跳转到内容

训练模型

当您准备好训练数据集(.npz.h5 文件)、决策流图描述文件(.yaml)和奖励函数(reward.py)后,就可以使用 python train.py 命令开启虚拟环境模型和策略模型训练。该脚本将实例化 revive.server.ReviveServer 并启动训练流程。

训练脚本

基本命令格式

bash
python train.py \
    -df <训练数据文件路> \
    -cf <决策流图文件路> \
    -rf <奖励函数文件路> \
    -vm <虚拟环境训练模> \
    -pm <策略模型训练模> \
    --run_id <训练实验名>

参数详解

参数描述必需性示例
-df训练数据文件路径(.npz.h5 文件)必需test_data.npz
-vf验证数据文件路径(可选)可选val_data.npz
-cf决策流图文件路径(.yaml 文件)必需test.yaml
-rf奖励函数文件路径(reward.py 文件)策略训练时必需test_reward.py
-rcf超参数配置文件(.json 文件)可选config.json
-tpn策略节点名称可选action
-vm虚拟环境训练模式必需oncetuneNone
-pm策略模型训练模式必需oncetuneNone
--run_id训练实验名称可选my_experiment

训练模式说明

虚拟环境训练模式(-vm

  • once 模式:使用默认参数训练虚拟环境模型,适合快速验证和基础应用
  • tune 模式:使用超参数搜索训练模型,需要更多算力和时间,但能获得更优的模型性能
  • None 模式:不训练虚拟环境,适用于使用已有虚拟环境进行策略训练的场景

策略模型训练模式(-pm

  • once 模式:使用默认参数训练策略模型,适合快速验证和基础应用
  • tune 模式:使用超参数搜索训练策略模型,需要更多算力和时间,但能获得更优的策略性能
  • None 模式:不训练策略模型,适用于只训练虚拟环境而不进行策略训练的场景

策略节点选择(-tpn

  • 必须是决策流图中定义的节点名称
  • 如果未指定,默认选择拓扑顺序第一位的节点作为策略节点
  • 策略节点是训练过程中需要优化的目标节点

训练脚本示例

完整训练模式

同时训练虚拟环境和策略模型

bash
# 使用默认参数训练
python train.py \
    -df test_data.npz \
    -cf test.yaml \
    -rf test_reward.py \
    -vm once \
    -pm once \
    --run_id complete_once

# 使用超参数优化训练
python train.py \
    -df test_data.npz \
    -cf test.yaml \
    -rf test_reward.py \
    -vm tune \
    -pm tune \
    --run_id complete_tune

分阶段训练模式

只训练虚拟环境模型

bash
# 使用默认参数训练虚拟环境
python train.py \
    -df test_data.npz \
    -cf test.yaml \
    -vm once \
    -pm None \
    --run_id venv_only

# 使用超参数优化训练虚拟环境
python train.py \
    -df test_data.npz \
    -cf test.yaml \
    -vm tune \
    -pm None \
    --run_id venv_tune

重要提示

奖励函数仅在训练策略模型时需要提供,训练虚拟环境时不需要。

基于已有虚拟环境训练策略模型

bash
# 使用默认参数训练策略
python train.py \
    -df test_data.npz \
    -cf test.yaml \
    -rf test_reward.py \
    -vm None \
    -pm once \
    --run_id venv_only

# 使用超参数优化训练策略
python train.py \
    -df test_data.npz \
    -cf test.yaml \
    -rf test_reward.py \
    -vm None \
    -pm tune \
    --run_id venv_only

注意事项

当单独训练策略时,REVIVE SDK 将根据 run_id 查找已完成训练的虚拟环境模型。请确保虚拟环境模型已经存在。

自定义配置训练

使用自定义超参数配置

bash
python train.py \
    -df test_data.npz \
    -cf test.yaml \
    -rf test_reward.py \
    -rcf config.json \
    --run_id custom_config

训练监控

实时监控

在训练过程中,您可以使用 TensorBoard 实时监控训练进度:

bash
# 启动 TensorBoard
tensorboard --logdir logs/<run_id>

# 在浏览器中访问 http://localhost:6006

超参数调优

REVIVE 提供了强大的超参数调优工具。使用 tune 模式时,系统将:

  1. 自动采样:从预设超参数空间中采样多组参数
  2. 并行训练:同时训练多个模型配置
  3. 性能评估:自动评估不同配置的性能
  4. 最优选择:选择性能最佳的模型配置

超参数配置

您可以通过修改 config.json 文件来调整:

  • 搜索空间:定义超参数的范围和分布
  • 资源限制:设置训练时间和计算资源限制

自定义训练

高级用法

train.py 是调用 REVIVE SDK 进行模型训练的通用脚本。您也可以基于 revive.server.ReviveServer 类开发自定义训练方法。

示例数据

示例数据文件存储在 revive/data 文件夹中,包括:

  • test_data.npz:训练数据
  • test.yaml:决策流图配置
  • test_reward.py:奖励函数
  • config.json:超参数配置

您可以直接使用这些示例文件进行训练测试。

最佳实践

1. 训练策略选择

  • 快速验证:使用 once 模式快速验证数据和方法
  • 生产部署:使用 tune 模式获得最优性能
  • 分阶段训练:先训练虚拟环境,再训练策略

2. 资源管理

  • 监控资源使用:使用 nvidia-smi 监控 GPU 使用情况
  • 合理设置批次大小:根据显存大小调整批次大小
  • 使用验证数据:提供验证数据避免过拟合

3. 调试技巧

  • 分析损失曲线:使用 TensorBoard 查看训练趋势
  • 检查数据质量:确保训练数据格式正确

通过合理配置训练参数和监控训练过程,您可以获得高质量的虚拟环境模型和策略模型。