倒立摆控制任务
任务概述

Gym-Pendulum 是传统强化学习(RL)领域中的经典控制问题之一。如以上动画所示,摆杆的一端连接到固定点,另一端可以自由摆动。该控制问题的目标是在摆杆的自由端施加力矩,使摆杆最终稳定地倒立于固定点之上。在该位置,摆杆可以"站"在固定点上并保持平衡。该问题的详细说明请参考 Pendulum。在本示例中,我们将说明如何使用 REVIVE SDK 构建 Gym-Pendulum 的虚拟环境,并基于虚拟环境学习最优的控制策略。我们还将对比 REVIVE SDK 输出策略和历史数据策略的表现,直观感受并理解 REVIVE 的运行机制和训练效果。
| 项目 | 描述 |
|---|---|
| Action Space | Continuous(1) |
| Observation | Shape (3,) |
| Observation | High [1. 1. 8.] |
| Observation | Low [-1. -1. -8.] |
动作空间
对摆杆的自由端施加力矩,力矩的大小是连续分布于 [-2, 2] 空间中。
观察空间
观察空间为三维,分别代表摆杆与重力方向夹角的正弦值、余弦值和该夹角的角速度值。
任务目标
在 Gym-Pendulum 任务中,我们的目标是在摆杆的一端施加扭矩,使其倒立于固定点上。奖励函数由以下等式确定:
其中,
import torch
import math
def get_reward(data : Dict[str, torch.Tensor]) -> torch.Tensor:
action = data['actions'][...,0:1]
u = torch.clamp(action, -2, 2)
state = data['states'][...,0:3]
costheta = state[:,0].view(-1,1)
sintheta = state[:, 1].view(-1,1)
thdot = state[:, 2].view(-1,1)
x = torch.acos(costheta)
theta = ((x + math.pi) % (2 * math.pi)) - math.pi
costs = theta ** 2 + 0.1 * thdot**2 + 0.001 * (u**2)
return -costs初始状态
摆杆可以以任意夹角以及该夹角的任意角速度作为初始状态。
训练流程
REVIVE 是一个基于历史数据的离线强化学习工具。在摆杆控制任务上使用 REVIVE 的完整流程如下:
- 收集历史决策数据:收集摆杆控制任务的历史运行数据
- 构建决策流图和训练数据:
- 结合业务场景和历史数据构建 决策流图
- 决策流图使用
.yaml文件描述业务数据的交互逻辑 - 训练数据使用
.npz或.h5文件存储决策流图中定义的节点数据
- 定义奖励函数:
- 根据任务目标设计 奖励函数
- 奖励函数指导控制策略优化,确保摆杆能够稳定倒立在固定点上
- 开始模型训练:
- 完成决策流图、训练数据和奖励函数定义后
- 使用 REVIVE 进行虚拟环境模型训练和策略模型训练
- 上线测试:
- 将训练好的策略模型部署到实际环境进行测试
收集历史数据
在本示例中,我们假设已经有一个可用的摆杆控制策略(以下简称:原始策略),我们的目标是通过 REVIVE 训练一个比此策略更优的新策略。我们首先使用这一原始策略来收集历史数据。
定义决策流图和准备数据
一旦有了历史决策数据,就需要根据业务场景构建决策流图。决策流图准确定义了数据之间的决策因果关系。在摆杆控制任务中,我们可以观察到摆杆的状态信息(states)。状态是一个三维量,分别代表摆杆与重力方向夹角的正弦值、余弦值和该夹角的角速度。控制策略 actions 根据 states 的信息对摆杆的自由端施加力矩。
下面的示例展示了 .yaml 文件的详细配置。.yaml 文件通常包含两个主要部分:graph 和 columns。其中 graph 部分定义决策流图,columns 部分定义数据组成。具体配置方法请参考 准备数据 文档。
由于 states 包含三个维度,states 的列需要按顺序定义在 columns 部分。如 gym-pendulum 所示,状态和动作中的变量都是连续分布的,因此使用 continuous 类型描述每一列数据。
metadata:
graph: <- 'graph'部分
actions: <- 对应于 '.npz' 的 `actions`.
- states <- 对应于 '.npz' 的 `states`.
next_states:
- states <- 对应于 '.npz' 的 `states`.
- actions <- 对应于 '.npz' 的 `actions`.
columns: <- 'columns'部分
- obs_0: ---+
dim: states |
type: continuous |
- obs_1: | 这里, 'dim:states' 对应 '.npz' 的 'states'
dim: states | <- 'obs_*' 表示第*维的 'states'。
type: continuous |
- obs_2 | 因为'states'有三个维度,我们按照维度的顺序在
dim: states | 'columns'中进行了定义
type: continuous ---+
- action:
dim: actions
type: continuous根据 准备数据 将数据转换为 .npz 文件进行存储。
定义奖励函数
奖励函数的设计对于学习策略至关重要。一个好的奖励函数应该能够指导策略向着预期的方向进行学习。REVIVE 支持以 Python 源文件的方式定义奖励函数。
倒立摆的目标在于将摆杆倒立在固定点上,此时与重力方向的夹角为 0 度,并获得最高奖励值 0。当摆杆垂直悬挂在固定点上时,此时夹角为最大值 180 度,获得最小奖励值 -16。
其中方程的最大值和最小值分别为 0 和 -16,分别对应于摆杆倒立在固定点上和垂直悬挂的状态。
import torch
import math
def get_reward(data : Dict[str, torch.Tensor]) -> torch.Tensor:
action = data['actions'][...,0:1]
u = torch.clamp(action, -2, 2)
state = data['states'][...,0:3]
costheta = state[:,0].view(-1,1)
sintheta = state[:, 1].view(-1,1)
thdot = state[:, 2].view(-1,1)
x = torch.acos(costheta)
theta = ((x + math.pi) % (2 * math.pi)) - math.pi
costs = theta ** 2 + 0.1 * thdot**2 + 0.001 * (u**2)
return -costs定义奖励函数的更多细节请参考 准备数据 章节的文档介绍。
训练控制策略
我们已经构建完成运行 REVIVE SDK 所需的所有文件,包括 .npz 数据文件、.yaml 文件和 reward.py 奖励函数。这三个文件位于 data 文件夹中。另外还有 config.json 文件,保存了训练所需的超参数。
我们可以使用以下命令开始模型训练:
python train.py \
-df data/expert_data.npz \
-cf data/Env-GAIL-pendulum.yaml \
-rf data/pendulum-reward.py \
-rcf data/config.json \
-vm once \
-pm once \
--run_id pendulum-v1 \
--revive_epoch 1000 \
--ppo_epoch 5000 \
--venv_rollout_horizon 50 \
--ppo_rollout_horizon 50训练模型的更多细节请参考 训练模型 章节的文档介绍。
重要提示
REVIVE 已提供运行示例所需的数据和代码,支持一键运行。数据和代码存储在 SDK 源码库。
测试模型
训练完成后,我们从日志文件中获得 REVIVE 训练后的控制策略,该策略保存路径为 logs/pendulum-v1/policy.pkl。我们尝试在 Gym-Pendulum 环境上测试策略的效果,并与历史数据中的控制策略(原始策略)进行对比。在下面的测试代码中,我们将策略在 Gym-Pendulum 环境中随机测试 50 次,每次执行 300 个时间步长,最后输出这 50 次的平均回报(累计奖励)。REVIVE 训练获得的策略获得了 -137.66 平均奖励,远高于数据中原始策略的 -861.74 奖励值,控制效果提高了约 84%。
import warnings
warnings.filterwarnings('ignore')
from Results import get_results
import pickle
result = get_results('logs/pendulum-v1/policy.pkl', 'url/Old_policy.pkl')
r_revive, r_old, vedio_revive, vedio_old = result.roll_out(50, step=300)
with open('url/results.pkl', 'wb') as f:
pickle.dump([vedio_revive, vedio_old], f)
# 输出:
# REVIVE 平均回报: -137.66
# 原始平均回报: -861.74为了更直观地比较策略效果,我们通过下面的代码生成策略的控制动画。我们在动画中展示钟摆运动的每一步,从比较来看,左侧由 REVIVE 输出的策略可以在 3 秒内将摆杆稳定地倒立在平衡点上,而右侧数据中的原始策略始终不能将摆杆控制到目标位置。
from Video import get_video
from IPython import display
%matplotlib notebook
vedio_revive, vedio_old = pickle.load(open('url/results.pkl', 'rb'))
html = get_video(vedio_revive,vedio_old)
display.display(html)