跳到主要内容

PyTorch 模型更新

pytorch_model_update.py 示例演示了如何使用 OutputModel 类训练模型并进行日志记录。

该示例执行以下操作

  • examples 项目中创建名为 Model update pytorch 的任务。
  • 在 CIFAR10 数据集上训练用于图像分类的神经网络。
  • 使用 OutputModel 对象记录模型、其标签枚举和配置字典。
禁用自动框架日志记录

该示例禁用默认的 PyTorch 输出自动捕获,以演示如何手动控制从 PyTorch 记录的内容。有关更多信息,请参阅此常见问题解答

初始化

为任务实例化一个 OutputModel 对象。

from clearml import Task, OutputModel

task = Task.init(
project_name="examples",
task_name="Model update pytorch",
auto_connect_frameworks={"pytorch": False}
)

output_model = OutputModel(task=task)

标签枚举

使用 Task.connect_label_enumeration 方法记录标签枚举字典,这将更新任务的结果模型信息。当前运行的任务通过 Task.current_task 类方法访问。

# store the label enumeration of the training model
classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck",)
enumeration = {k: v for v, k in enumerate(classes, 1)}
Task.current_task().connect_label_enumeration(enumeration)
直接设置模型枚举

您可以使用 OutputModel.update_labels 方法直接设置模型的标签枚举。

模型配置

使用 OutputModel.update_design 方法向模型添加配置字典。

model_config_dict = {
"list_of_ints": [1, 2, 3, 4],
"dict": {
"sub_value": "string",
"sub_integer": 11
},
"value": 13.37
}

model.update_design(config_dict=model_config_dict)

更新模型

要更新模型,请使用 OutputModel.update_weights()。这将模型上传到设置的存储目的地(请参阅设置上传目的地),并将该位置注册为任务的输出模型。

# CONDITION depicts a custom condition for when to save the model. The model is saved and then updated in ClearML
CONDITION = True

if CONDITION:
torch.save(net.state_dict(), PATH)
model.update_weights(weights_filename=PATH)

Web UI

模型显示在任务的ARTIFACTS选项卡中。

Task artifacts Task artifacts

点击模型名称会进入模型页面,您可以在其中查看模型的详细信息并访问模型。

Model page Model page

模型的NETWORK选项卡显示其配置。

Model network tab Model network tab

模型的LABELS选项卡显示其标签枚举。

Model labels Model labels