找回密码
 立即注册
首页 业界区 业界 再不学就晚了!RDT × LeRobot与RDKS100部署详解 ...

再不学就晚了!RDT × LeRobot与RDKS100部署详解

雌鲳签 2025-9-29 12:50:14
作者:SkyXZ
CSDN:SkyXZ~-CSDN博客
博客园:SkyXZ - 博客园
机械臂:LeRobot-SO101           数采机:MacBook-Pro Python3.10
开发机:Ubuntu 22.04, Cuda12.4,8 × NVIDIA A100-SXM4-40GB
开发板:RDK OS 4.0.2 Based on Ubuntu 22.04, Python 3.10.12, OpenExplore 3.2.0
相关资料:

  • LeRobot Doc:https://huggingface.co/docs/lerobot/main/en/index
  • RDT 170M&1B:https://github.com/thu-ml/RoboticsDiffusionTransformer
  • RDT on RDKS100:RDT on Double RDK S100P 全流程文档
一、环境安装&机械臂配置


  • 所有代码已上传至GitHub:GitHub - xiongqi123123/LeRobot-VLA: Classic VLA for LeRobot
环境安装

       我们首先完成LeRobot环境的安装,我们默认使用conda作为环境管理,先运行以下命令创建一个Python3.10的虚拟环境
  1. conda create -y -n lerobot python=3.10
复制代码
        接着便可以在环境中运行以下命令来配置lerobot所需要的依赖(使用的lerobot源码为我修改之后的,仅添加了本地serverclient,其他部分与官方源码一致)
  1. # step:0 安装编译依赖
  2. sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config
  3. # step:1 激活环境
  4. conda activate lerobot
  5. # step:2 安装ffmpeg
  6. conda install ffmpeg -c conda-forge
  7. # 以下两种方式任选其一:
  8. # step:3 从源码安装lerobot
  9. git clone https://github.com/xiongqi123123/LeRobot-VLA.git
  10. cd LeRobot-VLA/lerobot
  11. pip install -e .
  12. # step:3 从PyPI 安装
  13. pip install lerobot
  14. # 要安装附加功能,请使用以下之一
  15. pip install 'lerobot[all]'          # All available features
  16. pip install 'lerobot[aloha,pusht]'  # Specific features (Aloha & Pusht)
  17. pip install 'lerobot[feetech]'      # Feetech motor support
复制代码
机械臂安装


  • 官方安装教程:SO-101 - Hugging Face 机器学习平台
        由于SO-ARM101机械臂默认不提供机械臂上的相机安装位置,因此我们在原始的机械臂夹爪部分自行设计添加了一个相机固定的位置(安装孔位与夹爪是对齐的),我们使用的相机是亚博智能的1080P高清免驱摄像头,打印文件已保存进仓库中
1.png

        拿到刚拆封的机械臂配件我们首先将Follower和Leader臂的物料进行区分,要注意的是Follower机械臂使用的是12个ST-3215-C001(7.4V)1:345齿轮比的电机,而Leader臂不同关机使用的电机型号有所不同,不同关节的电机型号区分如下图及上表:
2.png

        接下来我们便可以按照官方提供的如下3D演示动画安装Follower和Leader了,下面是关节一到关节五以及夹爪的安装实例,除了夹爪之外,前面五个关节的安装方法均一致,仅需注意Leader臂的电机型号:
机械臂电机配置

        在完成机械臂的安装后我们便可以开始对机械臂两个臂的电机进行配置设置其对应的ID了,新版的LeRobot提供了CLI命令可以直接运行对应的任务,我们首先将两个机械臂的串口钱全部接上电脑并运行如下命令,接着按照提示拔出其中一个串口线按下回车即可知道拔出的串口号是多少(实际就是记录插上的所有串口,然后再和拔出后的进行对比就可以知道哪个串口少了...),示例输出如下图:
  1. lerobot-find-port
复制代码
3.png

        由于LeRobot使用的是总线舵机,每个电机都通过总线上的唯一 ID 进行识别,而全新电机通常带有一个默认 ID 为 1,所以为了让电机和控制器之间正常通信,我们需要为每个电机设置一个唯一的、不同的 ID,在确定了自己的两个机械臂分别对应的串口后我们便可以一个臂一个臂一个电机一个电机的进行配置啦,在新版的LeRobot库不需要重复多次运行电机配置命令,我们只需要运行以下命令并依次将不同的舵机线插到控制板按回车即可,具体可以参考以下视频:
  1. lerobot-setup-motors \
  2.     --robot.type=so101_follower \
  3.     --robot.port=/dev/tty.usbmodem585A0076841  # <- paste here the port found at previous step
复制代码
        接下来仅需运行这个脚本或者是如下命令即可完成数据的转换:
  1. lerobot-calibrate \
  2.     --robot.type=so101_follower \
  3.     --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
  4.     --robot.id=my_awesome_follower_arm # <- Give the robot a unique name
复制代码
        我们使用RoboTwin修改后的RDT版本,这个版本使用比较简单快速(更多关于RDT的信息请见:RDT on Double RDK S100P 全流程文档),由于RDT默认是双臂任务,而我们采集的LeRobot是单臂数据,且我们只采集了两个摄像头的画面跟RDT默认的三个摄像头不匹配,因此如果直接训练的话肯定会报索引不匹配的错误,因此我们还需要对数据集加载的部分进行修改,首先是修改action的归一化部分,我们直接在加载的时候对数据除以[[180, 180, 180, 180, 180, 180]],然后便是将LeRobot的单臂映射到RDT的右臂部分的动作维度并将左臂整个给剔除同时把RDT默认加载的右臂图像用Ground图像进行替代,请使用以下完成了修改的代码替换原本的代码中的RDT/data/hdf5_vla_dataset.py
  1. lerobot-find-cameras opencv
复制代码
开始训练

训练环境配置

        我们首先来安装RDT训练所需的环境,此我们进入RDT目录下依次安装如下包即可:
  1. lerobot-teleoperate \
  2.     --robot.type=so101_follower \
  3.     --robot.port=/dev/tty.usbmodem5AB90671801 \
  4.     --robot.id=my_awesome_follower_arm \
  5.     --teleop.type=so101_leader \
  6.     --teleop.port=/dev/tty.usbmodem5AB90671501 \
  7.     --teleop.id=my_awesome_leader_arm \
复制代码
        除了上述依赖之外我们还需要安装flash_attn用来加速,为了避免网络连接问题我们手动下载预编译的wheel 文件,下载连接为:https://github.com/Dao-AILab/flash-attention/releases,我们需要根据我们实际安装的torch及cuda版本来选择对应的版本,然后我们还需要根据我们下载的PyTorch是如何编译的来选择对应的cxx11abi 是 TRUE 还是 FALSE。
  1. lerobot-teleoperate \
  2.     --robot.type=so101_follower \
  3.     --robot.port=/dev/tty.usbmodem5AB90671801 \
  4.     --robot.id=my_awesome_follower_arm \
  5.     --robot.cameras="{ arm: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, front: {type: opencv, index_or_path: 2, width: 1920, height: 1080, fps: 30}}" \
  6.     --teleop.type=so101_leader \
  7.     --teleop.port=/dev/tty.usbmodem5AB90671501 \
  8.     --teleop.id=my_awesome_leader_arm \
  9.     --display_data=true
复制代码
        接着我们便可以根据输出判断我们要下载的是哪个版本啦,如下图所示我们当前的PyTorch是CXX11_ABI = 0,因此我们要下载的是cxx11abiFALSE 的 .whl 文件
  1. lerobot-record \
  2.     --robot.type=so101_follower \
  3.     --robot.port=/dev/tty.usbmodem5AB90671801 \
  4.     --robot.id=my_awesome_follower_arm \
  5.     --robot.cameras="{ arm: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, front: {type: opencv, index_or_path: 2, width: 1920, height: 1080, fps: 30}}" \
  6.     --teleop.type=so101_leader \
  7.     --teleop.port=/dev/tty.usbmodem5AB90671501 \
  8.     --teleop.id=my_awesome_leader_arm \
  9.     --display_data=true \
  10.     --dataset.repo_id=skyxz/blackmarker_scence1 \
  11.     --dataset.num_episodes=5 \
  12.     --dataset.single_task="Grab the black marker and put it in the bin" \
  13.     --dataset.push_to_hub=False \
  14.     --dataset.episode_time_s=15 \
  15.     --dataset.reset_time_s=5
复制代码
        如果上述步骤跟我的一样的话,那么大家需要下载并安装的应该是如下的版本,下载后安装即可:
  1. (RoboTwin) qi.xiong@A100-Test:~/Data_Qi/LeRobot/skyxz/blackmarker_scence1$ python3 -c "import h5py; f=h5py.File('/home/qi.xiong/Data_Qi/RDT/processed_data/place_dual_shoes-demo_clean-300/episode_3/episode_3.hdf5','r'); f.visit(print); f.close()"
  2. action
  3. observations
  4. observations/images
  5. observations/images/cam_high
  6. observations/images/cam_left_wrist
  7. observations/images/cam_right_wrist
  8. observations/left_arm_dim
  9. observations/qpos
  10. observations/right_arm_dim
