使用模型
当 REVIVE SDK 完成虚拟环境模型和策略模型训练后,您可以在日志文件夹(logs/<run_id>
)下找到保存的模型文件(.pkl
或 .onnx
格式)。这些模型可以用于推理和部署。
模型文件说明
训练完成后,您将获得以下模型文件:
env.pkl
/env.onnx
:虚拟环境模型,用于模拟环境状态转移policy.pkl
/policy.onnx
:策略模型,用于生成最优决策动作
使用 PKL 模型
虚拟环境模型推理
虚拟环境模型被序列化为 env.pkl
文件。使用 pickle
加载模型后,可以通过 venv.infer_one_step()
或 venv.infer_k_steps()
函数进行推理。
python
import os
import pickle
import numpy as np
# 获取虚拟环境模型文件路径
venv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"logs/run_id", "env.pkl")
# 加载虚拟环境模型
with open(venv_path, 'rb') as f:
venv = pickle.load(f, encoding='utf-8')
# 准备输入状态数据
state = {
"states": np.array([-0.5, 0.5, 0.2]),
"actions": np.array([1.0])
}
# 单时间步推理
output = venv.infer_one_step(state)
print("虚拟环境模型单步输出:", output)
# K个时间步推理(返回长度为k的列表)
output = venv.infer_k_steps(state, k=3)
print("虚拟环境模型多步输出:", output)
推理参数说明
虚拟环境模型推理支持以下可选参数:
deterministic
:控制输出确定性True
(默认):返回最可能的输出False
:根据模型概率分布进行采样
clip
:控制输出裁剪True
(默认):将输出裁剪到有效范围内False
:不进行裁剪- 裁剪范围基于 YAML 配置文件,未配置时自动从数据中计算
策略模型推理
策略模型被序列化为 policy.pkl
文件。使用 pickle
加载模型后,可以通过 policy.infer()
函数进行推理。
python
import os
import pickle
import numpy as np
# 获取策略模型文件路径
policy_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"logs/run_id", "policy.pkl")
# 加载策略模型
with open(policy_path, 'rb') as f:
policy = pickle.load(f, encoding='utf-8')
# 准备输入状态数据
state = {"states": np.array([-0.5, 0.5, 0.2])}
print("策略模型输入状态:", state)
# 策略推理,输出动作
action = policy.infer(state)
print("策略模型输出动作:", action)
推理参数说明
策略模型推理同样支持以下可选参数:
deterministic
:控制输出确定性True
(默认):返回最可能的策略动作False
:根据策略模型概率分布进行采样
clip
:控制动作裁剪True
(默认):将动作裁剪到动作空间有效范围内False
:不进行裁剪- 裁剪范围基于 YAML 配置文件,未配置时自动从数据中计算
使用 ONNX 模型
虚拟环境模型和策略模型也会被保存为 .onnx
格式,便于跨平台部署和集成。
虚拟环境模型(ONNX)
python
import os
import onnxruntime
import numpy as np
# 获取ONNX模型路径
venv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"logs/run_id", "env.onnx")
# 创建推理会话
venv = onnxruntime.InferenceSession(venv_path)
# 准备输入数据(支持灵活的batch_size)
# 确保数据类型为float32
venv_input = {
'temperature': np.array([0.5, 0.4, 0.3], dtype=np.float32).reshape(3, -1),
'door_open': np.array([1.0, 0.0, 1.0], dtype=np.float32).reshape(3, -1)
}
# 指定输出节点名称
venv_output_names = ["action", "next_temperature"]
# 执行推理(类似于pkl模型的venv.infer_one_step())
# 输出按venv_output_names顺序存储在数组中
output = venv.run(input_feed=venv_input, output_names=venv_output_names)
print("ONNX虚拟环境模型输出:", output)
策略模型(ONNX)
python
import os
import onnxruntime
import numpy as np
# 获取ONNX模型路径
policy_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"logs/run_id", "policy.onnx")
# 创建推理会话
policy = onnxruntime.InferenceSession(policy_path)
# 准备输入数据
policy_input = {
'temperature': np.array([0.5, 0.4, 0.3], dtype=np.float32).reshape(3, -1)
}
# 指定输出节点名称
policy_output_names = ["action"]
# 执行推理(类似于pkl模型的policy.infer())
output = policy.run(input_feed=policy_input, output_names=policy_output_names)
print("ONNX策略模型输出:", output)
模型格式对比
特性 | PKL 格式 | ONNX 格式 |
---|---|---|
加载速度 | 较快 | 较慢 |
文件大小 | 较大 | 较小 |
跨平台性 | 有限 | 优秀 |
部署便利性 | 一般 | 优秀 |
推理性能 | 优秀 | 良好 |
功能完整性 | 完整 | 基础 |
最佳实践
1. 模型选择建议
- 开发调试:使用 PKL 格式,功能完整且易于调试
- 生产部署:使用 ONNX 格式,便于跨平台部署
- 性能要求高:优先考虑 PKL 格式
2. 推理优化
- 批量处理:尽可能使用批量推理提高效率
- 数据类型:确保输入数据为正确的数据类型(float32)
- 内存管理:及时释放不需要的模型对象
3. 错误处理
python
try:
# 加载模型
with open(model_path, 'rb') as f:
model = pickle.load(f, encoding='utf-8')
# 执行推理
result = model.infer(input_data)
except FileNotFoundError:
print("模型文件不存在,请检查路径")
except Exception as e:
print(f"推理过程中发生错误: {e}")
4. 性能监控
python
import time
# 记录推理时间
start_time = time.time()
result = model.infer(input_data)
end_time = time.time()
print(f"推理耗时: {end_time - start_time:.4f} 秒")
部署建议
本地部署
- 使用 PKL 格式进行快速原型验证
- 确保环境依赖完整
云端部署
- 使用 ONNX 格式便于容器化部署
- 考虑使用 TensorRT 等推理引擎优化性能
边缘部署
- ONNX 格式更适合资源受限的环境
- 考虑模型量化和压缩
通过合理选择模型格式和优化推理流程,您可以高效地将训练好的模型应用到实际业务场景中。