¿Cómo guardo un modelo entrenado en PyTorch?
¿Cómo guardo un modelo entrenado en PyTorch? He leído que:
torch.save()
/torch.load()
es para guardar/cargar un objeto serializable.model.state_dict()
/model.load_state_dict()
es para guardar/cargar el estado del modelo.
Encontré esta página en su repositorio de github:
Enfoque recomendado para guardar un modelo.
Hay dos enfoques principales para serializar y restaurar un modelo.
El primero (recomendado) guarda y carga solo los parámetros del modelo:
torch.save(the_model.state_dict(), PATH)
Entonces despúes:
the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH))
El segundo guarda y carga todo el modelo:
torch.save(the_model, PATH)
Entonces despúes:
the_model = torch.load(PATH)
Sin embargo, en este caso, los datos serializados están vinculados a las clases específicas y a la estructura de directorios exacta utilizada, por lo que pueden romperse de varias maneras cuando se usan en otros proyectos o después de algunas refactorizaciones serias.
Consulte también: sección Guardar y cargar el modelo de los tutoriales oficiales de PyTorch.
Depende de lo que quieras hacer.
Caso # 1: Guarde el modelo para usarlo usted mismo para realizar inferencias : guarda el modelo, lo restaura y luego cambia el modelo al modo de evaluación. Esto se hace porque normalmente tienes BatchNorm
capas Dropout
que por defecto están en modo tren en construcción:
torch.save(model.state_dict(), filepath)
#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()
Caso # 2: Guarde el modelo para reanudar el entrenamiento más tarde : si necesita seguir entrenando el modelo que está a punto de guardar, debe guardar algo más que solo el modelo. También necesitas guardar el estado del optimizador, épocas, puntuación, etc. Lo harías así:
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}
torch.save(state, filepath)
Para reanudar el entrenamiento, haría cosas como: state = torch.load(filepath)
y luego, para restaurar el estado de cada objeto individual, algo como esto:
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
Dado que está reanudando el entrenamiento, NO llame model.eval()
una vez que restaure los estados al cargar.
Caso # 3: Modelo para ser utilizado por otra persona sin acceso a su código : En Tensorflow puede crear un .pb
archivo que defina tanto la arquitectura como los pesos del modelo. Esto es muy útil, especialmente cuando se usa Tensorflow serve
. La forma equivalente de hacer esto en Pytorch sería:
torch.save(model, filepath)
# Then later:
model = torch.load(filepath)
Esta forma todavía no es a prueba de balas y, dado que pytorch todavía está experimentando muchos cambios, no la recomendaría.
La biblioteca pickle Python implementa protocolos binarios para serializar y deserializar un objeto Python.
Cuando usted import torch
(o cuando use PyTorch), lo hará import pickle
por usted y no necesita llamar pickle.dump()
directamente pickle.load()
, cuáles son los métodos para guardar y cargar el objeto.
De hecho, torch.save()
lo torch.load()
envolveremos pickle.dump()
y pickle.load()
para ti.
La state_dict
otra respuesta mencionada merece solo algunas notas más.
¿ Qué state_dict
tenemos dentro de PyTorch? En realidad, hay dos state_dict
s.
El modelo PyTorch es torch.nn.Module
el que tiene model.parameters()
una llamada para obtener parámetros que se pueden aprender (w y b). Estos parámetros que se pueden aprender, una vez configurados aleatoriamente, se actualizarán con el tiempo a medida que aprendamos. Los parámetros que se pueden aprender son los primeros state_dict
.
El segundo state_dict
es el dictado del estado del optimizador. Recuerde que el optimizador se utiliza para mejorar nuestros parámetros de aprendizaje. Pero el optimizador state_dict
está arreglado. No hay nada que aprender allí.
Debido a que state_dict
los objetos son diccionarios de Python, se pueden guardar, actualizar, modificar y restaurar fácilmente, agregando una gran modularidad a los modelos y optimizadores de PyTorch.
Creemos un modelo súper simple para explicar esto:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Este código generará lo siguiente:
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
Tenga en cuenta que este es un modelo mínimo. Puede intentar agregar una pila de secuencial
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
Tenga en cuenta que solo las capas con parámetros que se pueden aprender (capas convolucionales, capas lineales, etc.) y búferes registrados (capas de norma por lotes) tienen entradas en el archivo state_dict
.
Las cosas que no se pueden aprender pertenecen al objeto optimizador state_dict
, que contiene información sobre el estado del optimizador, así como los hiperparámetros utilizados.
El resto de la historia es la misma; en la fase de inferencia (esta es una fase en la que usamos el modelo después del entrenamiento) para predecir; Predecimos en función de los parámetros que aprendimos. Entonces, para la inferencia, solo necesitamos guardar los parámetros model.state_dict()
.
torch.save(model.state_dict(), filepath)
Y para usar más adelante model.load_state_dict(torch.load(filepath)) model.eval()
Nota: No olvide la última línea, model.eval()
esto es crucial después de cargar el modelo.
Tampoco intentes guardar torch.save(model.parameters(), filepath)
. Es model.parameters()
solo el objeto generador.
Por otro lado, torch.save(model, filepath)
guarda el objeto del modelo en sí, pero tenga en cuenta que el modelo no tiene el optimizador state_dict
. Consulte la otra excelente respuesta de @Jadiel de Armas para guardar el dictado de estado del optimizador.
Una convención común de PyTorch es guardar modelos utilizando una extensión de archivo .pt o .pth.
Guardar/cargar todo el modelo
Ahorrar:
path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)
Carga:
(La clase de modelo debe definirse en alguna parte)
model.load_state_dict(torch.load(PATH))
model.eval()