复制代码
预训练模型下载

        RDT分为了1B版本(SigLip+DIT+Adaptor)以及单独的170M(DiT)两个版本,其中的区别仅在于最后的DiT的hidden_size和depth维度区别,170M相比于1B的版本直接减半了,如果需要在RDKS100上部署的话请参考接下来的RDT170M模型版本以保证可行的性能,在训练之前还需按照如下的步骤下载预训练模型及完成训练环境的安装;
  1. #!/usr/bin/env python3
  2. """
  3. LeRobot到RDT数据转换脚本
  4. LeRobot机器人结构:
  5. - 5个关节 (shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll)
  6. - 1个夹爪 (gripper)
  7. - 总计:6个自由度 (6DOF)
  8. 维度映射(匹配RDT训练代码):
  9. - left_arm_dim = 0 (单臂机器人,左臂不存在)
  10. - right_arm_dim = 6 (5关节 + 1夹爪,映射到RDT的right_arm部分)
  11. - 状态向量:6维 [joint1, joint2, joint3, joint4, joint5, gripper]
  12. - RDT索引映射:right_arm_joint_0_pos到right_arm_joint_5_pos (索引0-5)
  13. """
  14. import sys
  15. import os
  16. import h5py
  17. import numpy as np
  18. import cv2
  19. import argparse
  20. import yaml
  21. import json
  22. import subprocess
  23. from pathlib import Path
  24. import pandas as pd
  25. import torch
  26. current_dir = os.path.dirname(__file__)
  27. sys.path.append(os.path.join(current_dir, ".."))
  28. from models.multimodal_encoder.t5_encoder import T5Embedder
  29. def extract_frames_from_video(video_path, output_dir, episode_idx):
  30.     if not os.path.exists(video_path):
  31.         print(f"  No video file: {video_path}")
  32.         return []
  33.    
  34.     temp_dir = os.path.join(output_dir, f"temp_frames_{episode_idx}")
  35.     if not os.path.exists(temp_dir):
  36.         os.makedirs(temp_dir)
  37.    
  38.     output_pattern = os.path.join(temp_dir, "frame_%04d.jpg")
  39.    
  40.     try:
  41.         cmd = [
  42.             'ffmpeg', '-i', video_path,
  43.             '-vf', 'fps=30',
  44.             '-q:v', '2',
  45.             output_pattern,
  46.             '-y'
  47.         ]
  48.         
  49.         result = subprocess.run(cmd, capture_output=True, text=True)
  50.         
  51.         if result.returncode != 0:
  52.             print(f"  Failed to extract frames with ffmpeg: {result.stderr}")
  53.             return []
  54.         
  55.         frames = []
  56.         frame_files = sorted([f for f in os.listdir(temp_dir) if f.endswith('.jpg')])
  57.         
  58.         for frame_file in frame_files:
  59.             frame_path = os.path.join(temp_dir, frame_file)
  60.             frame = cv2.imread(frame_path)
  61.             if frame is not None:
  62.                 frame_resized = cv2.resize(frame, (640, 480))
  63.                 frames.append(frame_resized)
  64.         
  65.         print(f"  Successfully extracted {len(frames)} frames")
  66.         
  67.         for frame_file in frame_files:
  68.             os.remove(os.path.join(temp_dir, frame_file))
  69.         os.rmdir(temp_dir)
  70.         
  71.         return frames
  72.         
  73.     except Exception as e:
  74.         print(f"  Error extracting frames: {e}")
  75.         return []
  76. def load_lerobot_episode(data_dir, episode_idx, output_dir):
  77.     """加载LeRobot的单个episode数据
  78.    
  79.     LeRobot数据结构:
  80.     - action: 6维 [shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper]
  81.     - observation.state: 6维 [shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper]
  82.     - 图像: 高位相机 + 手臂相机
  83.     """
  84.     parquet_path = os.path.join(data_dir, "data/chunk-000", f"episode_{episode_idx:06d}.parquet")
  85.     if not os.path.exists(parquet_path):
  86.         print(f"Episode {episode_idx} parquet file does not exist: {parquet_path}")
  87.         return None
  88.    
  89.     df = pd.read_parquet(parquet_path)
  90.    
  91.     actions = []
  92.     qpos = []
  93.    
  94.     for i in range(len(df)):
  95.         action = df['action'].iloc[i]
  96.         state = df['observation.state'].iloc[i]
  97.         
  98.         if isinstance(action, np.ndarray):
  99.             actions.append(action.astype(np.float32))
  100.         else:
  101.             actions.append(np.array(action, dtype=np.float32))
  102.             
  103.         if isinstance(state, np.ndarray):
  104.             qpos.append(state.astype(np.float32))
  105.         else:
  106.             qpos.append(np.array(state, dtype=np.float32))
  107.    
  108.     high_cam_path = os.path.join(data_dir, "videos/chunk-000/observation.images.high", f"episode_{episode_idx:06d}.mp4")
  109.     arm_cam_path = os.path.join(data_dir, "videos/chunk-000/observation.images.arm", f"episode_{episode_idx:06d}.mp4")
  110.    
  111.     print(f"  Extracting high camera frames...")
  112.     high_images = extract_frames_from_video(high_cam_path, output_dir, episode_idx)
  113.    
  114.     print(f"  Extracting arm camera frames...")
  115.     arm_images = extract_frames_from_video(arm_cam_path, output_dir, episode_idx)
  116.    
  117.     target_frames = len(df)
  118.     if len(high_images) > target_frames:
  119.         high_images = high_images[:target_frames]
  120.     if len(arm_images) > target_frames:
  121.         arm_images = arm_images[:target_frames]
  122.    
  123.     while len(high_images) < target_frames and high_images:
  124.         high_images.append(high_images[-1])
  125.     while len(arm_images) < target_frames and arm_images:
  126.         arm_images.append(arm_images[-1])
  127.    
  128.     return {
  129.         'actions': np.array(actions),
  130.         'qpos': np.array(qpos),
  131.         'high_images': high_images,
  132.         'arm_images': arm_images,
  133.         'episode_length': len(df)
  134.     }
  135. def images_encoding(imgs):
  136.     if not imgs:
  137.         return [], 0
  138.         
  139.     encode_data = []
  140.     padded_data = []
  141.     max_len = 0
  142.    
  143.     for i in range(len(imgs)):
  144.         success, encoded_image = cv2.imencode(".jpg", imgs[i])
  145.         if success:
  146.             jpeg_data = encoded_image.tobytes()
  147.             encode_data.append(jpeg_data)
  148.             max_len = max(max_len, len(jpeg_data))
  149.         else:
  150.             print(f"  Image encoding failed: {i}")
  151.             empty_data = b""
  152.             encode_data.append(empty_data)
  153.    
  154.     for i in range(len(imgs)):
  155.         padded_data.append(encode_data[i].ljust(max_len, b"\0"))
  156.    
  157.     return encode_data, max_len
  158. def load_task_instructions(data_dir):
  159.     tasks_file = os.path.join(data_dir, "meta/tasks.jsonl")
  160.     if not os.path.exists(tasks_file):
  161.         print(f"Warning: tasks file not found: {tasks_file}")
  162.         return None
  163.    
  164.     instructions = []
  165.     with open(tasks_file, 'r') as f:
  166.         for line in f:
  167.             if line.strip():
  168.                 task_data = json.loads(line.strip())
  169.                 instructions.append(task_data["task"])
  170.    
  171.     print(f"  加载了 {len(instructions)} 个任务指令")
  172.     return instructions
  173. def encode_language_instruction(instruction_text, t5_embedder, device):
  174.     try:
  175.         text_embeds, attn_mask = t5_embedder.get_text_embeddings([instruction_text])
  176.         
  177.         valid_embeds = text_embeds[0][attn_mask[0]].float()
  178.         return valid_embeds.cpu().numpy()
  179.         
  180.     except Exception as e:
  181.         print(f"  Language encoding failed: {e}")
  182.         return np.zeros((1, 4096))
  183. def convert_lerobot_to_rdt(data_dir, output_dir, episode_num, gpu=0, no_language=False):
  184.     if not os.path.exists(output_dir):
  185.         os.makedirs(output_dir)
  186.    
  187.     print(f"Start converting LeRobot data to RDT format...")
  188.     print(f"Data source: {data_dir}")
  189.     print(f"Output directory: {output_dir}")
  190.     print(f"Processing episode number: {episode_num}")
  191.     print(f"GPU device: {gpu}")
  192.    
  193.     scene_name = os.path.basename(data_dir)
  194.    
  195.     instructions = None
  196.     if not no_language:
  197.         instructions = load_task_instructions(data_dir)
  198.    
  199.     t5_embedder = None
  200.     if not no_language and instructions:
  201.         try:
  202.             print(f"  Initializing T5 encoder...")
  203.             t5_model_path = "/home/qi.xiong/Data_Qi/t5-v1_1-xxl"
  204.             if not os.path.exists(t5_model_path):
  205.                 print(f"  Warning: T5 model path does not exist: {t5_model_path}")
  206.                 print(f"  Will skip language processing")
  207.                 no_language = True
  208.             else:
  209.                 t5_embedder = T5Embedder(
  210.                     from_pretrained=t5_model_path,
  211.                     device=f"cuda:{gpu}" if torch.cuda.is_available() else "cpu",
  212.                     model_max_length=120,
  213.                     use_offload_folder=None,
  214.                 )
  215.                 print(f"  T5 encoder initialized successfully")
  216.         except Exception as e:
  217.             print(f"  T5 encoder initialization failed: {e}")
  218.             print(f"  Will skip language processing")
  219.             no_language = True
  220.    
  221.     for i in range(episode_num):
  222.         print(f"Processing episode {i}...")
  223.         
  224.         episode_data = load_lerobot_episode(data_dir, i, output_dir)
  225.         if episode_data is None:
  226.             print(f"Skipping episode {i}")
  227.             continue
  228.         
  229.         episode_output_dir = os.path.join(output_dir, f"episode_{i}")
  230.         if not os.path.exists(episode_output_dir):
  231.             os.makedirs(episode_output_dir)
  232.         
  233.         hdf5_path = os.path.join(episode_output_dir, f"episode_{i}.hdf5")
  234.         
  235.         with h5py.File(hdf5_path, "w") as f:
  236.             f.create_dataset("action", data=episode_data['actions'])
  237.             
  238.             obs = f.create_group("observations")
  239.             obs.create_dataset("qpos", data=episode_data['qpos'])
  240.             
  241.             image = obs.create_group("images")
  242.             
  243.             if episode_data['high_images']:
  244.                 print(f"  Encoding high camera images...")
  245.                 high_enc, len_high = images_encoding(episode_data['high_images'])
  246.                 if high_enc and len_high > 0:
  247.                     image.create_dataset("cam_high", data=high_enc, dtype=f"S{len_high}")
  248.                     print(f"  Saved high camera images: {len(episode_data['high_images'])} frames")
  249.                 else:
  250.                     print(f"  Warning: High camera images encoding failed")
  251.             
  252.             if episode_data['arm_images']:
  253.                 print(f"  Encoding arm camera images...")
  254.                 arm_enc, len_arm = images_encoding(episode_data['arm_images'])
  255.                 if arm_enc and len_arm > 0:
  256.                     image.create_dataset("cam_right_wrist", data=arm_enc, dtype=f"S{len_arm}")
  257.                     print(f"  Saved arm camera images: {len(episode_data['arm_images'])} frames")
  258.                 else:
  259.                     print(f"  Warning: Arm camera images encoding failed")
  260.             
  261.             # 添加机器人维度信息(LeRobot: 5个关节 + 1个夹爪)
  262.             # 根据process_data.py的逻辑,每个时间步都需要记录维度信息
  263.             # LeRobot是单臂机器人,只有右臂:5个关节 + 1个夹爪 = 6维
  264.             # 左臂:0维(单臂机器人)
  265.             
  266.             # 为每个时间步记录维度信息
  267.             left_arm_dim = [0] * len(episode_data['actions'])  # 左臂0维(单臂机器人)
  268.             right_arm_dim = [6] * len(episode_data['actions'])  # 右臂6维(5关节+1夹爪)
  269.             
  270.             obs.create_dataset("left_arm_dim", data=np.array(left_arm_dim))
  271.             obs.create_dataset("right_arm_dim", data=np.array(right_arm_dim))
  272.         
  273.         print(f"  Episode {i} converted successfully: {hdf5_path}")
  274.         print(f"  Data length: {episode_data['episode_length']}")
  275.         print(f"  Action shape: {episode_data['actions'].shape}")
  276.         print(f"  Qpos shape: {episode_data['qpos'].shape}")
  277.         print(f"  High camera frames: {len(episode_data['high_images'])}")
  278.         print(f"  Arm camera frames: {len(episode_data['arm_images'])}")
  279.         
  280.         if not no_language and t5_embedder and instructions:
  281.             print(f"  Processing language instructions...")
  282.             try:
  283.                 instruction = instructions[0]
  284.                
  285.                 language_features = encode_language_instruction(instruction, t5_embedder, f"cuda:{gpu}")
  286.                
  287.                 instructions_dir = os.path.join(episode_output_dir, "instructions")
  288.                 if not os.path.exists(instructions_dir):
  289.                     os.makedirs(instructions_dir)
  290.                
  291.                 lang_embed_path = os.path.join(instructions_dir, "lang_embed_0.pt")
  292.                 torch.save(torch.from_numpy(language_features), lang_embed_path)
  293.                
  294.                 print(f"  Language instruction encoded successfully: {instruction}")
  295.                 print(f"  Language features saved to: {lang_embed_path}")
  296.                 print(f"  Language features shape: {language_features.shape}, data type: {language_features.dtype}")
  297.                
  298.             except Exception as e:
  299.                 print(f"  Language instruction processing failed: {e}")
  300.    
  301.     print(f"\nConversion completed! Processed {episode_num} episodes")
  302.     print(f"Output directory: {output_dir}")
  303. def main():
  304.     parser = argparse.ArgumentParser(description="Convert LeRobot data to RDT format")
  305.     parser.add_argument("--data_dir", type=str, required=True,
  306.                        help="LeRobot data directory path")
  307.     parser.add_argument("--output_dir", type=str, required=True,
  308.                        help="Output directory path")
  309.     parser.add_argument("--episode_num", type=int, default=10,
  310.                        help="Number of episodes to process")
  311.     parser.add_argument("--gpu", type=int, default=0,
  312.                        help="GPU device ID")
  313.     parser.add_argument("--no_language", action="store_true",
  314.                        help="Skip language processing")
  315.    
  316.     args = parser.parse_args()
  317.    
  318.     if not os.path.exists(args.data_dir):
  319.         print(f"Error: Data directory does not exist: {args.data_dir}")
  320.         return
  321.    
  322.     meta_file = os.path.join(args.data_dir, "meta/info.json")
  323.     if not os.path.exists(meta_file):
  324.         print(f"Error: Meta information file not found: {meta_file}")
  325.         return
  326.    
  327.     try:
  328.         subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True)
  329.         print("ffmpeg is available, will use ffmpeg to extract video frames")
  330.     except (subprocess.CalledProcessError, FileNotFoundError):
  331.         print("Warning: ffmpeg is not available, image data may not be extracted correctly")
  332.         print("Please install ffmpeg: conda install -c conda-forge ffmpeg=6.1")
  333.         return
  334.    
  335.     with open(meta_file, 'r') as f:
  336.         meta_info = yaml.safe_load(f)
  337.    
  338.     total_episodes = meta_info.get('total_episodes', 10)
  339.     if args.episode_num > total_episodes:
  340.         print(f"Warning: Requested episode number ({args.episode_num}) exceeds available number ({total_episodes})")
  341.         args.episode_num = total_episodes
  342.    
  343.     convert_lerobot_to_rdt(
  344.         args.data_dir,
  345.         args.output_dir,
  346.         args.episode_num,
  347.         args.gpu,
  348.         args.no_language
  349.     )
  350. if __name__ == "__main__":
  351.     main()
