Loading serialized PyTorch models with torch.load without restricting deserialization can lead to arbitrary code execution.
In PyTorch, it is common to load serialized models using the torch.load function. Under the hood, torch.load uses the
pickle library to load the model and the weights. If the model comes from an untrusted source, an attacker could inject a malicious
payload which would be executed during the deserialization.
An attacker who controls a model file loaded with torch.load can execute arbitrary code on the target machine. This can lead to full
system compromise, data exfiltration, or further lateral movement within a network.
Use a safer alternative to load the model, such as safetensors.torch.load_model. Alternatively, PyTorch can be instructed to only load
the weights by setting the parameter weights_only=True. This avoids the use of the pickle library and is therefore safe.
Note that the use of weights_only requires saving only the state_dict of a model instead of the whole model.
import torch
model = torch.load('model.pth') # Noncompliant
import torch import safetensors model = MyModel() safetensors.torch.load_model(model, 'model.pth')