多时间步节点拼接
多时间步节点用于拼接多个历史时间步的数据作为输入,从而提升模型的精度和鲁棒性。通过使用多时间步节点,模型可以获得更丰富的输入信息,进一步增强对数据动态变化的捕捉能力及整体性能。以汽车驾驶场景为例,当前时刻的速度信息往往有限,难以全面描述车辆的状态。通过拼接多个时间步内的速度数据,不仅可以获取到速度本身的信息,还可以推导出加速度、减速度等动态特征,从而更精确地预测汽车未来的速度和行驶状态,并提升车辆控制系统的响应速度和调控精度。
因此在构建模型时,将多个历史时间步的数据拼接为输入是一种非常有效的策略。REVIVE 提供了一种简单的配置方法,方便用户快速实现多时间步节点的拼接功能。
下面是一个示例,展示如何进行多时间步节点的数据拼接:
metadata:
graph:
action:
- observation
next_observation:
- action
- observation
columns:
...
在下面的 .yaml
文件中,我们通过添加节点信息并将 observation
的 ts
属性配置为5, 以获得 observation
节点的历史5步拼帧数据。多时间步拼接后的节点名称为 ts_observation
。ts_observation
节点将自动将历史5时间步的 observation
数据进行拼接。
需要注意的是,REVIVE会自动检测存在 ts_
前缀的节点。用户应避免在自定义节点名时使用 ts_
前缀。此外,在上述例子中拼帧节点 ts_observation
可以类比为一个长度为5的队列, observation
在其中符合先入先出(FIFO)原则,最新的 observation
将会被拼在特征维度的最后。
metadata:
graph:
action:
- ts_observation
next_ts_observation:
- action
- ts_observation
columns:
...
nodes:
observation:
ts: 5
在一些特殊的任务中, action
节点的输出需要考虑历史 action
的影响。因此我们也需要对 action
节点进行拼帧操作:
metadata:
graph:
action:
- ts_action
- ts_observation
next_ts_action:
- action
- ts_action
next_ts_observation:
- next_ts_action
- ts_observation
columns:
...
nodes:
observation:
ts: 5
ts_repeat: false
action:
ts: 5
endpoint: 5
ts_repeat: false
在上面的示例中,我们新增了一个 ts_action
节点。需要注意的是,我们同时设置了 endpoint
属性。 endpoint
用于控制 ts
节点中数据的结束位置。默认情况下,如果不设置endpoint
,系统会自动使用 ts+1
来获取历史拼帧时间步的数据以及当前时间步的数据。然而,在本案例中,我们需要预测当前的 action
,因此通过设置 endpoint
属性,确保 ts_action
中仅包含历史的 action
数据。
还可以增加专家函数以完成多时间步节点的转移函数。该方法虽然比较复杂,但是可以降低REVIVE学习转移节点的难度,提高虚拟环境模型精度。下面是一个示例:
通过yaml文件指定专家函数:
metadata:
graph:
action:
- ts_observation
next_observation:
- ts_observation
- action
next_ts_observation:
- next_observation
- ts_observation
columns:
...
nodes:
observation:
ts: 5
expert_functions:
next_ts_observation:
'node_function' : 'expert_function.next_ts_observation_func'
专家函数:
import torch
from typing import Dict
def next_ts_observation_func(data: Dict[str, torch.Tensor]) -> torch.Tensor:
obs_dim = data["next_observation"].shape[-1]
next_obs = data["next_observation"]
ts_obs = data["ts_observation"]
next_ts_obs = torch.cat([ts_obs, next_obs], axis=-1)[..., obs_dim:]
return next_ts_obs