复制代码
RDT1B微调

        要训练1B的版本我们需要修改RDT/configs/base.yaml文件中model类下的rdt参数
  1. # 法一 bash process_data_rdt.sh data_dir=${1} output_dir=${2} episode_num=${3} gpu_id=${4}
  2. bash process_data_rdt.sh /home/qi.xiong/Data_Qi/LeRobot/skyxz/redmarker_scence4 /home/qi.xiong/DualArm/RoboTwin/policy/RDT-LeRobot/processed_data/redmarker_scence4 5 0
  3. # 法二 python scripts/process_data_lerobot.py --data_dir --output_dir --episode_num --gpu
  4. python3 scripts/process_data_lerobot.py --data_dir /home/qi.xiong/Data_Qi/LeRobot/skyxz/redmarker_scence4 --output_dir /home/qi.xiong/DualArm/RoboTwin/policy/RDT-LeRobot/processed_data/redmarker_scence4
复制代码
        接着生成训练的参数yml文件,并将pretrained_model_name_or_path指向我们先前下载的1B模型
  1. import os
  2. import fnmatch
  3. import json
  4. import h5py
  5. import yaml
  6. import cv2
  7. import numpy as np
  8. from configs.state_vec import STATE_VEC_IDX_MAPPING
  9. class HDF5VLADataset:
  10.     """
  11.     This class is used to sample episodes from the embododiment dataset
  12.     stored in HDF5.
  13.     """
  14.     def __init__(self, model_config_path) -> None:
  15.         # [Modify] The path to the HDF5 dataset directory
  16.         # Each HDF5 file contains one episode
  17.         with open(model_config_path, "r") as f:
  18.             model_config = yaml.safe_load(f)
  19.         HDF5_DIR = model_config["data_path"]
  20.         self.DATASET_NAME = "agilex"
  21.         self.file_paths = []
  22.         for root, _, files in os.walk(HDF5_DIR):
  23.             for filename in fnmatch.filter(files, "*.hdf5"):
  24.                 file_path = os.path.join(root, filename)
  25.                 self.file_paths.append(file_path)
  26.         # Load the config
  27.         with open("configs/base.yaml", "r") as file:
  28.             config = yaml.safe_load(file)
  29.         self.CHUNK_SIZE = config["common"]["action_chunk_size"]
  30.         self.IMG_HISORY_SIZE = config["common"]["img_history_size"]
  31.         self.STATE_DIM = config["common"]["state_dim"]
  32.         # Get each episode's len (use original length, not standardized length)
  33.         episode_lens = []
  34.         for file_path in self.file_paths:
  35.             try:
  36.                 with h5py.File(file_path, "r") as f:
  37.                     qpos = f["observations"]["qpos"][:]
  38.                     num_steps = qpos.shape[0]
  39.                     episode_lens.append(num_steps)
  40.             except Exception as e:
  41.                 print(f"Warning: Could not read {file_path}: {e}")
  42.                 episode_lens.append(0)
  43.         self.episode_sample_weights = np.array(episode_lens) / np.sum(episode_lens)
  44.     def __len__(self):
  45.         return len(self.file_paths)
  46.     def get_dataset_name(self):
  47.         return self.DATASET_NAME
  48.     def get_item(self, index: int = None, state_only=False):
  49.         """Get a training sample at a random timestep.
  50.         Args:
  51.             index (int, optional): the index of the episode.
  52.                 If not provided, a random episode will be selected.
  53.             state_only (bool, optional): Whether to return only the state.
  54.                 In this way, the sample will contain a complete trajectory rather
  55.                 than a single timestep. Defaults to False.
  56.         Returns:
  57.            sample (dict): a dictionary containing the training sample.
  58.         """
  59.         while True:
  60.             if index is None:
  61.                 file_path = np.random.choice(self.file_paths, p=self.episode_sample_weights)
  62.             else:
  63.                 file_path = self.file_paths[index]
  64.             valid, sample = (self.parse_hdf5_file(file_path)
  65.                              if not state_only else self.parse_hdf5_file_state_only(file_path))
  66.             if valid:
  67.                 return sample
  68.             else:
  69.                 index = np.random.randint(0, len(self.file_paths))
  70.     def parse_hdf5_file(self, file_path):
  71.         """[Modify] Parse a hdf5 file to generate a training sample at
  72.             a random timestep.
  73.         Args:
  74.             file_path (str): the path to the hdf5 file
  75.         Returns:
  76.             valid (bool): whether the episode is valid, which is useful for filtering.
  77.                 If False, this episode will be dropped.
  78.             dict: a dictionary containing the training sample,
  79.                 {
  80.                     "meta": {
  81.                         "dataset_name": str,    # the name of your dataset.
  82.                         "#steps": int,          # the number of steps in the episode,
  83.                                                 # also the total timesteps.
  84.                         "instruction": str      # the language instruction for this episode.
  85.                     },
  86.                     "step_id": int,             # the index of the sampled step,
  87.                                                 # also the timestep t.
  88.                     "state": ndarray,           # state[t], (1, STATE_DIM).
  89.                     "state_std": ndarray,       # std(state[:]), (STATE_DIM,).
  90.                     "state_mean": ndarray,      # mean(state[:]), (STATE_DIM,).
  91.                     "state_norm": ndarray,      # norm(state[:]), (STATE_DIM,).
  92.                     "actions": ndarray,         # action[t:t+CHUNK_SIZE], (CHUNK_SIZE, STATE_DIM).
  93.                     "state_indicator", ndarray, # indicates the validness of each dim, (STATE_DIM,).
  94.                     "cam_high": ndarray,        # external camera image, (IMG_HISORY_SIZE, H, W, 3)
  95.                                                 # or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
  96.                     "cam_high_mask": ndarray,   # indicates the validness of each timestep, (IMG_HISORY_SIZE,) boolean array.
  97.                                                 # For the first IMAGE_HISTORY_SIZE-1 timesteps, the mask should be False.
  98.                     "cam_left_wrist": ndarray,  # left wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
  99.                                                 # or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
  100.                     "cam_left_wrist_mask": ndarray,
  101.                     "cam_right_wrist": ndarray, # right wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
  102.                                                 # or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
  103.                                                 # If only one wrist, make it right wrist, plz.
  104.                     "cam_right_wrist_mask": ndarray
  105.                 } or None if the episode is invalid.
  106.         """
  107.         with h5py.File(file_path, "r") as f:
  108.             qpos = f["observations"]["qpos"][:]
  109.             left_arm_dim = f["observations"]["left_arm_dim"][:]
  110.             right_arm_dim = f["observations"]["right_arm_dim"][:]
  111.             num_steps = qpos.shape[0]
  112.             action_dim = qpos
  113.             # [Optional] We drop too-short episode
  114.             # if num_steps < 128:
  115.             #     return False, None
  116.             # [Optional] We skip the first few still steps
  117.             EPS = 1e-2
  118.             # Get the idx of the first qpos whose delta exceeds the threshold
  119.             qpos_delta = np.abs(qpos - qpos[0:1])
  120.             indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
  121.             if len(indices) > 0:
  122.                 first_idx = indices[0]
  123.             else:
  124.                 raise ValueError("Found no qpos that exceeds the threshold.")
  125.             # We randomly sample a timestep
  126.             step_id = np.random.randint(first_idx - 1, num_steps)
  127.             # Load the instruction
  128.             dir_path = os.path.dirname(file_path)
  129.             # with open(os.path.join(dir_path, 'instruction.json'), 'r') as f_instr:
  130.             #     instruction_dict = json.load(f_instr)
  131.             # # We have 1/3 prob to use original instruction,
  132.             # # 1/3 to use simplified instruction,
  133.             # # and 1/3 to use expanded instruction.
  134.             # instruction_type = np.random.choice([
  135.             #     'instruction', 'expanded_instruction'])
  136.             # instruction = instruction_dict[instruction_type]
  137.             # if isinstance(instruction, list):
  138.             #    instruction = np.random.choice(instruction)
  139.             # You can also use precomputed language embeddings (recommended)
  140.             # instruction = "path/to/lang_embed.pt"
  141.             instructions_path = os.path.join(dir_path, "instructions")
  142.             instructions_names = []
  143.             for filename in os.listdir(instructions_path):
  144.                 # 检查文件名是否以.pt结尾
  145.                 if filename.endswith(".pt"):
  146.                     instructions_names.append(os.path.join(instructions_path, filename))
  147.             instruction = np.random.choice(instructions_names)
  148.             # print(f"choose {instruction} file as instruction.")
  149.             # Assemble the meta
  150.             meta = {
  151.                 "dataset_name": self.DATASET_NAME,
  152.                 "#steps": num_steps,
  153.                 "step_id": step_id,
  154.                 "instruction": instruction,
  155.             }
  156.             # Rescale gripper to [0, 1]
  157.             # qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]])
  158.             # target_qpos = f["action"][step_id:step_id + self.CHUNK_SIZE] / np.array(
  159.             #     [[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]])
  160.             qpos = qpos / np.array(
  161.             #    [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]
  162.             [[180, 180, 180, 180, 180, 180]]
  163.             )
  164.             target_qpos = f['action'][step_id:step_id + self.CHUNK_SIZE] / np.array(
  165.             #    [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]]
  166.             [[180, 180, 180, 180, 180, 180]]
  167.             )
  168.             # Parse the state and action
  169.             state = qpos[step_id:step_id + 1]
  170.             state_std = np.std(qpos, axis=0)
  171.             state_mean = np.mean(qpos, axis=0)
  172.             state_norm = np.sqrt(np.mean(qpos**2, axis=0))
  173.             actions = target_qpos
  174.             if actions.shape[0] < self.CHUNK_SIZE:
  175.                 # Pad the actions using the last action
  176.                 actions = np.concatenate(
  177.                     [
  178.                         actions,
  179.                         np.tile(actions[-1:], (self.CHUNK_SIZE - actions.shape[0], 1)),
  180.                     ],
  181.                     axis=0,
  182.                 )
  183.             # Fill the state/action into the unified vector
  184.             def fill_in_state(values):
  185.                 # Target indices corresponding to your state space
  186.                 # In this example: 6 joints + 1 gripper for each arm
  187.                 UNI_STATE_INDICES =  [
  188.                     STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
  189.                 # ] + [
  190.                     # STATE_VEC_IDX_MAPPING["right_gripper_open"]
  191.                 ]
  192.                 uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, ))
  193.                 uni_vec[..., UNI_STATE_INDICES] = values
  194.                 return uni_vec
  195.             state = fill_in_state(state)
  196.             state_indicator = fill_in_state(np.ones_like(state_std))
  197.             state_std = fill_in_state(state_std)
  198.             state_mean = fill_in_state(state_mean)
  199.             state_norm = fill_in_state(state_norm)
  200.             # If action's format is different from state's,
  201.             # you may implement fill_in_action()
  202.             actions = fill_in_state(actions)
  203.             # Parse the images
  204.             def parse_img(key):
  205.                 imgs = []
  206.                 for i in range(max(step_id - self.IMG_HISORY_SIZE + 1, 0), step_id + 1):
  207.                     img_bits = f["observations"]["images"][key][i]
  208.                     img = cv2.imdecode(np.frombuffer(img_bits, np.uint8), cv2.IMREAD_COLOR)
  209.                     imgs.append(img)
  210.                 imgs = np.stack(imgs)
  211.                 if imgs.shape[0] < self.IMG_HISORY_SIZE:
  212.                     # Pad the images using the first image
  213.                     imgs = np.concatenate(
  214.                         [
  215.                             np.tile(
  216.                                 imgs[:1],
  217.                                 (self.IMG_HISORY_SIZE - imgs.shape[0], 1, 1, 1),
  218.                             ),
  219.                             imgs,
  220.                         ],
  221.                         axis=0,
  222.                     )
  223.                 return imgs
  224.             # `cam_high` is the external camera image
  225.             cam_high = parse_img("cam_high")
  226.             # For step_id = first_idx - 1, the valid_len should be one
  227.             valid_len = min(step_id - (first_idx - 1) + 1, self.IMG_HISORY_SIZE)
  228.             cam_high_mask = np.array([False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len)
  229.             # cam_left_wrist = parse_img("cam_left_wrist")
  230.             # cam_left_wrist_mask = cam_high_mask.copy()
  231.             cam_left_wrist = np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0))#parse_img('cam_right_wrist')
  232.             cam_left_wrist_mask = np.array([False] * self.IMG_HISORY_SIZE)#cam_high_mask.copy()
  233.             cam_right_wrist = parse_img("cam_right_wrist")
  234.             cam_right_wrist_mask = cam_high_mask.copy()  # 使用相同的掩码逻辑
  235.             # Return the resulting sample
  236.             # For unavailable images, return zero-shape arrays, i.e., (IMG_HISORY_SIZE, 0, 0, 0)
  237.             # E.g., return np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0)) for the key "cam_left_wrist",
  238.             # if the left-wrist camera is unavailable on your robot
  239.             return True, {
  240.                 "meta": meta,
  241.                 "state": state,
  242.                 "state_std": state_std,
  243.                 "state_mean": state_mean,
  244.                 "state_norm": state_norm,
  245.                 "actions": actions,
  246.                 "state_indicator": state_indicator,
  247.                 "cam_high": cam_high,
  248.                 "cam_high_mask": cam_high_mask,
  249.                 "cam_left_wrist": cam_left_wrist,
  250.                 "cam_left_wrist_mask": cam_left_wrist_mask,
  251.                 "cam_right_wrist": cam_right_wrist,
  252.                 "cam_right_wrist_mask": cam_right_wrist_mask,
  253.             }
  254.     def parse_hdf5_file_state_only(self, file_path):
  255.         """[Modify] Parse a hdf5 file to generate a state trajectory.
  256.         Args:
  257.             file_path (str): the path to the hdf5 file
  258.         Returns:
  259.             valid (bool): whether the episode is valid, which is useful for filtering.
  260.                 If False, this episode will be dropped.
  261.             dict: a dictionary containing the training sample,
  262.                 {
  263.                     "state": ndarray,           # state[:], (T, STATE_DIM).
  264.                     "action": ndarray,          # action[:], (T, STATE_DIM).
  265.                 } or None if the episode is invalid.
  266.         """
  267.         with h5py.File(file_path, "r") as f:
  268.             qpos = f["observations"]["qpos"][:]
  269.             left_arm_dim = f["observations"]["left_arm_dim"][:]
  270.             right_arm_dim = f["observations"]["right_arm_dim"][:]
  271.             num_steps = qpos.shape[0]
  272.             # [Optional] We drop too-short episode
  273.             # if num_steps < 128:
  274.             # return False, None
  275.             # [Optional] We skip the first few still steps
  276.             EPS = 1e-2
  277.             # Get the idx of the first qpos whose delta exceeds the threshold
  278.             qpos_delta = np.abs(qpos - qpos[0:1])
  279.             indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
  280.             if len(indices) > 0:
  281.                 first_idx = indices[0]
  282.             else:
  283.                 raise ValueError("Found no qpos that exceeds the threshold.")
  284.             # Rescale gripper to [0, 1]
  285.             # qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]])
  286.             # target_qpos = f["action"][:] / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]])
  287.             
  288.             qpos = qpos / np.array(
  289.             #    [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]
  290.             [[180, 180, 180, 180, 180, 180]]
  291.             )
  292.             target_qpos = f['action'][first_idx - 1:] / np.array(
  293.             #    [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]]
  294.             [[180, 180, 180, 180, 180, 180]]
  295.             )
  296.             # Parse the state and action
  297.             state = qpos[first_idx - 1:]
  298.             action = target_qpos[first_idx - 1:]
  299.             
  300.             # Standardize trajectory length to avoid batch size mismatch
  301.             # Use a fixed length (e.g., 128) or pad/truncate to match
  302.             target_length = 128  # You can adjust this value
  303.             if state.shape[0] > target_length:
  304.                 # Truncate to target length
  305.                 state = state[:target_length]
  306.                 action = action[:target_length]
  307.             elif state.shape[0] < target_length:
  308.                 # Pad with the last state/action
  309.                 pad_length = target_length - state.shape[0]
  310.                 state = np.concatenate([state, np.tile(state[-1:], (pad_length, 1))], axis=0)
  311.                 action = np.concatenate([action, np.tile(action[-1:], (pad_length, 1))], axis=0)
  312.             # Fill the state/action into the unified vector
  313.             def fill_in_state(values):
  314.                 # Target indices corresponding to your state space
  315.                 # In this example: 6 joints + 1 gripper for each arm
  316.                 UNI_STATE_INDICES =  [
  317.                     STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
  318.                 # ] + [
  319.                     # STATE_VEC_IDX_MAPPING["right_gripper_open"]
  320.                 ]
  321.                 uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, ))
  322.                 uni_vec[..., UNI_STATE_INDICES] = values
  323.                 return uni_vec
  324.             state = fill_in_state(state)
  325.             action = fill_in_state(action)
  326.             # Return the resulting sample
  327.             return True, {"state": state, "action": action}
  328. if __name__ == "__main__":
  329.     ds = HDF5VLADataset()
  330.     for i in range(len(ds)):
  331.         print(f"Processing episode {i}/{len(ds)}...")
  332.         ds.get_item(i)
