¿Cómo guardo un modelo entrenado en PyTorch?

Resuelto Wasi Ahmad asked hace 7 años • 11 respuestas

¿Cómo guardo un modelo entrenado en PyTorch? He leído que:

  1. torch.save()/ torch.load()es para guardar/cargar un objeto serializable.
  2. model.state_dict()/ model.load_state_dict()es para guardar/cargar el estado del modelo.
Wasi Ahmad avatar Mar 10 '17 02:03 Wasi Ahmad
Aceptado

Encontré esta página en su repositorio de github:

Enfoque recomendado para guardar un modelo.

Hay dos enfoques principales para serializar y restaurar un modelo.

El primero (recomendado) guarda y carga solo los parámetros del modelo:

torch.save(the_model.state_dict(), PATH)

Entonces despúes:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

El segundo guarda y carga todo el modelo:

torch.save(the_model, PATH)

Entonces despúes:

the_model = torch.load(PATH)

Sin embargo, en este caso, los datos serializados están vinculados a las clases específicas y a la estructura de directorios exacta utilizada, por lo que pueden romperse de varias maneras cuando se usan en otros proyectos o después de algunas refactorizaciones serias.


Consulte también: sección Guardar y cargar el modelo de los tutoriales oficiales de PyTorch.

dontloo avatar May 06 '2017 10:05 dontloo

Depende de lo que quieras hacer.

Caso # 1: Guarde el modelo para usarlo usted mismo para realizar inferencias : guarda el modelo, lo restaura y luego cambia el modelo al modo de evaluación. Esto se hace porque normalmente tienes BatchNormcapas Dropoutque por defecto están en modo tren en construcción:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Caso # 2: Guarde el modelo para reanudar el entrenamiento más tarde : si necesita seguir entrenando el modelo que está a punto de guardar, debe guardar algo más que solo el modelo. También necesitas guardar el estado del optimizador, épocas, puntuación, etc. Lo harías así:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Para reanudar el entrenamiento, haría cosas como: state = torch.load(filepath)y luego, para restaurar el estado de cada objeto individual, algo como esto:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Dado que está reanudando el entrenamiento, NO llame model.eval()una vez que restaure los estados al cargar.

Caso # 3: Modelo para ser utilizado por otra persona sin acceso a su código : En Tensorflow puede crear un .pbarchivo que defina tanto la arquitectura como los pesos del modelo. Esto es muy útil, especialmente cuando se usa Tensorflow serve. La forma equivalente de hacer esto en Pytorch sería:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Esta forma todavía no es a prueba de balas y, dado que pytorch todavía está experimentando muchos cambios, no la recomendaría.

Jadiel de Armas avatar Mar 02 '2018 23:03 Jadiel de Armas

La biblioteca pickle Python implementa protocolos binarios para serializar y deserializar un objeto Python.

Cuando usted import torch(o cuando use PyTorch), lo hará import picklepor usted y no necesita llamar pickle.dump()directamente pickle.load(), cuáles son los métodos para guardar y cargar el objeto.

De hecho, torch.save()lo torch.load()envolveremos pickle.dump()y pickle.load()para ti.

La state_dictotra respuesta mencionada merece solo algunas notas más.

¿ Qué state_dicttenemos dentro de PyTorch? En realidad, hay dos state_dicts.

El modelo PyTorch es torch.nn.Moduleel que tiene model.parameters()una llamada para obtener parámetros que se pueden aprender (w y b). Estos parámetros que se pueden aprender, una vez configurados aleatoriamente, se actualizarán con el tiempo a medida que aprendamos. Los parámetros que se pueden aprender son los primeros state_dict.

El segundo state_dictes el dictado del estado del optimizador. Recuerde que el optimizador se utiliza para mejorar nuestros parámetros de aprendizaje. Pero el optimizador state_dictestá arreglado. No hay nada que aprender allí.

Debido a que state_dictlos objetos son diccionarios de Python, se pueden guardar, actualizar, modificar y restaurar fácilmente, agregando una gran modularidad a los modelos y optimizadores de PyTorch.

Creemos un modelo súper simple para explicar esto:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Este código generará lo siguiente:

Model's state_dict:
weight      torch.Size([2, 5])
bias      torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state      {}
param_groups      [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Tenga en cuenta que este es un modelo mínimo. Puede intentar agregar una pila de secuencial

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Tenga en cuenta que solo las capas con parámetros que se pueden aprender (capas convolucionales, capas lineales, etc.) y búferes registrados (capas de norma por lotes) tienen entradas en el archivo state_dict.

Las cosas que no se pueden aprender pertenecen al objeto optimizador state_dict, que contiene información sobre el estado del optimizador, así como los hiperparámetros utilizados.

El resto de la historia es la misma; en la fase de inferencia (esta es una fase en la que usamos el modelo después del entrenamiento) para predecir; Predecimos en función de los parámetros que aprendimos. Entonces, para la inferencia, solo necesitamos guardar los parámetros model.state_dict().

torch.save(model.state_dict(), filepath)

Y para usar más adelante model.load_state_dict(torch.load(filepath)) model.eval()

Nota: No olvide la última línea, model.eval()esto es crucial después de cargar el modelo.

Tampoco intentes guardar torch.save(model.parameters(), filepath). Es model.parameters()solo el objeto generador.

Por otro lado, torch.save(model, filepath)guarda el objeto del modelo en sí, pero tenga en cuenta que el modelo no tiene el optimizador state_dict. Consulte la otra excelente respuesta de @Jadiel de Armas para guardar el dictado de estado del optimizador.

prosti avatar Apr 17 '2019 19:04 prosti

Una convención común de PyTorch es guardar modelos utilizando una extensión de archivo .pt o .pth.

Guardar/cargar todo el modelo

Ahorrar:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Carga:

(La clase de modelo debe definirse en alguna parte)

model.load_state_dict(torch.load(PATH))
model.eval()
harsh avatar May 13 '2019 20:05 harsh