训练模型
当我们准备好训练数据集( .npz
或 .h5
文件)、决策流图描述文件( .yaml
) 和奖励函数( reward.py
)后。 我们可以使用 python train.py
命令开启虚拟环境模型和策略模型训练。该脚本将实例化revive.server.ReviveServer
并开启训练。
训练脚本
python train.py \
-df <训练数据文件路径> \
-cf <决策流图文件路径> \
-rf <奖励函数文件路径> \
-vm <训练虚拟环境模式> \
-pm <训练策略模型模式> \
--run_id <训练实验名称>
运行 train.py
脚本可定义的命令行参数如下:
- -df: 训练数据的文件路径(
.npz
或.h5
文件)。 - -vf: 验证数据的文件路径(可选)。
- -cf: 决策流图的文件路径(
.yaml
)。 - -rf: 定义的奖励函数的文件路径(
reward.py
)(仅在训练策略时需要)。 - -rcf: 支持进行超参配置的
.json
文件(可选)。 - -tpn: 策略节点的名称。必须是决策流图中定义的节点;如果未指定,在默认情况下,排在拓扑顺序第一位的节点将作为策略节点。
- -vm: 训练虚拟环境的不同模式, 包括:
once
,tune
,None
。once
模式: REVIVE SDK将使用默认参数训练模型。tune
模式: REVIVE SDK将使用超参数搜索来训练模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。None
模式: REVIVE SDK不会训练虚拟环境,它适用于调用已有虚拟环境进行策略训练。
- -pm: 策略模型的训练模式, 包括:
once
,tune
,None
。once
模式: REVIVE SDK将使用默认参数训练模型。tune
模式: REVIVE SDK将使用超参数搜索来训练模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。None
模式: REVIVE SDK不会训练策略,它适用于只训练虚拟环境而不进行策略训练的情况。
- --run_id: 用户为训练实验提供的名称。REVIVE将创建
logs/<run_id>
作为日志目录。如果未提供,REVIVE将随机生成名称。
训练脚本示例
训练虚拟环境和策略模型
# 使用once模式训练环境模型和策略模型
python train.py \
-df test_data.npz \
-cf test.yaml \
-rf test_reward.py \
-vm once \
-pm once \
--run_id once
# 使用tune模式训练环境模型和策略模型
python train.py \
-df test_data.npz \
-cf test.yaml \
-rf test_reward.py \
-vm tune \
-pm tune \
--run_id tune
只训练虚拟环境
# 使用once模式只训练环境模型
python train.py \
-df test_data.npz \
-cf test.yaml \
-vm once \
-pm None \
--run_id venv_once
# 使用tune模式只训练环境模型
python train.py \
-df test_data.npz \
-cf test.yaml \
-vm tune \
-pm None \
--run_id venv_tune
重要提示
需要定义的奖励函数仅在策略训练的时候需要被提供。
在已有的环境模型基础上训练策略模型
重要提示
当单独训练策略时,REVIVE SDK将根据索引 run_id
查找完成训练的虚拟环境模型。
# 使用once模式训练环境模型和策略模型
python train.py \
-df test_data.npz \
-cf test.yaml \
-rf test_reward.py \
-vm None \
-pm once \
--run_id venv_once
# 使用once模式训练环境模型和策略模型
python train.py \
-df test_data.npz \
-cf test.yaml \
-rf test_reward.py \
-vm None \
-pm tune \
--run_id venv_tune
重要提示
train.py
是一个调用REVIVE SDK进行模型训练的通用脚本。 我们也可以基于recovery.server.ReviveServer
类进行自定义训练方法的开发。
请参阅API中的 revive.server.ReviveServer
(REVIVE API <../revive_server_cn>
)了解更多信息。
在训练过程中,我们可以随时使用tensorboard打开日志目录以监控训练过程。当REVIVESDK完成虚拟环境模型训练和策略模型训练后。 我们可以在日志文件夹(logs/<run_id>
)下找到保存的模型( .pkl
或 .onnx
)。
REVIVE提供了一个超参数调优工具,使用 tune
模式开始训练将切换到超参调优模式。在该模式下,REVIVE将从预设超参数空间中采样多组超参数用于模型训练,通常使用超参调优模式可以获得更好的模型。有关超参调优模式的详细说明,请参阅文档中的REVIVEAPI(revive.conf <../revive_conf_cn>
)部分。 我们也可以通过修改 config.json
中的相关配置来调整超参搜索空间和搜索方法。
python train.py \
-df test_data.npz \
-cf test.yaml \
-rf test_reward.py \
-rcf config.json \
--run_id test
重要提示
示例数据 test_data.npz
, test.yaml
, test_reward.py
和 config.json
存储在 revive/data
文件夹中。