多时间步节点拼接
多时间步节点用于拼接多个历史时间步的数据作为输入,从而提升模型的精度和鲁棒性。通过使用多时间步节点,模型可以获得更丰富的输入信息,进一步增强对数据动态变化的捕捉能力及整体性能。以汽车驾驶场景为例,当前时刻的速度信息往往有限,难以全面描述车辆的状态。通过拼接多个时间步内的速度数据,不仅可以获取到速度本身的信息,还可以推导出加速度、减速度等动态特征,从而更精确地预测汽车未来的速度和行驶状态,并提升车辆控制系统的响应速度和调控精度。
因此在构建模型时,将多个历史时间步的数据拼接为输入是一种非常有效的策略。REVIVE 提供了一种简单的配置方法,方便用户快速实现多时间步节点的拼接功能。
下面是一个示例,展示如何进行多时间步节点的数据拼接:
metadata:
graph:
action:
- observation
next_observation:
- action
- observation
columns:
...在下面的 .yaml 文件中,我们通过添加节点信息并将 observation 的 ts 属性配置为 5,以获得 observation 节点的历史 5 步拼帧数据。多时间步拼接后的节点名称为 ts_observation。ts_observation 节点将自动将历史 5 时间步的 observation 数据进行拼接。
需要注意的是,REVIVE 会自动检测存在 ts_ 前缀的节点。用户应避免在自定义节点名时使用 ts_ 前缀。此外,在上述例子中拼帧节点 ts_observation 可以类比为一个长度为 5 的队列,observation 在其中符合先入先出(FIFO)原则,最新的 observation 将会被拼在特征维度的最后。
metadata:
graph:
action:
- ts_observation
next_ts_observation:
- action
- ts_observation
columns:
...
nodes:
observation:
ts: 5在一些特殊的任务中,action 节点的输出需要考虑历史 action 的影响。因此我们也需要对 action 节点进行拼帧操作:
metadata:
graph:
action:
- ts_action
- ts_observation
next_ts_action:
- action
- ts_action
next_ts_observation:
- next_ts_action
- ts_observation
columns:
...
nodes:
observation:
ts: 5
ts_repeat: false
action:
ts: 5
ts_endpoint: 5
ts_repeat: false在上面的示例中,我们新增了一个 ts_action 节点。需要注意的是,我们同时设置了 ts_endpoint 属性。ts_endpoint 用于控制 ts 节点中数据的结束位置。默认情况下,如果不设置 ts_endpoint,系统会自动使用 ts+1 来获取历史拼帧时间步的数据以及当前时间步的数据。然而,在本案例中,我们需要预测当前的 action,因此通过设置 ts_endpoint 属性,确保 ts_action 中仅包含历史的 action 数据。
还可以增加专家函数以完成多时间步节点的转移函数。该方法虽然比较复杂,但是可以降低 REVIVE 学习转移节点的难度,提高虚拟环境模型精度。下面是一个示例:
通过 yaml 文件指定专家函数:
metadata:
graph:
action:
- ts_observation
next_observation:
- ts_observation
- action
next_ts_observation:
- next_observation
- ts_observation
columns:
...
nodes:
observation:
ts: 5
expert_functions:
next_ts_observation:
'node_function' : 'expert_function.next_ts_observation_func'专家函数:
import torch
from typing import Dict
def next_ts_observation_func(data: Dict[str, torch.Tensor]) -> torch.Tensor:
obs_dim = data["next_observation"].shape[-1]
next_obs = data["next_observation"]
ts_obs = data["ts_observation"]
next_ts_obs = torch.cat([ts_obs, next_obs], axis=-1)[..., obs_dim:]
return next_ts_obs