复制代码
        接着直接开始训练:
  1. # step1:安装torch、torchvision
  2. pip install torch==2.1.0 torchvision==0.16.0  --index-url https://download.pytorch.org/whl/cu121
  3. # step2:安装packaging
  4. pip install packaging==24.0
  5. # step3:安装其他依赖
  6. pip install -r requirements.txt
复制代码
RDT170M微调

        要训练170M的版本我们需要修改RDT/configs/base.yaml文件中model类下的rdt参数
  1. $ python3 -c "import torch ; print(torch.__config__.show())"
复制代码
        接着生成训练的参数yml文件,并将pretrained_model_name_or_path指向我们先前下载的170M模型
  1. PyTorch built with:
  2.   - GCC 9.3
  3.   - C++ Version: 201703
  4.   - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  5.   - Intel(R) MKL-DNN v3.4.2 (Git Hash 1137e04ec0b5251ca2b4400a4fd3c667ce843d67)
  6.   - OpenMP 201511 (a.k.a. OpenMP 4.5)
  7.   - LAPACK is enabled (usually provided by MKL)
  8.   - NNPACK is enabled
  9.   - CPU capability usage: AVX2
  10.   - CUDA Runtime 12.1
  11.   - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
  12.   - CuDNN 90.1  (built against CUDA 12.4)
  13.   - Magma 2.6.1
  14.   - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=9.1.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.4.1, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,
