跳转到内容

使用模型

当 REVIVE SDK 完成虚拟环境模型和策略模型训练后,您可以在日志文件夹(logs/<run_id>)下找到保存的模型文件(.pkl.onnx 格式)。这些模型可以用于推理和部署。

模型文件说明

训练完成后,您将获得以下模型文件:

  • env.pkl / env.onnx:虚拟环境模型,用于模拟环境状态转移
  • policy.pkl / policy.onnx:策略模型,用于生成最优决策动作

使用 PKL 模型

虚拟环境模型推理

虚拟环境模型被序列化为 env.pkl 文件。使用 pickle 加载模型后,可以通过 venv.infer_one_step()venv.infer_k_steps() 函数进行推理。

python
import os
import pickle
import numpy as np

# 获取虚拟环境模型文件路径
venv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                        "logs/run_id", "env.pkl")

# 加载虚拟环境模型
with open(venv_path, 'rb') as f:
    venv = pickle.load(f, encoding='utf-8')

# 准备输入状态数据
state = {
    "states": np.array([-0.5, 0.5, 0.2]),
    "actions": np.array([1.0])
}

# 单时间步推理
output = venv.infer_one_step(state)
print("虚拟环境模型单步输出:", output)

# K个时间步推理(返回长度为k的列表)
output = venv.infer_k_steps(state, k=3)
print("虚拟环境模型多步输出:", output)

推理参数说明

虚拟环境模型推理支持以下可选参数:

  • deterministic:控制输出确定性
    • True(默认):返回最可能的输出
    • False:根据模型概率分布进行采样
  • clip:控制输出裁剪
    • True(默认):将输出裁剪到有效范围内
    • False:不进行裁剪
    • 裁剪范围基于 YAML 配置文件,未配置时自动从数据中计算

策略模型推理

策略模型被序列化为 policy.pkl 文件。使用 pickle 加载模型后,可以通过 policy.infer() 函数进行推理。

python
import os
import pickle
import numpy as np

# 获取策略模型文件路径
policy_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                          "logs/run_id", "policy.pkl")

# 加载策略模型
with open(policy_path, 'rb') as f:
    policy = pickle.load(f, encoding='utf-8')

# 准备输入状态数据
state = {"states": np.array([-0.5, 0.5, 0.2])}
print("策略模型输入状态:", state)

# 策略推理,输出动作
action = policy.infer(state)
print("策略模型输出动作:", action)

推理参数说明

策略模型推理同样支持以下可选参数:

  • deterministic:控制输出确定性
    • True(默认):返回最可能的策略动作
    • False:根据策略模型概率分布进行采样
  • clip:控制动作裁剪
    • True(默认):将动作裁剪到动作空间有效范围内
    • False:不进行裁剪
    • 裁剪范围基于 YAML 配置文件,未配置时自动从数据中计算

使用 ONNX 模型

虚拟环境模型和策略模型也会被保存为 .onnx 格式,便于跨平台部署和集成。

虚拟环境模型(ONNX)

python
import os
import onnxruntime
import numpy as np

# 获取ONNX模型路径
venv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                        "logs/run_id", "env.onnx")

# 创建推理会话
venv = onnxruntime.InferenceSession(venv_path)

# 准备输入数据(支持灵活的batch_size)
# 确保数据类型为float32
venv_input = {
    'temperature': np.array([0.5, 0.4, 0.3], dtype=np.float32).reshape(3, -1),
    'door_open': np.array([1.0, 0.0, 1.0], dtype=np.float32).reshape(3, -1)
}

# 指定输出节点名称
venv_output_names = ["action", "next_temperature"]

# 执行推理(类似于pkl模型的venv.infer_one_step())
# 输出按venv_output_names顺序存储在数组中
output = venv.run(input_feed=venv_input, output_names=venv_output_names)
print("ONNX虚拟环境模型输出:", output)

策略模型(ONNX)

python
import os
import onnxruntime
import numpy as np

# 获取ONNX模型路径
policy_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                          "logs/run_id", "policy.onnx")

# 创建推理会话
policy = onnxruntime.InferenceSession(policy_path)

# 准备输入数据
policy_input = {
    'temperature': np.array([0.5, 0.4, 0.3], dtype=np.float32).reshape(3, -1)
}

# 指定输出节点名称
policy_output_names = ["action"]

# 执行推理(类似于pkl模型的policy.infer())
output = policy.run(input_feed=policy_input, output_names=policy_output_names)
print("ONNX策略模型输出:", output)

模型格式对比

特性PKL 格式ONNX 格式
加载速度较快较慢
文件大小较大较小
跨平台性有限优秀
部署便利性一般优秀
推理性能优秀良好
功能完整性完整基础

最佳实践

1. 模型选择建议

  • 开发调试:使用 PKL 格式,功能完整且易于调试
  • 生产部署:使用 ONNX 格式,便于跨平台部署
  • 性能要求高:优先考虑 PKL 格式

2. 推理优化

  • 批量处理:尽可能使用批量推理提高效率
  • 数据类型:确保输入数据为正确的数据类型(float32)
  • 内存管理:及时释放不需要的模型对象

3. 错误处理

python
try:
    # 加载模型
    with open(model_path, 'rb') as f:
        model = pickle.load(f, encoding='utf-8')

    # 执行推理
    result = model.infer(input_data)

except FileNotFoundError:
    print("模型文件不存在,请检查路径")
except Exception as e:
    print(f"推理过程中发生错误: {e}")

4. 性能监控

python
import time

# 记录推理时间
start_time = time.time()
result = model.infer(input_data)
end_time = time.time()

print(f"推理耗时: {end_time - start_time:.4f} 秒")

部署建议

本地部署

  • 使用 PKL 格式进行快速原型验证
  • 确保环境依赖完整

云端部署

  • 使用 ONNX 格式便于容器化部署
  • 考虑使用 TensorRT 等推理引擎优化性能

边缘部署

  • ONNX 格式更适合资源受限的环境
  • 考虑模型量化和压缩

通过合理选择模型格式和优化推理流程,您可以高效地将训练好的模型应用到实际业务场景中。