From e391cbd6c7803fd096cb3850d8376de1252bfd12 Mon Sep 17 00:00:00 2001 From: hehesheng Date: Tue, 21 Jan 2025 19:15:03 +0800 Subject: [PATCH] feat: chore --- .gitignore | 4 +- A1_motor.py | 160 ++++++++++++++++++++++++++++++++++++++++ example.py | 35 ++++++++- motor_instance.py | 40 ++-------- motor_manager.py | 58 +++++++++++++-- unitree_actuator_sdk.py | 10 +-- 6 files changed, 259 insertions(+), 48 deletions(-) create mode 100644 A1_motor.py diff --git a/.gitignore b/.gitignore index 8899f26..08ad934 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +__pycache__ +.vscode .venv 3rdparty - +logs diff --git a/A1_motor.py b/A1_motor.py new file mode 100644 index 0000000..00f31c3 --- /dev/null +++ b/A1_motor.py @@ -0,0 +1,160 @@ +import time +import logging +from typing import override + +from unitree_actuator_sdk import * + +from utils import timeit + +from motor_instance import MotorInstance + +logger = logging.getLogger(__file__.split("/")[-1]) + + +class A1Motor(MotorInstance): + motor_name: str + serial_path: str + id: int + motor_mode: MotorMode + motor_type: MotorType + serial: SerialPort + + motor_cmd: MotorCmd = MotorCmd() + motor_data: MotorData = MotorData() + + # motor init value + _tau: float = 0 + _dq: float = 0 + _q: float = 0 + _kp: float = 0 + _kd: float = 0 + + def __init__( + self, + serial_path: str, + id: int, + motor_name: str = None, + mode: MotorMode = MotorMode.CALIBRATE, + motor_type: MotorType = MotorType.A1, + tau: float = 0, + dq: float = 0, + q: float = 0, + kp: float = 0, + kd: float = 0, + ): + self.serial_path = serial_path + self.motor_id = id + self.motor_mode = mode + self.motor_type = motor_type + self.motor_name = "-".join([self.serial_path, str(id), str(int(time.time()))]) if motor_name is None else motor_name + self.serial = SerialPort(self.serial_path) + self.reduction_ratio = queryGearRatio(self.motor_type) + self._tau = 0 + self._dq = 0 + self._q = 0 + self._kp = 0 + self._kd = 0 + self.init_motor_cmd() + self.init_motor_data() + + def init_motor_cmd(self): + motor_cmd = MotorCmd() + motor_cmd.motorType = self.motor_type + motor_cmd.mode = queryMotorMode(self.motor_type, self.motor_mode) + motor_cmd.id = self.motor_id + motor_cmd.tau = self._tau + motor_cmd.dq = self._dq + motor_cmd.q = self._q + motor_cmd.kp = self._kp + motor_cmd.kd = self._kd + self.motor_cmd = motor_cmd + + def init_motor_data(self): + motor_data = MotorData() + motor_data.motorType = self.motor_type + motor_data.mode = queryMotorMode(self.motor_type, self.motor_mode) + motor_data.motor_id = self.motor_id + motor_data.tau = 0 + motor_data.dq = 0 + motor_data.q = 0 + motor_data.temp = 0 + self.motor_data = motor_data + + @property + def tau(self): + return self.motor_data.tau + + @tau.setter + def tau(self, value): + if abs(value) < 128: + self.motor_cmd.tau = value + else: + raise Exception(f"tau value invalid: {value}") + + @property + def dq(self): + return self.motor_data.dq + + @dq.setter + def dq(self, value): + if abs(value) < 256: + self.motor_cmd.dq = value * self.reduction_ratio + else: + raise Exception(f"dq value invalid: {value}") + + @property + def q(self): + return self.motor_data.q + + @q.setter + def q(self, value): + if abs(value) < 823549: + self.motor_cmd.q = value * self.reduction_ratio + else: + raise Exception(f"q value invalid: {value}") + + @property + def kp(self): + return self.motor_cmd.kp + + @kp.setter + def kp(self, value): + if 0 <= value and value < 16: + self.motor_cmd.kp = value + else: + raise Exception(f"kp value invalid: {value}") + + @property + def kd(self): + return self.motor_cmd.kd + + @kd.setter + def kd(self, value): + if 0 <= value and value < 32: + self.motor_cmd.kd = value + else: + raise Exception(f"kd value invalid: {value}") + + @property + def temp(self): + return self.motor_data.temp + + @timeit + def reset(self): + self.init_motor_cmd() + self.sendrecv(self.motor_cmd) + + @timeit + def send_pingpong(self): + pass + + @override + def get_motor_name(self): + return self.motor_name + + @override + def sendrecv(self, cmd: MotorCmd) -> MotorData: + data: MotorData = MotorData() + self.serial.sendRecv(cmd, data) + self.motor_data = data + return data diff --git a/example.py b/example.py index 584bd0b..5d4bcf6 100644 --- a/example.py +++ b/example.py @@ -1,3 +1,36 @@ import os +import time -print("Hello world") +from unitree_actuator_sdk import * + +from A1_motor import A1Motor +from motor_manager import MotorManager + +# init manager with 20ms(50hz) +manager = MotorManager(20) + +motor00_name = manager.register_motor(A1Motor("/dev/my485serial0", 0)) +motor01_name = manager.register_motor(A1Motor("/dev/my485serial0", 1)) + +motor10_name = manager.register_motor(A1Motor("/dev/my485serial1", 0)) + + +def process_motor_data(name_to_motor_data: dict[str, MotorData]): + # do something you want, eg: + obs_tensor = [data for _, data in name_to_motor_data.items()] + new_action = policy(obs_tensor).detach().numpy().squeeze() + new_cmds_dict = {} + i = 0 + for name, motor in name_to_motor_data.items(): + new_cmds_dict[name] = new_action[i] + i = i + 1 + manager.update_motor_cmds(new_cmds_dict) + + +manager.add_motor_data_callback(process_motor_data) + +manager.run() + +time.sleep(1000) + +manager.stop() diff --git a/motor_instance.py b/motor_instance.py index a08a5d0..a821afe 100644 --- a/motor_instance.py +++ b/motor_instance.py @@ -1,44 +1,14 @@ import time +import logging from unitree_actuator_sdk import * +logger = logging.getLogger(__file__.split("/")[-1]) + class MotorInstance(object): - motor_name: str - serial_path: str - id: int - motor_mode: MotorMode - motor_type: MotorType - serial: SerialPort - - motor_cmd: MotorCmd - motor_data: MotorData - - def __init__( - self, - serial_path: str, - id: int, - motor_name: str = None, - mode: MotorMode = MotorMode.CALIBRATE, - motor_type: MotorType = MotorType.A1, - ): - self.serial_path = serial_path - self.id = id - self.motor_mode = mode - self.motor_type = motor_type - self.motor_name = "-".join([self.serial_path, str(id), str(int(time.time()))]) if motor_name is None else motor_name - self.serial = SerialPort(self.serial_path) - - def reset(self): - pass - - def send_pingpong(self): - pass - def get_motor_name(self): - return self.motor_name + pass def sendrecv(self, cmd: MotorCmd) -> MotorData: - data: MotorData = MotorData() - self.serial.sendRecv(cmd, data) - return data + pass diff --git a/motor_manager.py b/motor_manager.py index 6fc49a5..8eec72a 100644 --- a/motor_manager.py +++ b/motor_manager.py @@ -1,3 +1,4 @@ +import logging.config import time import threading import traceback @@ -6,10 +7,12 @@ import os import logging import yaml import typing +import copy +import enum from unitree_actuator_sdk import * -from motor_instance import MotorInstance +from motor_instance import * base_dir = os.path.dirname(__file__) log_config = None @@ -32,15 +35,25 @@ class MotorManager(object): motor_dict: dict[MotorInstance] = dict() loop_flag: bool = False - motor_cmds: dict[str, int] + motor_cmds_and_data_sem: threading.Semaphore = threading.Semaphore() + motor_cmds: dict[str, MotorCmd] = dict() + motor_data: dict[str, MotorData] = dict() - task_list: dict[str, typing.Callable] = {} + task_list: dict[str, typing.Callable] = dict() + + motor_data_callback_list: list[typing.Callable] = list() def __init__(self, cmd_interval_ms: int): self.cmd_interval_ms = cmd_interval_ms + self.register_task(MotorManager.transfer_motor_cmds_task) + self.register_task(MotorManager.notify_motor_data_task) def register_motor(self, motor: MotorInstance): self.motor_dict[motor.get_motor_name()] = motor + return motor.get_motor_name() + + def get_motor(self, motor_name: str) -> MotorInstance: + return self.motor_dict.get(motor_name) def register_task(self, task: typing.Callable, task_name: str = None) -> str: if task_name is None or task_name == "": @@ -48,6 +61,38 @@ class MotorManager(object): self.task_list[task_name] = task return task_name + def update_motor_cmds(self, motor_cmds: dict[str, MotorCmd]): + with self.motor_cmds_and_data_sem as sem: + self.motor_cmds = motor_cmds + + @staticmethod + def transfer_motor_cmds_task(self: "MotorManager"): + motor_cmds = {} + with self.motor_cmds_and_data_sem as sem: + motor_cmds = copy.deepcopy(self.motor_cmds) + motor_data = {} + for name, motor_cmd in motor_cmds.items(): + motor_instance = self.motor_dict.get(name) + if motor_instance is None: + continue + motor_data[name] = motor_instance.sendrecv(motor_cmd) + with self.motor_cmds_and_data_sem as sem: + self.motor_data = motor_data + + @staticmethod + def notify_motor_data_task(self: "MotorManager"): + notify_data = {} + with self.motor_cmds_and_data_sem as sem: + notify_data = copy.deepcopy(self.motor_data) + for callback in self.motor_data_callback_list: + try: + callback(notify_data) + except Exception as e: + logger.warning(f"notify data callback error: {traceback.format_exc()}") + + def add_motor_data_callback(self, callback: typing.Callable): + self.motor_data_callback_list.append(callback) + def run(self): self.loop_flag = True self.transfer_thread = threading.Thread(target=self.loop) @@ -66,15 +111,16 @@ class MotorManager(object): while self.loop_flag: for task_name, task in self.task_list.items(): try: - task(self, task_name) + task(self) except Exception as e: logger.warning(f"run task: {task_name} has trouble: {traceback.format_exc()}") cur_time = time.time() * 1000 time_delta = cur_time - next_run_time + sleep_seconds = self.cmd_interval_ms / 1000 if time_delta < 0: logger.warning(f"loop run too slow, took {self.cmd_interval_ms - time_delta} ms") - next_run_time = cur_time + self.cmd_interval_ms - # timeout and no sleep + next_run_time = (((time_delta // self.cmd_interval_ms) * -1) + 1) * self.cmd_interval_ms + next_run_time + # no sleep and run next loop immediately else: sleep_seconds = (next_run_time - cur_time) / 1000 next_run_time = next_run_time + self.cmd_interval_ms diff --git a/unitree_actuator_sdk.py b/unitree_actuator_sdk.py index ff01487..89cfb21 100644 --- a/unitree_actuator_sdk.py +++ b/unitree_actuator_sdk.py @@ -20,11 +20,11 @@ class MotorCmd(object): hex_len: int id: int mode: int - tau: float - dq: float - q: float - kp: float - kd: float + tau: float # T + dq: float # W + q: float # pos + kp: float # K_P + kd: float # K_W class MotorData(object):