复制代码
        接着直接开始训练:
  1. pip3 install flash_attn-2.7.2.post1+cu12torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
复制代码
        在双卡A100-40GB上BS16即单卡8BS的显存占用及训练速度参考如下,按照RDT论文中的说法关注overall_avg_sample_mse指标,RDT170M和1B的版本在数据量仅有100条的时候均能在7000步左右实现拟合指标下降到0.0001量级
4.png

5.png

四、实际评测

        在训练完成之后我们便可以开始实机评测啦,由于我们目前需要在Mac上连接LeRobot机械臂进行控制,因此我们在实际使用A100或者是RDKS100部署推理的时候还需要完成两端之间的通信代码,我们在这里就用最简单的Socket来实现ServerClient,还有其他更优的ZMQ等方式就不在这里呈现了,具体实现的代码如下,这份代码需要同时放到本地及推理服务器端进行调用
  1. # 法一: 直接运行我仓库中写好的脚本
  2. cd weights/RDT
  3. bash _download.sh
  4. # 法二: 手动下载
  5. export HF_ENDPOINT=https://hf-mirror.com # 国内镜像,加速下载
  6. huggingface-cli download google/t5-v1_1-xxl --local-dir t5-v1_1-xxl
  7. huggingface-cli download google/siglip-so400m-patch14-384 --local-dir siglip-so400m-patch14-384
  8. huggingface-cli download robotics-diffusion-transformer/rdt-1b --local-dir rdt-1b
  9. huggingface-cli download robotics-diffusion-transformer/rdt-170m --local-dir rdt-170m
复制代码
服务器单卡部署

       接着我们便可以完成我们服务器上的推理代码了,我们参考RDT中的RDT/scripts/agilex_model.py来完成我们的lerobot_rdt_server,我们在这个代码中完成对RoboticDiffusionTransformerModel类的修改同时使用类中的step执行推理,并集成我们的ServerClient来接收本地电脑发来的机械臂observation数据,具体的代码实现如下:
  1. rdt:
  2. # 1B: num_head 32 hidden_size 2048 depth: 28
  3. # 170M: num_head 32 hidden_size 1024 depth: 14
  4. hidden_size: 2048
  5. depth: 28
  6. num_heads: 32
  7. cond_pos_embed_type: multimodal
复制代码
        新版的LeRobot使用函数来替换了老的observation获取函数,同时其中的数据结构也进行了更改,因此我们在record的基础上直接复制一份使用最小的状态获取及机械臂控制实例完成本地的数据传输及通信控制代码,具体代码如下,文件位置为lerobot/src/lerobot/record_rdt.py
  1. bash generate.sh RDT1B_LeRobot
  2. # Generated on 2025-08-28 17:14:20
  3. model: RDT1B_LeRobot
  4. data_path: training_data/RDT1B_LeRobot
  5. checkpoint_path: checkpoints/RDT1B_LeRobot
  6. pretrained_model_name_or_path: ../weights/RDT/rdt-1b
  7. cuda_visible_device: '0,1,2,3'
  8. train_batch_size: 16
  9. sample_batch_size: 32
  10. max_train_steps: 10000
  11. checkpointing_period: 2500
  12. sample_period: 100
  13. checkpoints_total_limit: 40
  14. learning_rate: 0.0001
  15. dataloader_num_workers: 8
  16. state_noise_snr: 40
  17. gradient_accumulation_steps: 1
复制代码
        接着按照你的实际摄像头串口配置修改main()中的对应配置即可,接着在依次运行服务端即本地客户端即可看到机械臂开始运动完成任务:
  1. bash finetune.sh RDT1B_LeRobot
复制代码
RDKS100部署

        目前RDKS100上仅支持RDT170M的部署,接下来我们参考文档:RDT on Double RDK S100P 全流程文档,一步步的完成LeRobot的RDT的上板流程,我们首先使用仓库中RDT目录下的脚本RDT/export_all.py来完成RDT中所有ONNX模型的导出,这段脚本有如下的参数可以进行配置,大家仅需按照自己的实际情况配置即可
  1. rdt:
  2. # 1B: num_head 32 hidden_size 2048 depth: 28
  3. # 170M: num_head 32 hidden_size 1024 depth: 14
  4. hidden_size: 1024
  5. depth: 14
  6. num_heads: 32
  7. cond_pos_embed_type: multimodal
复制代码
        接下来我们执行这段脚本之后在RDT的目录下会生成脚本导出的ONNX模型以及dump出的校准数据和文件
  1. bash generate.sh RDT170M_LeRobot
  2. # Generated on 2025-08-28 17:14:20
  3. model: RDT170M_LeRobot
  4. data_path: training_data/RDT170M_LeRobot
  5. checkpoint_path: checkpoints/RDT170M_LeRobot
  6. pretrained_model_name_or_path: ../weights/RDT/rdt-170m
  7. cuda_visible_device: '0,1'
  8. train_batch_size: 16
  9. sample_batch_size: 32
  10. max_train_steps: 10000
  11. checkpointing_period: 2500
  12. sample_period: 100
  13. checkpoints_total_limit: 40
  14. learning_rate: 0.0001
  15. dataloader_num_workers: 8
  16. state_noise_snr: 40
  17. gradient_accumulation_steps: 1
复制代码
        接着我们使用RDKS100算法工具链标准交付的docker环境,来完成BPU模型的量化和编译,docker的安装挂载和使用命令参考以下:
  1. bash finetune.sh RDT170M_LeRobot
复制代码
        在Docker内输入bash build_all.sh后便会自动开启编译并导出可以在RDKS100板端运行的HBM模型
