¿Qué hace .view() en PyTorch?
¿ Qué le hace .view()
a un tensor x
? ¿Qué significan los valores negativos?
x = x.view(-1, 16 * 5 * 5)
view()
reforma el tensor sin copiar la memoria, similar a numpy reshape()
.
Dado un tensor a
con 16 elementos:
import torch
a = torch.range(1, 16)
Para remodelar este tensor y convertirlo en 4 x 4
tensor, use:
a = a.view(4, 4)
Ahora a
será un 4 x 4
tensor. Tenga en cuenta que después de la remodelación, el número total de elementos debe seguir siendo el mismo. Cambiar la forma del tensor a
a 3 x 5
tensor no sería apropiado.
¿Cuál es el significado del parámetro -1?
Si hay alguna situación en la que no sabe cuántas filas desea pero está seguro del número de columnas, puede especificar esto con un -1. ( Tenga en cuenta que puede extender esto a tensores con más dimensiones. Solo uno de los valores del eje puede ser -1 ). Esta es una forma de decirle a la biblioteca: "dame un tensor que tenga tantas columnas y calcularás el número apropiado de filas necesarias para que esto suceda".
Esto se puede ver en este código de definición de modelo . Después de la línea x = self.pool(F.relu(self.conv2(x)))
en la función de avance, tendrá un mapa de características de 16 profundidades. Tienes que aplanarlo para darle la capa completamente conectada. Entonces le dice a PyTorch que remodele el tensor que obtuvo para que tenga un número específico de columnas y le dice que decida el número de filas por sí mismo.
view()
reforma un tensor 'estirando' o 'comprimiendo' sus elementos en la forma que especifiques:
¿Cómo view()
funciona?
Primero, veamos qué es un tensor debajo del capó:
Tensor y su subyacentestorage |
por ejemplo, el tensor de la derecha (forma (3,2)) se puede calcular a partir del de la izquierda cont2 = t1.view(3,2) |
Aquí puede ver que PyTorch crea un tensor al convertir un bloque subyacente de memoria contigua en un objeto similar a una matriz agregando un shape
atributo and stride
:
shape
indica cuánto mide cada dimensiónstride
indica cuántos pasos debes dar en la memoria hasta llegar al siguiente elemento en cada dimensión
view(dim1,dim2,...)
devuelve una vista de la misma información subyacente, pero reformada a un tensor de formadim1 x dim2 x ...
(modificando los atributosshape
ystride
).
Tenga en cuenta que esto supone implícitamente que las dimensiones nuevas y antiguas tienen el mismo producto (es decir, el tensor antiguo y el nuevo tienen el mismo volumen).
PyTorch -1
-1
es un alias de PyTorch para "inferir esta dimensión dado que todas las demás han sido especificadas" (es decir, el cociente del producto original por el nuevo producto). Es una convención tomada de numpy.reshape()
.
Por lo tanto t1.view(3,2)
en nuestro ejemplo sería equivalente a t1.view(3,-1)
o t1.view(-1,2)
.
Hagamos algunos ejemplos, de más simples a más difíciles.
El
view
método devuelve un tensor con los mismos datos que elself
tensor (lo que significa que el tensor devuelto tiene la misma cantidad de elementos), pero con una forma diferente. Por ejemplo:a = torch.arange(1, 17) # a's shape is (16,) a.view(4, 4) # output below 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 [torch.FloatTensor of size 4x4] a.view(2, 2, 4) # output below (0 ,.,.) = 1 2 3 4 5 6 7 8 (1 ,.,.) = 9 10 11 12 13 14 15 16 [torch.FloatTensor of size 2x2x4]
Suponiendo que
-1
no es uno de los parámetros, cuando los multiplicas, el resultado debe ser igual al número de elementos del tensor. Si lo hace:a.view(3, 3)
, generará un mensajeRuntimeError
porque la forma (3 x 3) no es válida para la entrada con 16 elementos. En otras palabras: 3 x 3 no es igual a 16 sino a 9.Puedes usarlo
-1
como uno de los parámetros que pasas a la función, pero solo una vez. Todo lo que sucede es que el método hará los cálculos por usted sobre cómo llenar esa dimensión. Por ejemploa.view(2, -1, 4)
es equivalente aa.view(2, 2, 4)
. [16/(2 x 4) = 2]Observe que el tensor devuelto comparte los mismos datos . Si realiza un cambio en la "vista", está cambiando los datos del tensor original:
b = a.view(4, 4) b[0, 2] = 2 a[2] == 3.0 False
Ahora, veamos un caso de uso más complejo. La documentación dice que cada nueva dimensión de vista debe ser un subespacio de una dimensión original, o solo abarcar d, d + 1, ..., d + k que satisfaga la siguiente condición de contigüidad que para todo i = 0,. .., k - 1, zancada[i] = zancada[i + 1] x tamaño[i + 1] . De lo contrario,
contiguous()
es necesario llamarlo antes de poder ver el tensor. Por ejemplo:a = torch.rand(5, 4, 3, 2) # size (5, 4, 3, 2) a_t = a.permute(0, 2, 3, 1) # size (5, 3, 2, 4) # The commented line below will raise a RuntimeError, because one dimension # spans across two contiguous subspaces # a_t.view(-1, 4) # instead do: a_t.contiguous().view(-1, 4) # To see why the first one does not work and the second does, # compare a.stride() and a_t.stride() a.stride() # (24, 6, 2, 1) a_t.stride() # (24, 2, 1, 6)
Observe que para
a_t
, zancada[0] != zancada[1] x tamaño[1] desde 24 != 2 x 3
torch.Tensor.view()
En pocas palabras, torch.Tensor.view()
que está inspirado en numpy.ndarray.reshape()
o numpy.reshape()
, crea una nueva vista del tensor, siempre que la nueva forma sea compatible con la forma del tensor original.
Entendamos esto en detalle usando un ejemplo concreto.
In [43]: t = torch.arange(18)
In [44]: t
Out[44]:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17])
Con este tensor t
de forma (18,)
, solo se pueden crear nuevas vistas para las siguientes formas:
(1, 18)
o de manera equivalente (1, -1)
o o equivalente o o equivalente o o equivalente o o equivalente o o equivalente o(-1, 18)
(2, 9)
(2, -1)
(-1, 9)
(3, 6)
(3, -1)
(-1, 6)
(6, 3)
(6, -1)
(-1, 3)
(9, 2)
(9, -1)
(-1, 2)
(18, 1)
(18, -1)
(-1, 1)
Como ya podemos observar en las tuplas de forma anteriores, la multiplicación de los elementos de la tupla de forma (por ejemplo 2*9
, 3*6
etc.) siempre debe ser igual al número total de elementos en el tensor original ( 18
en nuestro ejemplo).
Otra cosa a observar es que usamos a -1
en uno de los lugares de cada una de las tuplas de formas. Al usar a -1
, somos perezosos al hacer el cálculo nosotros mismos y preferimos delegar la tarea a PyTorch para que haga el cálculo de ese valor para la forma cuando crea la nueva vista . Una cosa importante a tener en cuenta es que solo podemos usar una tupla única -1
en la forma. Los valores restantes deben ser proporcionados explícitamente por nosotros. De lo contrario, PyTorch se quejará lanzando un RuntimeError
:
RuntimeError: solo se puede inferir una dimensión
Entonces, con todas las formas mencionadas anteriormente, PyTorch siempre devolverá una nueva vista del tensor original t
. Básicamente, esto significa que simplemente cambia la información de zancada del tensor para cada una de las nuevas vistas que se solicitan.
A continuación se muestran algunos ejemplos que ilustran cómo cambian los pasos de los tensores con cada nueva vista .
# stride of our original tensor `t`
In [53]: t.stride()
Out[53]: (1,)
Ahora veremos los avances de las nuevas vistas :
# shape (1, 18)
In [54]: t1 = t.view(1, -1)
# stride tensor `t1` with shape (1, 18)
In [55]: t1.stride()
Out[55]: (18, 1)
# shape (2, 9)
In [56]: t2 = t.view(2, -1)
# stride of tensor `t2` with shape (2, 9)
In [57]: t2.stride()
Out[57]: (9, 1)
# shape (3, 6)
In [59]: t3 = t.view(3, -1)
# stride of tensor `t3` with shape (3, 6)
In [60]: t3.stride()
Out[60]: (6, 1)
# shape (6, 3)
In [62]: t4 = t.view(6,-1)
# stride of tensor `t4` with shape (6, 3)
In [63]: t4.stride()
Out[63]: (3, 1)
# shape (9, 2)
In [65]: t5 = t.view(9, -1)
# stride of tensor `t5` with shape (9, 2)
In [66]: t5.stride()
Out[66]: (2, 1)
# shape (18, 1)
In [68]: t6 = t.view(18, -1)
# stride of tensor `t6` with shape (18, 1)
In [69]: t6.stride()
Out[69]: (1, 1)
Esa es la magia de la view()
función. Simplemente cambia los pasos del tensor (original) para cada una de las nuevas vistas , siempre que la forma de la nueva vista sea compatible con la forma original.
Otra cosa interesante que se puede observar en las tuplas de pasos es que el valor del elemento en la posición 0 es igual al valor del elemento en la posición 1 de la tupla de forma.
In [74]: t3.shape
Out[74]: torch.Size([3, 6])
|
In [75]: t3.stride() |
Out[75]: (6, 1) |
|_____________|
Esto es porque:
In [76]: t3
Out[76]:
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17]])
la zancada (6, 1)
dice que para pasar de un elemento al siguiente a lo largo de la dimensión 0 , tenemos que saltar o dar 6 pasos. (es decir, para ir de 0
a 6
, uno tiene que tomar 6 pasos). Pero para pasar de un elemento al siguiente elemento en la primera dimensión , solo necesitamos un paso (por ejemplo, ir de 2
a 3
).
Por lo tanto, la información de los avances es el núcleo de cómo se accede a los elementos desde la memoria para realizar el cálculo.
antorcha.reformar()
Esta función devolvería una vista y es exactamente igual que usar torch.Tensor.view()
siempre que la nueva forma sea compatible con la forma del tensor original. En caso contrario, devolverá una copia.
Sin embargo, las notas de torch.reshape()
advierte que:
Las entradas contiguas y las entradas con pasos compatibles se pueden remodelar sin copiar, pero no se debe depender del comportamiento de copia versus visualización.
Intentemos entender la vista con los siguientes ejemplos:
a=torch.range(1,16)
print(a)
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
15., 16.])
print(a.view(-1,2))
tensor([[ 1., 2.],
[ 3., 4.],
[ 5., 6.],
[ 7., 8.],
[ 9., 10.],
[11., 12.],
[13., 14.],
[15., 16.]])
print(a.view(2,-1,4)) #3d tensor
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.]],
[[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]])
print(a.view(2,-1,2))
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.],
[ 7., 8.]],
[[ 9., 10.],
[11., 12.],
[13., 14.],
[15., 16.]]])
print(a.view(4,-1,2))
tensor([[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]],
[[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.]]])
-1 como valor de argumento es una manera fácil de calcular el valor de, digamos, x, siempre que conozcamos los valores de y, z o al revés en el caso de 3d y para 2d nuevamente, una manera fácil de calcular el valor de, por ejemplo, x, siempre que sepamos conocer los valores de y o viceversa.