训练模型
当您准备好训练数据集(.npz
或 .h5
文件)、决策流图描述文件(.yaml
)和奖励函数(reward.py
)后,就可以使用 python train.py
命令开启虚拟环境模型和策略模型训练。该脚本将实例化 revive.server.ReviveServer
并启动训练流程。
训练脚本
基本命令格式
bash
python train.py \
-df <训练数据文件路径> \
-cf <决策流图文件路径> \
-rf <奖励函数文件路径> \
-vm <虚拟环境训练模式> \
-pm <策略模型训练模式> \
--run_id <训练实验名称>
参数详解
参数 | 描述 | 必需性 | 示例 |
---|---|---|---|
-df | 训练数据文件路径(.npz 或 .h5 文件) | 必需 | test_data.npz |
-vf | 验证数据文件路径(可选) | 可选 | val_data.npz |
-cf | 决策流图文件路径(.yaml 文件) | 必需 | test.yaml |
-rf | 奖励函数文件路径(reward.py 文件) | 策略训练时必需 | test_reward.py |
-rcf | 超参数配置文件(.json 文件) | 可选 | config.json |
-tpn | 策略节点名称 | 可选 | action |
-vm | 虚拟环境训练模式 | 必需 | once 、tune 、None |
-pm | 策略模型训练模式 | 必需 | once 、tune 、None |
--run_id | 训练实验名称 | 可选 | my_experiment |
训练模式说明
虚拟环境训练模式(-vm
)
once
模式:使用默认参数训练虚拟环境模型,适合快速验证和基础应用tune
模式:使用超参数搜索训练模型,需要更多算力和时间,但能获得更优的模型性能None
模式:不训练虚拟环境,适用于使用已有虚拟环境进行策略训练的场景
策略模型训练模式(-pm
)
once
模式:使用默认参数训练策略模型,适合快速验证和基础应用tune
模式:使用超参数搜索训练策略模型,需要更多算力和时间,但能获得更优的策略性能None
模式:不训练策略模型,适用于只训练虚拟环境而不进行策略训练的场景
策略节点选择(-tpn
)
- 必须是决策流图中定义的节点名称
- 如果未指定,默认选择拓扑顺序第一位的节点作为策略节点
- 策略节点是训练过程中需要优化的目标节点
训练脚本示例
完整训练模式
同时训练虚拟环境和策略模型
bash
# 使用默认参数训练
python train.py \
-df test_data.npz \
-cf test.yaml \
-rf test_reward.py \
-vm once \
-pm once \
--run_id complete_once
# 使用超参数优化训练
python train.py \
-df test_data.npz \
-cf test.yaml \
-rf test_reward.py \
-vm tune \
-pm tune \
--run_id complete_tune
分阶段训练模式
只训练虚拟环境模型
bash
# 使用默认参数训练虚拟环境
python train.py \
-df test_data.npz \
-cf test.yaml \
-vm once \
-pm None \
--run_id venv_only
# 使用超参数优化训练虚拟环境
python train.py \
-df test_data.npz \
-cf test.yaml \
-vm tune \
-pm None \
--run_id venv_tune
重要提示
奖励函数仅在训练策略模型时需要提供,训练虚拟环境时不需要。
基于已有虚拟环境训练策略模型
bash
# 使用默认参数训练策略
python train.py \
-df test_data.npz \
-cf test.yaml \
-rf test_reward.py \
-vm None \
-pm once \
--run_id venv_only
# 使用超参数优化训练策略
python train.py \
-df test_data.npz \
-cf test.yaml \
-rf test_reward.py \
-vm None \
-pm tune \
--run_id venv_only
注意事项
当单独训练策略时,REVIVE SDK 将根据 run_id
查找已完成训练的虚拟环境模型。请确保虚拟环境模型已经存在。
自定义配置训练
使用自定义超参数配置
bash
python train.py \
-df test_data.npz \
-cf test.yaml \
-rf test_reward.py \
-rcf config.json \
--run_id custom_config
训练监控
实时监控
在训练过程中,您可以使用 TensorBoard 实时监控训练进度:
bash
# 启动 TensorBoard
tensorboard --logdir logs/<run_id>
# 在浏览器中访问 http://localhost:6006
超参数调优
REVIVE 提供了强大的超参数调优工具。使用 tune
模式时,系统将:
- 自动采样:从预设超参数空间中采样多组参数
- 并行训练:同时训练多个模型配置
- 性能评估:自动评估不同配置的性能
- 最优选择:选择性能最佳的模型配置
超参数配置
您可以通过修改 config.json
文件来调整:
- 搜索空间:定义超参数的范围和分布
- 资源限制:设置训练时间和计算资源限制
自定义训练
高级用法
train.py
是调用 REVIVE SDK 进行模型训练的通用脚本。您也可以基于 revive.server.ReviveServer
类开发自定义训练方法。
示例数据
示例数据文件存储在 revive/data
文件夹中,包括:
test_data.npz
:训练数据test.yaml
:决策流图配置test_reward.py
:奖励函数config.json
:超参数配置
您可以直接使用这些示例文件进行训练测试。
最佳实践
1. 训练策略选择
- 快速验证:使用
once
模式快速验证数据和方法 - 生产部署:使用
tune
模式获得最优性能 - 分阶段训练:先训练虚拟环境,再训练策略
2. 资源管理
- 监控资源使用:使用
nvidia-smi
监控 GPU 使用情况 - 合理设置批次大小:根据显存大小调整批次大小
- 使用验证数据:提供验证数据避免过拟合
3. 调试技巧
- 分析损失曲线:使用 TensorBoard 查看训练趋势
- 检查数据质量:确保训练数据格式正确
通过合理配置训练参数和监控训练过程,您可以获得高质量的虚拟环境模型和策略模型。