PyTorch comment charger un modèle pré-entraîné ?

Il existe plusieurs approches pour enregistrer (sérialiser) et charger (désérialiser) des modèles pour l’inférence dans PyTorch.

Par exemple vous pouvez avoir besoin de charger un modèle qui est déjà entrainé et sauvegarder qui provient d’internet. Plus récemment j’ai répondu à cette question sur un forum de discussion https://discuss.pytorch.org/t/i-want-to-do-machine-learning-with-android/98753. Je profite de cette article pour donner quelques détails.

SAVE ET LOAD D’UN MODE PRE-TRAINED avec load_state_dict

Dans PyTorch on peut sauvegarder un modèle en stockant dans son fichier son state_dict ce sont des dictionnaires Python, ils peuvent être facilement enregistrés, mis à jour, modifiés et restaurés, ajoutant une grande modularité aux modèles et optimiseurs PyTorch.

Dans mon exemple :

ESPNet est sauvegardé avec cette méthode

https://github.com/sacmehta/ESPNet/tree/master/pretrained/encoder

J’ai réussi à charger le modèle de l’encodeur en utilisant la classe ESPNet qui fait un load_state_dict

import torch
model = ESPNet(20,encoderFile="espnet_p_2_q_8.pth", p=2, q=8)
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")

La sortie est

Encoder loaded!

A noter que ESPNet utilise le GPU par défaut et qu’une adaptation dans Model.py dans https://github.com/sacmehta/ESPNet/tree/master/train a été nécessaire en remplaçant :

self.encoder.load_state_dict(torch.load(encoderFile)

Par :

self.encoder.load_state_dict(torch.load(encoderFile,map_location='cpu'))

Sinon si vous ne faites pas cette adaptation n’avez pas la possibilité de lancer sur un GPU vous allez obtenir l’erreur suivante :

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
PYTORCH TORCHSCRIPT

SAUVEGARDE ET LOAD PAR PYTORCH SCRIPT

https://pytorch.org/docs/stable/jit.html

TorchScript est un moyen de créer des modèles sérialisables et optimisables à partir du code PyTorch.

Tout programme TorchScript peut être enregistré à partir d’un processus Python et chargé dans un processus où il n’y a pas de dépendance Python.

Le code ci-dessous permet de sauvegarder via l’utilisation de TorchScript le modèle pré-entraîné de ESPNet qui avait été sauvegardé par la méthode classique torch.save.

import torch


model = ESPNet(20,encoderFile="espnet_p_2_q_8.pth", p=2, q=8)
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
#Exemple de save avec TorchScript
torch.jit.save(traced_script_module, "scriptmodel.pth")

Un exemple pour loader via TorchScript ci-dessous :

import torch
model = torch.jit.load("scriptmodel.pth")

RESSOURCES SUR LE SUJET

https://pytorch.org/docs/stable/jit.html

https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html

https://stackoverflow.com/questions/53900396/what-are-torch-scripts-in-pytorch

Laisser un commentaire

Votre adresse e-mail ne sera pas publiée. Les champs obligatoires sont indiqués avec *