跳转到内容

自定义节点的损失函数

REVIVE在训练虚拟环境时,支持为每个节点配置自定义的损失函数以约束学习到的节点模型。下面展示一个示例通过自定义模块为节点配置自定义损失函数。

首先需要通过自定义模块文件( user_module.py )定义损失函数。

python
import torch
from typing import Dict

def mae_loss(kwargs) -> torch.Tensor:
    # 获得当前节点名
    node_name = kwargs["node_name"]
    # 节点网络输出的分布
    node_dist = kwargs["node_dist"]
    # 当前节点对应的专家数据
    expert_data = kwargs["expert_data"]
    # 获得决策流图
    graph = kwargs["graph"]

    # get network output data -> node_dist.mode
    # get node expert data -> expert_data[node_name]
    # reverse normalization data -> graph.nodes[node_name].processor.deprocess_torch({node_name:expert_data[node_name]})
    policy_loss = (node_dist.mode - expert_data[node_name]).abs().sum(dim=-1).mean()

    return policy_loss

然后在 .yaml 文件中配置节点要使用的损失函数,配置的函数名 = user_module. + 自定义模块中定义的函数名称

yaml
metadata:
  graph:
  act:
    - obs
  next_obs:
    - obs
    - act
  rew:
    - obs
    - act
    - next_obs
  expert_functions:
    rew:
      'node_function': 'expert_function.reward_node_function'
  nodes:
    act:
      loss_type: 'user_module.mae_loss'

最后在训练时需要使用 -umf 参数指定自定义模块文件,示例如下:

bash
python train.py \
    -df test_data.npz \
    -cf test.yaml \
    -umf user_module.py \
    -rf test_reward.py \
    -vm once \
    -pm once \
    --run_id once