6.png

        以下是DiT、img_adaptor、lang_adaptor以及state_adaptor模型量化过程中的部分指标参考:
  1. import socket
  2. import numpy as np
  3. import zlib
  4. import json
  5. import base64
  6. import time
  7. from typing import Any
  8. import torch
  9. class NumpyEncoder(json.JSONEncoder):
  10.     """Enhanced json encoder for numpy types and PyTorch tensors with array reconstruction info"""
  11.     def default(self, obj):
  12.         if isinstance(obj, np.ndarray):
  13.             return {
  14.                 '__numpy_array__': True,
  15.                 'data': base64.b64encode(obj.tobytes()).decode('ascii'),
  16.                 'dtype': str(obj.dtype),
  17.                 'shape': obj.shape
  18.             }
  19.         elif torch is not None and isinstance(obj, torch.Tensor):
  20.             # 将 PyTorch Tensor 转换为 numpy 数组
  21.             numpy_array = obj.cpu().detach().numpy()
  22.             return {
  23.                 '__numpy_array__': True,
  24.                 'data': base64.b64encode(numpy_array.tobytes()).decode('ascii'),
  25.                 'dtype': str(numpy_array.dtype),
  26.                 'shape': numpy_array.shape
  27.             }
  28.         elif isinstance(obj, (np.integer, np.floating, np.bool_)):
  29.             return obj.item()
  30.         return super().default(obj)
  31. def numpy_to_json(data: Any) -> str:
  32.     return json.dumps(data, cls=NumpyEncoder)
  33. def json_to_numpy(json_str: str) -> Any:
  34.     def hook(dct):
  35.         if '__numpy_array__' in dct:
  36.             data = base64.b64decode(dct['data'])
  37.             return np.frombuffer(data, dtype=dct['dtype']).reshape(dct['shape'])
  38.         return dct
  39.     return json.loads(json_str, object_hook=hook)
  40. class CommonUtils:
  41.     @staticmethod
  42.     def serialize(data: Any) -> bytes:
  43.         return zlib.compress(numpy_to_json(data).encode('utf-8'))
  44.     @staticmethod
  45.     def deserialize(data: bytes) -> Any:
  46.         return json_to_numpy(zlib.decompress(data).decode('utf-8'))
  47. def send_all(sock, payload):
  48.     sock.sendall(len(payload).to_bytes(8, 'big') + payload)
  49. def recv_all(sock) :
  50.     length_bytes = sock.recv(8)
  51.     if not length_bytes:
  52.         return None
  53.     length = int.from_bytes(length_bytes, 'big')
  54.     buf = b''
  55.     while len(buf) < length:
  56.         chunk = sock.recv(length - len(buf))
  57.         if not chunk:
  58.             return None
  59.         buf += chunk
  60.     return buf
  61. class ServerClient:
  62.     def __init__(self, host='localhost', port=5000, is_server=True):
  63.         self.host, self.port, self.is_server = host, port, is_server
  64.         self.utils = CommonUtils()
  65.         self._connect()
  66.     def _connect(self):
  67.         if self.is_server:
  68.             self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  69.             self.sock.bind((self.host, self.port))
  70.             self.sock.listen(1)
  71.             print(f"[ServerClient] Listening on {self.host}:{self.port}")
  72.             self.conn, addr = self.sock.accept()
  73.             print(f"[ServerClient] Connected by {addr}")
  74.         else:
  75.             while True:
  76.                 try:
  77.                     self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  78.                     self.sock.connect((self.host, self.port))
  79.                     self.conn = self.sock
  80.                     print(f"[ServerClient] Connected to {self.host}:{self.port}")
  81.                     break
  82.                 except (ConnectionRefusedError, OSError):
  83.                     print("[ServerClient] Waiting for server...")
  84.                     time.sleep(2)
  85.     def send(self, data):
  86.         payload = self.utils.serialize(data)
  87.         try:
  88.             send_all(self.conn, payload)
  89.         except (BrokenPipeError, ConnectionResetError, OSError):
  90.             print("[ServerClient] Connection lost. Reconnecting...")
  91.             self._connect()
  92.             send_all(self.conn, payload)
  93.     def receive(self):
  94.         try:
  95.             buf = recv_all(self.conn)
  96.             return self.utils.deserialize(buf) if buf else None
  97.         except (BrokenPipeError, ConnectionResetError, OSError):
  98.             print("[ServerClient] Connection lost. Reconnecting...")
  99.             self._connect()
  100.             return None
  101.     def close(self):
  102.         self.conn.close()
  103.         self.sock.close()
  104. class Client:
  105.     def __init__(self, host='127.0.0.1', port=5000):
  106.         self.host, self.port = host, port
  107.         self.utils = CommonUtils()
  108.         self.connect()
  109.     def connect(self):
  110.         while True:
  111.             try:
  112.                 self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  113.                 self.sock.connect((self.host, self.port))
  114.                 print(f"[Client] Connected to {self.host}:{self.port}")
  115.                 break
  116.             except (ConnectionRefusedError, OSError):
  117.                 print("[Client] Waiting for server...")
  118.                 time.sleep(2)
  119.     def send(self, data):
  120.         payload = self.utils.serialize(data)
  121.         try:
  122.             send_all(self.sock, payload)
  123.         except (BrokenPipeError, ConnectionResetError, OSError):
  124.             print("[Client] Connection lost. Reconnecting...")
  125.             self.connect()
  126.             send_all(self.sock, payload)
  127.     def receive(self):
  128.         try:
  129.             buf = recv_all(self.sock)
  130.             return self.utils.deserialize(buf) if buf else None
  131.         except (BrokenPipeError, ConnectionResetError, OSError):
  132.             print("[Client] Connection lost. Reconnecting...")
  133.             self.connect()
  134.             return None
  135.     def close(self):
  136.         self.sock.close()
  137.         print("[Client] Closed.")
  138. class Server:
  139.     def __init__(self, host='0.0.0.0', port=5000):
  140.         self.host, self.port = host, port
  141.         self.utils = CommonUtils()
  142.         self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  143.         self.sock.bind((self.host, self.port))
  144.         self.sock.listen(1)
  145.         print(f"[Server] Listening on {self.host}:{self.port}")
  146.         self._wait_client()
  147.     def _wait_client(self):
  148.         print("[Server] Waiting for client...")
  149.         self.conn, addr = self.sock.accept()
  150.         print(f"[Server] Connected by {addr}")
  151.     def send(self, data: Any):
  152.         payload = self.utils.serialize(data)
  153.         try:
  154.             send_all(self.conn, payload)
  155.         except (BrokenPipeError, ConnectionResetError, OSError):
  156.             print("[Server] Client disconnected. Waiting new client...")
  157.             self._wait_client()
  158.             send_all(self.conn, payload)
  159.     def receive(self):
  160.         try:
  161.             buf = recv_all(self.conn)
  162.             return self.utils.deserialize(buf) if buf else None
  163.         except (BrokenPipeError, ConnectionResetError, OSError):
  164.             print("[Server] Client disconnected. Waiting new client...")
  165.             self._wait_client()
  166.             return None
  167.     def close(self):
  168.         self.conn.close()
  169.         self.sock.close()
  170.         print("[Server] Closed.")
