Loading serialized PyTorch models with torch.load without restricting deserialization can lead to arbitrary code execution.

Why is this an issue?

In PyTorch, it is common to load serialized models using the torch.load function. Under the hood, torch.load uses the pickle library to load the model and the weights. If the model comes from an untrusted source, an attacker could inject a malicious payload which would be executed during the deserialization.

What is the potential impact?

Remote code execution

An attacker who controls a model file loaded with torch.load can execute arbitrary code on the target machine. This can lead to full system compromise, data exfiltration, or further lateral movement within a network.

How to fix it

Use a safer alternative to load the model, such as safetensors.torch.load_model. Alternatively, PyTorch can be instructed to only load the weights by setting the parameter weights_only=True. This avoids the use of the pickle library and is therefore safe. Note that the use of weights_only requires saving only the state_dict of a model instead of the whole model.

Code examples

Noncompliant code example

import torch

model = torch.load('model.pth') # Noncompliant

Compliant solution

import torch
import safetensors

model = MyModel()
safetensors.torch.load_model(model, 'model.pth')

Resources

Documentation

Standards