跳转到内容

自定义网络节点

REVIVE 采用神经网络进行建模,默认集成了 MLP、ResNet、GRU、LSTM 和 Transformer 模型。然而,在实际应用中,用户可能希望针对任务需求,对决策流图中的某些网络节点进行定制。为满足这一需求,REVIVE 特别支持用户自定义网络节点设置。

下面以倒立摆控制任务为例,展示如何为摆杆任务的 动作节点 来增加自定义网络节点:

1. 首先需要在 .yaml 文件中配置 custom_nodes节点

yaml
metadata:
  graph:
    actions:
      - states
    next_states:
      - states
      - actions
  columns:
    - obs_states_0:
        dim: states
        type: continuous
    - obs_states_1:
        dim: states
        type: continuous
    - obs_states_2:
        dim: states
        type: continuous
    - action:
        dim: actions
        type: continuous

  custom_nodes:
    actions: 'custom_node_file.ActionNode'

自定义网络节点的配置通常是 <node_name>: <filename>.<node_class_name> 。 上述的 .yaml 文件为 actions 节点使用自定义网络节点 ActionNode,自定义网络节点应该定义在同级目录的 custom_node_file.py 文件中。 此时 data 目录结构如下:

    data/
    |-- Env-GAIL-pendulum_custom_node.yaml
    |-- config.json
    |-- custom_node_file.py
    |-- expert_data.npz
    `-- pendulum-reward.py

2. 编辑自定义网络节点文件

我们还需要编辑 custom_node_file.py 文件,如下所示:

python
import torch
from torch import nn
from revive.computation.modules import DistributionWrapper, ReviveDistribution
from revive.computation.graph import NetworkDecisionNode


class Net(torch.nn.Module):
    def __init__(self,
                in_features : dict,
                out_features : int,
                hidden_features : int,
                hidden_layers : int,
                dist_config : list):
        super().__init__()
        """
        in_features: Dict.
            - key for input_names, value for dimension of corresponding input.
            - In general: in_features = {'obs_node_1': obs_node_1_dim,
                                         'obs_node_2': obs_node_2_dim, ...}
            - e.g. In Pendulum example: in_features = {'states': 3}
            - It enables you to access to dimensions of each input.
        """

        # ====================== Edit node network in here ======================
        in_features_dims = sum(in_features.values())
        net = []
        for i in range(hidden_layers):
            net.append(nn.Linear(in_features_dims if i == 0 \
                                 else hidden_features, hidden_features))
            net.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))
        net.append(nn.Linear(hidden_features, out_features))
        net.append(nn.Identity())
        self.net = nn.Sequential(*net)
        # =================================================================

        # Do not modify the code below!
        self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)

    def forward(self,
                state : dict,
                adapt_std : torch.Tensor = None,
                **kwargs) -> ReviveDistribution:
        #========== Edit forward method in here =========
        x = state['states']
        output = self.net(x)
        # ===============================================

        # Do not modify the code below!
        dist = self.dist_wrapper(output, adapt_std)

        return dist

    # used to reset necessary network variables: e.g. RNN hidden states
    def reset(self):
        pass


class ActionNode(NetworkDecisionNode):
    custom_node = True  # Do NOT modify it !

    def initialize_network(self,
                          inputs: dict,
                          output_dim: int,
                          dist_config: list,
                          *args, **kwargs):

        #========== Edit initialize_network method in here =========
        self.network = Net(inputs, output_dim,
                            hidden_features=256,
                            hidden_layers=2,
                            dist_config=dist_config)
        #===========================================================

在编辑自定义网络节点文件时,需要尤为注意。通常我们会定义 2 个类: 网络类节点类。 如上所示, Net 类主要实现 网络结构forward, reset 方法。 ActionNode 类主要实现 网络定义 以及其他维护节点的方法。

需要注意的是:

  1. 网络结构中 in_featurepython 字典类型key 代表输入节点的名字, value 代表输入节点的维度。比如倒立摆控制任务 例子中:in_features = {'states': 3}。这样我们可以针对不同输入节点,做不同处理。
  2. 网络结构中 dist_wrapper 是为了把输出包装成 Distribution,这是 REVIVE 的内部结构, 用户不要更改!同样,在 forward 方法中, 用户也不要更改 dist 相关部分。
  3. reset 方法是为了必要时候重置网络中的某些变量,比如 RNN 的隐变量。
  4. 节点类 中用户需要增加额外属性 custom_node=True。只有这样,inputs 才是 python 字典类型

重要提示

更多信息用户可以参考源码中的 revive/revive/computation/modules.py 文件。