复制代码
        接着我们运行以下命令下载已经量化编译好的SigLip到编译产物结果文件夹BPU_RDT_Policy中并将这个文件夹和test_data文件夹以及LeRobot-VLA仓库中RDKS_ModelRun/RDT路径下的所有文件复制到我们的RDKS100板端如图所示:
  1. import os, sys
  2. import numpy as np
  3. import torch
  4. from PIL import Image
  5. from torchvision import transforms
  6. import yaml
  7. from pathlib import Path
  8. # get current workspace
  9. current_file = Path(__file__)
  10. sys.path.append(os.path.join(current_file.parent.parent, "models"))
  11. sys.path.append(os.path.join(current_file.parent.parent, "models"))
  12. sys.path.append(os.path.join(current_file.parent.parent))  
  13. from configs.state_vec import STATE_VEC_IDX_MAPPING
  14. from multimodal_encoder.siglip_encoder import SiglipVisionTower
  15. from multimodal_encoder.t5_encoder import T5Embedder
  16. from rdt_runner import RDTRunner
  17. from server_client import ServerClient
  18. # The indices that the raw vector should be mapped to in the unified action vector
  19. AGILEX_STATE_INDICES = [
  20.     STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
  21. ]
  22. # Create the RDT model
  23. def create_model(args, **kwargs):
  24.     model = RoboticDiffusionTransformerModel(args, **kwargs)
  25.     pretrained = kwargs.get("pretrained", None)
  26.     if pretrained is not None and os.path.isfile(pretrained):
  27.         model.load_pretrained_weights(pretrained)
  28.     return model
  29. class RoboticDiffusionTransformerModel(object):
  30.     """A wrapper for the RDT model, which handles
  31.     1. Model initialization
  32.     2. Encodings of instructions
  33.     3. Model inference
  34.     """
  35.     def __init__(
  36.         self,
  37.         args,
  38.         device="cuda",
  39.         dtype=torch.bfloat16,
  40.         image_size=None,
  41.         control_frequency=25,
  42.         pretrained=None,
  43.         pretrained_vision_encoder_name_or_path=None,
  44.     ):
  45.         self.args = args
  46.         self.dtype = dtype
  47.         self.image_size = image_size
  48.         self.device = device
  49.         self.control_frequency = control_frequency
  50.         # We do not use the text encoder due to limited GPU memory
  51.         # self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
  52.         self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
  53.         self.policy = self.get_policy(pretrained)
  54.         self.reset()
  55.     def get_policy(self, pretrained):
  56.         """Initialize the model."""
  57.         # Initialize model with arguments
  58.         if pretrained is None or os.path.isfile(pretrained):
  59.             img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] *
  60.                             self.vision_model.num_patches)
  61.             _model = RDTRunner(
  62.                 action_dim=self.args["common"]["state_dim"],
  63.                 pred_horizon=self.args["common"]["action_chunk_size"],
  64.                 config=self.args["model"],
  65.                 lang_token_dim=self.args["model"]["lang_token_dim"],
  66.                 img_token_dim=self.args["model"]["img_token_dim"],
  67.                 state_token_dim=self.args["model"]["state_token_dim"],
  68.                 max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
  69.                 img_cond_len=img_cond_len,
  70.                 img_pos_embed_config=[
  71.                     # No initial pos embed in the last grid size
  72.                     # since we've already done in ViT
  73.                     (
  74.                         "image",
  75.                         (
  76.                             self.args["common"]["img_history_size"],
  77.                             self.args["common"]["num_cameras"],
  78.                             -self.vision_model.num_patches,
  79.                         ),
  80.                     ),
  81.                 ],
  82.                 lang_pos_embed_config=[
  83.                     # Similarly, no initial pos embed for language
  84.                     ("lang", -self.args["dataset"]["tokenizer_max_length"]),
  85.                 ],
  86.                 dtype=self.dtype,
  87.             )
  88.         else:
  89.             _model = RDTRunner.from_pretrained(pretrained)
  90.         return _model
  91.     def get_text_encoder(self, pretrained_text_encoder_name_or_path):
  92.         text_embedder = T5Embedder(
  93.             from_pretrained=pretrained_text_encoder_name_or_path,
  94.             model_max_length=self.args["dataset"]["tokenizer_max_length"],
  95.             device=self.device,
  96.         )
  97.         tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
  98.         return tokenizer, text_encoder
  99.     def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
  100.         vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
  101.         image_processor = vision_encoder.image_processor
  102.         return image_processor, vision_encoder
  103.     def reset(self):
  104.         """Set model to evaluation mode."""
  105.         device = self.device
  106.         weight_dtype = self.dtype
  107.         self.policy.eval()
  108.         # self.text_model.eval()
  109.         self.vision_model.eval()
  110.         self.policy = self.policy.to(device, dtype=weight_dtype)
  111.         # self.text_model = self.text_model.to(device, dtype=weight_dtype)
  112.         self.vision_model = self.vision_model.to(device, dtype=weight_dtype)
  113.     def load_pretrained_weights(self, pretrained=None):
  114.         if pretrained is None:
  115.             return
  116.         print(f"Loading weights from {pretrained}")
  117.         filename = os.path.basename(pretrained)
  118.         if filename.endswith(".pt"):
  119.             checkpoint = torch.load(pretrained)
  120.             self.policy.load_state_dict(checkpoint["module"])
  121.         elif filename.endswith(".safetensors"):
  122.             from safetensors.torch import load_model
  123.             load_model(self.policy, pretrained)
  124.         else:
  125.             raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")
  126.     def encode_instruction(self, instruction, device="cuda"):
  127.         """Encode string instruction to latent embeddings.
  128.         Args:
  129.             instruction: a string of instruction
  130.             device: a string of device
  131.         Returns:
  132.             pred: a tensor of latent embeddings of shape (text_max_length, 512)
  133.         """
  134.         tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest",
  135.                                      truncation=True)["input_ids"].to(device)
  136.         tokens = tokens.view(1, -1)
  137.         with torch.no_grad():
  138.             pred = self.text_model(tokens).last_hidden_state.detach()
  139.         return pred
  140.     def _format_joint_to_state(self, joints):
  141.         """
  142.         Format the joint proprioception into the unified action vector.
  143.         Args:
  144.             joints (torch.Tensor): The joint proprioception to be formatted.
  145.                 qpos ([B, N, 14]).
  146.         Returns:
  147.             state (torch.Tensor): The formatted vector for RDT ([B, N, 128]).
  148.         """
  149.         # Rescale the gripper to the range of [0, 1]
  150.         joints = joints / torch.tensor(
  151.             [[[180, 180, 180, 180, 180, 180]]],
  152.             device=joints.device,
  153.             dtype=joints.dtype,
  154.         )
  155.         B, N, _ = joints.shape
  156.         state = torch.zeros(
  157.             (B, N, self.args["model"]["state_token_dim"]),
  158.             device=joints.device,
  159.             dtype=joints.dtype,
  160.         )
  161.         # Fill into the unified state vector
  162.         state[:, :, AGILEX_STATE_INDICES] = joints
  163.         # Assemble the mask indicating each dimension's availability
  164.         state_elem_mask = torch.zeros(
  165.             (B, self.args["model"]["state_token_dim"]),
  166.             device=joints.device,
  167.             dtype=joints.dtype,
  168.         )
  169.         state_elem_mask[:, AGILEX_STATE_INDICES] = 1
  170.         return state, state_elem_mask
  171.     def _unformat_action_to_joint(self, action):
  172.         """
  173.         Unformat the unified action vector into the joint action to be executed.
  174.         Args:
  175.             action (torch.Tensor): The unified action vector to be unformatted.
  176.                 ([B, N, 128])
  177.         Returns:
  178.             joints (torch.Tensor): The unformatted robot joint action.
  179.                 qpos ([B, N, 14]).
  180.         """
  181.         action_indices = AGILEX_STATE_INDICES
  182.         joints = action[:, :, action_indices]
  183.         # Rescale the gripper back to the action range
  184.         # Note that the action range and proprioception range are different
  185.         # for Mobile ALOHA robot
  186.         joints = joints * torch.tensor(
  187.             [[[180, 180, 180, 180, 180, 180]]],
  188.             device=joints.device,
  189.             dtype=joints.dtype,
  190.         )
  191.         return joints
  192.     @torch.no_grad()
  193.     def step(self, proprio, images, text_embeds):
  194.         """
  195.         Predict the next action chunk given the
  196.         proprioceptive states, images, and instruction embeddings.
  197.         Args:
  198.             proprio: proprioceptive states
  199.             images: RGB images, the order should be
  200.                 [ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1},
  201.                 ext_{t}, right_wrist_{t}, left_wrist_{t}]
  202.             text_embeds: instruction embeddings
  203.         Returns:
  204.             action: predicted action
  205.         """
  206.         device = self.device
  207.         dtype = self.dtype
  208.         # The background image used for padding
  209.         background_color = np.array([int(x * 255) for x in self.image_processor.image_mean],
  210.                                     dtype=np.uint8).reshape(1, 1, 3)
  211.         background_image = (np.ones(
  212.             (
  213.                 self.image_processor.size["height"],
  214.                 self.image_processor.size["width"],
  215.                 3,
  216.             ),
  217.             dtype=np.uint8,
  218.         ) * background_color)
  219.         # Preprocess the images by order and encode them
  220.         image_tensor_list = []
  221.         for image in images:
  222.             if image is None:
  223.                 # Replace it with the background image
  224.                 image = Image.fromarray(background_image)
  225.             else:
  226.                 # Convert numpy array to PIL Image if needed
  227.                 if isinstance(image, np.ndarray):
  228.                     image = Image.fromarray(image)
  229.             if self.image_size is not None:
  230.                 image = transforms.Resize(self.image_size)(image)
  231.             if self.args["dataset"].get("auto_adjust_image_brightness", False):
  232.                 pixel_values = list(image.getdata())
  233.                 average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
  234.                 if average_brightness <= 0.15:
  235.                     image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)
  236.             if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad":
  237.                 def expand2square(pil_img, background_color):
  238.                     width, height = pil_img.size
  239.                     if width == height:
  240.                         return pil_img
  241.                     elif width > height:
  242.                         result = Image.new(pil_img.mode, (width, width), background_color)
  243.                         result.paste(pil_img, (0, (width - height) // 2))
  244.                         return result
  245.                     else:
  246.                         result = Image.new(pil_img.mode, (height, height), background_color)
  247.                         result.paste(pil_img, ((height - width) // 2, 0))
  248.                         return result
  249.                 image = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean))
  250.             image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
  251.             image_tensor_list.append(image)
  252.         image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
  253.         image_embeds = self.vision_model(image_tensor).detach()
  254.         image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)
  255.         # Prepare the proprioception states and the control frequency
  256.         # Convert numpy array to tensor if needed
  257.         if isinstance(proprio, np.ndarray):
  258.             # Copy the array to make it writable
  259.             proprio = torch.from_numpy(proprio.copy())
  260.         
  261.         joints = proprio.to(device).unsqueeze(0)  # (1, 1, 14)
  262.         states, state_elem_mask = self._format_joint_to_state(joints)  # (1, 1, 128), (1, 128)
  263.         states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
  264.         states = states[:, -1:, :]  # (1, 1, 128)
  265.         ctrl_freqs = torch.tensor([self.control_frequency]).to(device)
  266.         text_embeds = text_embeds.to(device, dtype=dtype)
  267.         # Predict the next action chunk given the inputs
  268.         trajectory = self.policy.predict_action(
  269.             lang_tokens=text_embeds,
  270.             lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device),
  271.             img_tokens=image_embeds,
  272.             state_tokens=states,
  273.             action_mask=state_elem_mask.unsqueeze(1),
  274.             ctrl_freqs=ctrl_freqs,
  275.         )
  276.         trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)
  277.         return trajectory
  278. class LERobotRDTServer:
  279.     def __init__(self, pretrained_vision_encoder_name_or_path, pretrained, args, lang_model):
  280.         self.policy = create_model(
  281.             args=args,
  282.             dtype=torch.bfloat16,
  283.             pretrained=pretrained,
  284.             pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,
  285.             control_frequency=30,
  286.         )
  287.         self.server = ServerClient(host="0.0.0.0", port=5002, is_server=True)
  288.         
  289.         # Load and debug language embeddings
  290.         self.lang_embeddings = torch.load(lang_model)
  291.         print(f"Loaded language embeddings shape: {self.lang_embeddings.shape}")
  292.         print(f"Model expects tokenizer_max_length: {self.policy.args['dataset']['tokenizer_max_length']}")
  293.         print(f"Model lang_token_dim: {self.policy.args['model']['lang_token_dim']}")
  294.         
  295.         # Check if dimensions match
  296.         expected_seq_len = self.policy.args["dataset"]["tokenizer_max_length"]
  297.         expected_hidden_dim = self.policy.args["model"]["lang_token_dim"]
  298.         
  299.         # Handle different embedding formats
  300.         if len(self.lang_embeddings.shape) == 2:
  301.             # Format: [seq_len, hidden_dim]
  302.             actual_seq_len, actual_hidden_dim = self.lang_embeddings.shape
  303.             if actual_seq_len != expected_seq_len:
  304.                 print(f"WARNING: Sequence length mismatch! Expected {expected_seq_len}, got {actual_seq_len}")
  305.             if actual_hidden_dim != expected_hidden_dim:
  306.                 print(f"WARNING: Hidden dimension mismatch! Expected {expected_hidden_dim}, got {actual_hidden_dim}")
  307.         elif len(self.lang_embeddings.shape) == 3:
  308.             # Format: [batch_size, seq_len, hidden_dim]
  309.             actual_batch, actual_seq_len, actual_hidden_dim = self.lang_embeddings.shape
  310.             if actual_seq_len != expected_seq_len:
  311.                 print(f"WARNING: Sequence length mismatch! Expected {expected_seq_len}, got {actual_seq_len}")
  312.             if actual_hidden_dim != expected_hidden_dim:
  313.                 print(f"WARNING: Hidden dimension mismatch! Expected {expected_hidden_dim}, got {actual_hidden_dim}")
  314.         else:
  315.             print(f"WARNING: Unexpected embedding shape: {self.lang_embeddings.shape}")
  316.         
  317.     def run(self):
  318.         print("LERobot RDT Server started, waiting for messages...")
  319.         try:
  320.             while True:
  321.                 print("Waiting for RDT data...")
  322.                 rdt_data = self.server.receive()
  323.                 print(f"Received RDT data, message_id: {rdt_data['message_id']}")
  324.                
  325.                 # Perform inference
  326.                 # Ensure language embeddings have correct shape
  327.                 if len(self.lang_embeddings.shape) == 2:
  328.                     # [seq_len, hidden_dim] -> [1, seq_len, hidden_dim]
  329.                     text_embeds = self.lang_embeddings.unsqueeze(0)
  330.                 else:
  331.                     # Already [batch_size, seq_len, hidden_dim]
  332.                     text_embeds = self.lang_embeddings
  333.                
  334.                 action = self.policy.step(
  335.                     proprio=rdt_data["proprio"],
  336.                     images=rdt_data["images"],
  337.                     text_embeds=text_embeds,
  338.                 )
  339.                
  340.                 # Prepare response - use 'actions' key to match client expectation
  341.                 message_id = rdt_data["message_id"]
  342.                 action_data = {
  343.                     "message_id": message_id,
  344.                     "actions": action,  # Changed from 'action' to 'actions'
  345.                 }
  346.                
  347.                 # Send response
  348.                 print(f"send action data, action_data: {action_data}")
  349.                 self.server.send(action_data)
  350.                 print(f"Sent action data for message_id: {message_id}")
  351.                
  352.         except KeyboardInterrupt:
  353.             print("\nServer stopped by user")
  354.             self.server.close()
  355.         except Exception as e:
  356.             print(f"Error in server loop: {e}")
  357.             self.server.close()
  358.             raise
  359. if __name__ == "__main__":
  360.     path_to_rdt_model_wights = "/home/qi.xiong/DualArm/RoboTwin/policy/RDT/checkpoints/RDT_LeRobot/checkpoint-7500/pytorch_model/mp_rank_00_model_states.pt"
  361.     path_to_vision_encoder_model = "/home/qi.xiong/DualArm/RoboTwin/policy/weights/RDT/siglip-so400m-patch14-384"
  362.     lang_model = "/home/qi.xiong/DualArm/RoboTwin/policy/RDT/scripts/lerobot_rdt_data/greenmarker_scene1/episode_4/instructions/lang_embed_0.pt"
  363.     with open("/home/qi.xiong/DualArm/RoboTwin/policy/RDT/configs/base.yaml", "r") as fp:
  364.         config = yaml.safe_load(fp)
  365.     rdt_server = LERobotRDTServer(path_to_vision_encoder_model, path_to_rdt_model_wights, config, lang_model)
  366.     rdt_server.run()
