跳转到内容

训练模型

当我们准备好训练数据集( .npz.h5 文件)、决策流图描述文件( .yaml) 和奖励函数( reward.py )后。 我们可以使用 python train.py命令开启虚拟环境模型和策略模型训练。该脚本将实例化revive.server.ReviveServer 并开启训练。

训练脚本

bash
python train.py \
    -df <训练数据文件路> \
    -cf <决策流图文件路> \
    -rf <奖励函数文件路> \
    -vm <训练虚拟环境模> \
    -pm <训练策略模型模> \
    --run_id <训练实验名>

运行 train.py 脚本可定义的命令行参数如下:

  • -df: 训练数据的文件路径( .npz.h5 文件)。
  • -vf: 验证数据的文件路径(可选)。
  • -cf: 决策流图的文件路径( .yaml )。
  • -rf: 定义的奖励函数的文件路径( reward.py )(仅在训练策略时需要)。
  • -rcf: 支持进行超参配置的 .json 文件(可选)。
  • -tpn: 策略节点的名称。必须是决策流图中定义的节点;如果未指定,在默认情况下,排在拓扑顺序第一位的节点将作为策略节点。
  • -vm: 训练虚拟环境的不同模式, 包括: once,tune,None
    • once 模式: REVIVE SDK将使用默认参数训练模型。
    • tune 模式: REVIVE SDK将使用超参数搜索来训练模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。
    • None 模式: REVIVE SDK不会训练虚拟环境,它适用于调用已有虚拟环境进行策略训练。
  • -pm: 策略模型的训练模式, 包括: once,tune,None
    • once 模式: REVIVE SDK将使用默认参数训练模型。
    • tune 模式: REVIVE SDK将使用超参数搜索来训练模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。
    • None 模式: REVIVE SDK不会训练策略,它适用于只训练虚拟环境而不进行策略训练的情况。
  • --run_id: 用户为训练实验提供的名称。REVIVE将创建 logs/<run_id> 作为日志目录。如果未提供,REVIVE将随机生成名称。

训练脚本示例

训练虚拟环境和策略模型

python
# 使用once模式训练环境模型和策略模型
python train.py \
    -df test_data.npz \
    -cf test.yaml \
    -rf test_reward.py \
    -vm once \
    -pm once \
    --run_id once

# 使用tune模式训练环境模型和策略模型
python train.py \
    -df test_data.npz \
    -cf test.yaml \
    -rf test_reward.py \
    -vm tune \
    -pm tune \
    --run_id tune

只训练虚拟环境

python
# 使用once模式只训练环境模型
python train.py \
    -df test_data.npz \
    -cf test.yaml \
    -vm once \
    -pm None \
    --run_id venv_once

# 使用tune模式只训练环境模型
python train.py \
    -df test_data.npz \
    -cf test.yaml \
    -vm tune \
    -pm None \
    --run_id venv_tune

重要提示

需要定义的奖励函数仅在策略训练的时候需要被提供。

在已有的环境模型基础上训练策略模型

重要提示

当单独训练策略时,REVIVE SDK将根据索引 run_id查找完成训练的虚拟环境模型。

bash
# 使用once模式训练环境模型和策略模型
python train.py \
    -df test_data.npz \
    -cf test.yaml \
    -rf test_reward.py \
    -vm None \
    -pm once \
    --run_id venv_once

# 使用once模式训练环境模型和策略模型
python train.py \
    -df test_data.npz \
    -cf test.yaml \
    -rf test_reward.py \
    -vm None \
    -pm tune \
    --run_id venv_tune

重要提示

train.py 是一个调用REVIVE SDK进行模型训练的通用脚本。 我们也可以基于recovery.server.ReviveServer 类进行自定义训练方法的开发。

请参阅API中的 revive.server.ReviveServerREVIVE API <../revive_server_cn> )了解更多信息。

在训练过程中,我们可以随时使用tensorboard打开日志目录以监控训练过程。当REVIVESDK完成虚拟环境模型训练和策略模型训练后。 我们可以在日志文件夹(logs/<run_id>)下找到保存的模型( .pkl.onnx)。

REVIVE提供了一个超参数调优工具,使用 tune模式开始训练将切换到超参调优模式。在该模式下,REVIVE将从预设超参数空间中采样多组超参数用于模型训练,通常使用超参调优模式可以获得更好的模型。有关超参调优模式的详细说明,请参阅文档中的REVIVEAPI(revive.conf <../revive_conf_cn>)部分。 我们也可以通过修改 config.json中的相关配置来调整超参搜索空间和搜索方法。

bash
python train.py \
    -df test_data.npz \
    -cf test.yaml \
    -rf test_reward.py \
    -rcf config.json \
    --run_id test

重要提示

示例数据 test_data.npz, test.yaml, test_reward.pyconfig.json存储在 revive/data 文件夹中。