跳转到内容

多时间步节点拼接

多时间步节点用于拼接多个历史时间步的数据作为输入,从而提升模型的精度和鲁棒性。通过使用多时间步节点,模型可以获得更丰富的输入信息,进一步增强对数据动态变化的捕捉能力及整体性能。以汽车驾驶场景为例,当前时刻的速度信息往往有限,难以全面描述车辆的状态。通过拼接多个时间步内的速度数据,不仅可以获取到速度本身的信息,还可以推导出加速度、减速度等动态特征,从而更精确地预测汽车未来的速度和行驶状态,并提升车辆控制系统的响应速度和调控精度。

因此在构建模型时,将多个历史时间步的数据拼接为输入是一种非常有效的策略。REVIVE 提供了一种简单的配置方法,方便用户快速实现多时间步节点的拼接功能。

下面是一个示例,展示如何进行多时间步节点的数据拼接:

yaml
metadata:
   graph:
     action:
     - observation
     next_observation:
     - action
     - observation
     columns:
     ...

在下面的 .yaml 文件中,我们通过添加节点信息并将 observationts 属性配置为5, 以获得 observation 节点的历史5步拼帧数据。多时间步拼接后的节点名称为 ts_observationts_observation 节点将自动将历史5时间步的 observation 数据进行拼接。

需要注意的是,REVIVE会自动检测存在 ts_前缀的节点。用户应避免在自定义节点名时使用 ts_ 前缀。此外,在上述例子中拼帧节点 ts_observation 可以类比为一个长度为5的队列, observation 在其中符合先入先出(FIFO)原则,最新的 observation 将会被拼在特征维度的最后。

yaml
metadata:
   graph:
     action:
     - ts_observation
     next_ts_observation:
     - action
     - ts_observation
   columns:
   ...

   nodes:
     observation:
       ts: 5

在一些特殊的任务中, action 节点的输出需要考虑历史 action的影响。因此我们也需要对 action 节点进行拼帧操作:

yaml
metadata:
   graph:
     action:
     - ts_action
     - ts_observation
     next_ts_action:
     - action
     - ts_action
     next_ts_observation:
     - next_ts_action
     - ts_observation
   columns:
   ...

   nodes:
     observation:
       ts: 5
       ts_repeat: false
     action:
       ts: 5
       endpoint: 5
       ts_repeat: false

在上面的示例中,我们新增了一个 ts_action节点。需要注意的是,我们同时设置了 endpoint 属性。 endpoint用于控制 ts节点中数据的结束位置。默认情况下,如果不设置endpoint,系统会自动使用 ts+1来获取历史拼帧时间步的数据以及当前时间步的数据。然而,在本案例中,我们需要预测当前的 action,因此通过设置 endpoint属性,确保 ts_action中仅包含历史的 action数据。

还可以增加专家函数以完成多时间步节点的转移函数。该方法虽然比较复杂,但是可以降低REVIVE学习转移节点的难度,提高虚拟环境模型精度。下面是一个示例:

通过yaml文件指定专家函数:

yaml
metadata:
  graph:
    action:
    - ts_observation
    next_observation:
    - ts_observation
    - action
    next_ts_observation:
    - next_observation
    - ts_observation

  columns:
  ...

  nodes:
    observation:
      ts: 5

  expert_functions:
    next_ts_observation:
      'node_function' : 'expert_function.next_ts_observation_func'

专家函数:

python
import torch
from typing import Dict

def next_ts_observation_func(data: Dict[str, torch.Tensor]) -> torch.Tensor:
   obs_dim = data["next_observation"].shape[-1]
   next_obs = data["next_observation"]
   ts_obs = data["ts_observation"]
   next_ts_obs = torch.cat([ts_obs, next_obs], axis=-1)[..., obs_dim:]

   return next_ts_obs