复制代码
7.png

        接着我们直接运行以下命令即可首先使用校准测试数据验证模型推理是否成功,若无报错即证明模型可正常推理无异常,若运行过程中缺少某个依赖,直接安装即可:
  1. import logging
  2. import time
  3. from dataclasses import asdict, dataclass
  4. from pathlib import Path
  5. from pprint import pformat
  6. from collections import deque
  7. import torch
  8. from PIL import Image
  9. import numpy as np
  10. from lerobot.cameras import (  # noqa: F401
  11.     CameraConfig,  # noqa: F401
  12. )
  13. from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig  # noqa: F401
  14. from lerobot.configs import parser
  15. from lerobot.configs.policies import PreTrainedConfig
  16. from lerobot.datasets.image_writer import safe_stop_image_writer
  17. from lerobot.datasets.lerobot_dataset import LeRobotDataset
  18. from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
  19. from lerobot.datasets.video_utils import VideoEncodingManager
  20. from lerobot.policies.factory import make_policy
  21. from lerobot.policies.pretrained import PreTrainedPolicy
  22. from lerobot.robots import (  # noqa: F401
  23.     Robot,
  24.     RobotConfig,
  25.     bi_so100_follower,
  26.     hope_jr,
  27.     koch_follower,
  28.     make_robot_from_config,
  29.     so100_follower,
  30.     so101_follower,
  31. )
  32. from lerobot.robots.so101_follower.so101_follower import SO101Follower
  33. from lerobot.robots.so101_follower.config_so101_follower import SO101FollowerConfig
  34. from lerobot.teleoperators import (  # noqa: F401
  35.     Teleoperator,
  36.     TeleoperatorConfig,
  37.     bi_so100_leader,
  38.     homunculus,
  39.     koch_leader,
  40.     make_teleoperator_from_config,
  41.     so100_leader,
  42.     so101_leader,
  43. )
  44. from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
  45. from lerobot.utils.control_utils import (
  46.     init_keyboard_listener,
  47.     is_headless,
  48.     predict_action,
  49.     sanity_check_dataset_name,
  50.     sanity_check_dataset_robot_compatibility,
  51. )
  52. from lerobot.utils.robot_utils import busy_wait
  53. from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
  54. from server_client import *
  55. debug_save_img = True
  56. _action_queue = deque([], maxlen=64)
  57. _message_id = 0
  58. _last_cam_high = None  
  59. _last_cam_right_wrist = None  
  60. @safe_stop_image_writer
  61. def record_loop(
  62.     robot: Robot,
  63.     client: Client,
  64. ):
  65.     global _action_queue, _message_id, _last_cam_high, _last_cam_right_wrist
  66.     observation = robot.get_observation()
  67.     cam_high = observation['high']
  68.     cam_right_wrist = observation['arm']
  69.     image_arrs = [
  70.         _last_cam_high,
  71.         _last_cam_right_wrist,
  72.         None,
  73.         cam_high,
  74.         cam_right_wrist,
  75.         None
  76.     ]
  77.     images = [arr if arr is not None else None
  78.             for arr in image_arrs]
  79.     joint_positions = [observation[key] for key in observation.keys() if key.endswith('.pos')]
  80.     proprio = torch.tensor(joint_positions, dtype=torch.float32).unsqueeze(0)
  81. ###################Debug图像########################
  82.     if debug_save_img:
  83.         imgs_to_show = [cam_high, cam_right_wrist, _last_cam_high, _last_cam_right_wrist]
  84.         if all(img is not None for img in imgs_to_show):
  85.             pil_imgs = []
  86.             for img in imgs_to_show:
  87.                 if img.dtype != np.uint8:
  88.                     img = np.clip(img, 0, 1)
  89.                     img = (img * 255).astype(np.uint8)
  90.                 if img.ndim == 2:
  91.                     img = np.stack([img]*3, axis=-1)  
  92.                 elif img.shape[-1] == 1:
  93.                     img = np.repeat(img, 3, axis=-1)
  94.                 pil_imgs.append(Image.fromarray(img))
  95.             w, h = pil_imgs[0].size
  96.             for i in range(4):
  97.                 if pil_imgs[i].size != (w, h):
  98.                     pil_imgs[i] = pil_imgs[i].resize((w, h))
  99.             new_img = Image.new('RGB', (w*2, h*2))
  100.             new_img.paste(pil_imgs[0], (0, 0))       # 左上:新high
  101.             new_img.paste(pil_imgs[1], (w, 0))       # 右上:新wrist
  102.             new_img.paste(pil_imgs[2], (0, h))       # 左下:老high
  103.             new_img.paste(pil_imgs[3], (w, h))       # 右下:老wrist
  104.             debug_save_path = "debug_2x2.png"
  105.             new_img.save(debug_save_path)
  106.             print(f"Have been saved at: {debug_save_path}")
  107.             # new_img.show()
  108. ###################Debug图像########################
  109.     rdt_data = {
  110.         'message_id': _message_id,
  111.         'proprio': proprio,
  112.         'images': images,
  113.         'text_embeds': ""
  114.     }
  115.     client.send(rdt_data)
  116.     _message_id += 1
  117.     print(f"send new rdt data done, message_id: {_message_id-1}")
  118.     action_data = client.receive()
  119.     if action_data is None:
  120.         print("ERROR: Server returned None. Is the RDT server running?")
  121.         print("Please start the RDT server first!")
  122.         raise ConnectionError("Failed to receive response from RDT server")
  123.     actions = action_data['actions']
  124.     action_message_id = action_data["message_id"]
  125.     print(f"receive actions done, message_id: {action_message_id}")
  126.     # print(f"receive actions contents: {actions}")
  127.     actions_array = np.array(actions)
  128.     if len(actions_array.shape) == 3:
  129.         action_sequence = actions_array[0, :, :]  # 取第一个batch的所有时间步
  130.     else:
  131.         print(f"action shape should be 3 dim, but get {actions_array.shape} ")
  132.    
  133.     joint_names = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
  134.     for step_idx in range(0, len(action_sequence), 4):  # 64个动作隔4个执行一次动作
  135.         action_values = action_sequence[step_idx]
  136.         action_dict = {f"{joint}.pos": float(action_values[i]) for i, joint in enumerate(joint_names)}
  137.         sent_action = robot.send_action(action_dict)
  138.         time.sleep(0.1)  
  139.     _last_cam_high = cam_high
  140.     _last_cam_right_wrist = cam_right_wrist
  141. def main():
  142.     robot = SO101Follower(SO101FollowerConfig(
  143.         port="/dev/tty.usbmodem5AB90671801",
  144.         id="my_awesome_follower_arm",
  145.         cameras={
  146.             "arm": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30),
  147.             "high": OpenCVCameraConfig(index_or_path=2, width=1920, height=1080, fps=30)
  148.         }
  149.     ))
  150.    
  151.     robot.connect()
  152.     client = Client(host="localhost", port=5002)
  153.     try:
  154.         while True:
  155.             record_loop(
  156.                 robot,
  157.                 client
  158.             )
  159.             time.sleep(0.1)
  160.     except KeyboardInterrupt:
  161.         pass
  162.     robot.disconnect()
  163. if __name__ == "__main__":
  164.     main()
复制代码
        测试正常后我们便可以仿照上面服务器推理一样先启动板端推理代码,将其作为一个板端推理Server,接着运行我们本地的机械臂控制代码即可:
  1. # 服务端
  2. python3 RDT/scripts/lerobot_rdt_server.py
  3. # 客户端
  4. python3 lerobot/src/lerobot/record_rdt.py
复制代码
        RDKS100板端性能占用参考如下(多多支持我的Dtop:Jetson有Jtop,Linux有Htop,RDK也有Dtop! - SkyXZ - 博客园哈哈哈哈哈哈哈哈哈
8.png

        RDKS100可直接连接LeRobot控制机械臂,但由于写文档的时候手上没有多余的摄像头,头部的摄像头只能使用Apple的连续互通实现,若使用RDK直接连接LeRobot的话将无法当问头部摄像头,因此只能绕个弯用本地Mac来控制机械臂了...

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
13 小时前

举报

谢谢分享,辛苦了
您需要登录后才可以回帖 登录 | 立即注册