跳转到内容

冻结节点网络参数

在训练过程中,可以冻结某些节点的网络参数不进行梯度更新,这在存在有预训练环境模型时非常有用。

例如在机器人控制任务上,假设我们已经在一个A型号的机器人任务上训练完成了一个较好的虚拟环境模型,现在我们希望快速将A型号的虚拟环境模型迁移到新的B型号机器人任务上,两个型号的机器人系统共用了大部分相同的机械结构,但是也存在一些不同。这时我们可以将A型号机器人的虚拟环境模型作为预训练模型,然后通过参数冻结功能将模型中与A型号机器人相同的部分保护起来,这样训练过程中就不会破坏这些节点的参数设置了。然后,我们可以在B型号机器人上继续训练,只对那些与A型号机器人不同的部分进行微调或者重新学习。这样做的好处是可以节省训练时间和计算资源,同时也可以在一定程度上避免过拟合,因为预训练模型已经学会了一些通用的特征,可以避免从头开始训练过程中的过拟合现象。同时,由于模型已经预训练过,所以迁移学习过程中需要拟合的新数据集的大小不需要太大。

需要注意的是,在进行参数冻结时,我们需要根据具体问题自行决定哪些节点需要被冻结,哪些节点需要重新训练。如果冻结太多的节点,那么新任务的机器人可能无法得到足够的学习;如果冻结太少的节点,那么新任务的机器人可能无法发挥预训练模型的优势,训练时间也会很长。因此,在具体操作时需要合理权衡。

在下面的 .yaml 示例中,我们将 action 节点的 freeze 属性配置为 true 。在训练期间, action 节点的网络参数将不会更新。

yaml
metadata:

   graph:
     action:
     - observation
     next_observation:
     - action
     - observation

   columns:
   ...

   nodes:
     action:
       freeze: true

重要提示

不能同时冻结所有网络节点。至少需要存在一个网络节点可以进行梯度更新,否则REVIVE将报错。