There are several approaches to recording (serializing) and loading (deserialize) patterns for inference in PyTorch.
For example you may need to load a model that is already trained and back up that comes from the internet. More recently I answered this question on a discussion forum https://discuss.pytorch.org/t/i-want-to-do-machine-learning-with-android/98753. I take advantage of this article to give some details.
SAVE AND LOAD OF A PRE-TRAINED MODE with load_state_dict
In PyTorch you can save a model by storing in its file its state_dict these are Python dictionaries, they can be easily recorded, updated, modified and restored, adding great modularity to PyTorch models and optimizers.
In my example:
ESPNet is backed up with this method
I managed to load the encoder model using the ESPNet class that makes a load_state_dict
import torch model = ESPNet(20,encoderFile="espnet_p_2_q_8.pth", p=2, q=8) example = torch.rand(1, 3, 224, 224) traced_script_module = torch.jit.trace(model, example) traced_script_module.save("model.pt")
The exit is
Note that ESPNet uses the default GPU and that an adaptation in Model.py in https://github.com/sacmehta/ESPNet/tree/master/train was required by replacing:
Otherwise if you don’t make this adjustment don’t have the ability to launch on a GPU you’ll get the following error:
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location-torch.device ('cpu') to map your storages to the CPU.
BACKUP AND LOAD BY PYTORCH SCRIPT
TorchScript is a way to create models that can be made that can be optimized from the PyTorch code.
Any TorchScript program can be saved from a Python process and loaded into a process where there is no Python dependency.
The code below allows you to save the pre-trained ESPNet model that was backed up by the classic torch.save method via the use of TorchScript.
import torch model = ESPNet(20,encoderFile="espnet_p_2_q_8.pth", p=2, q=8) example = torch.rand(1, 3, 224, 224) traced_script_module = torch.jit.trace(model, example) #Exemple de save avec TorchScript torch.jit.save(traced_script_module, "scriptmodel.pth")
An example for loader via TorchScript below:
import torch model = torch.jit.load("scriptmodel.pth")