¿Cómo se puede calcular la distancia euclidiana con NumPy?

Resuelto Nathan Fellman asked hace 15 años • 26 respuestas

Tengo dos puntos en el espacio 3D:

a = (ax, ay, az)
b = (bx, by, bz)

Quiero calcular la distancia entre ellos:

dist = sqrt((ax-bx)^2 + (ay-by)^2 + (az-bz)^2)

¿Cómo hago esto con NumPy? Tengo:

import numpy
a = numpy.array((ax, ay, az))
b = numpy.array((bx, by, bz))
Nathan Fellman avatar Sep 10 '09 02:09 Nathan Fellman
Aceptado

Usar numpy.linalg.norm:

dist = numpy.linalg.norm(a-b)

Esto funciona porque la distancia euclidiana es la norma l2 y el valor predeterminado del ordparámetro numpy.linalg.normes 2. Para obtener más teoría, consulte Introducción a la minería de datos :

ingrese la descripción de la imagen aquí

u0b34a0f6ae avatar Sep 09 '2009 20:09 u0b34a0f6ae

Usar scipy.spatial.distance.euclidean:

from scipy.spatial import distance
a = (1, 2, 3)
b = (4, 5, 6)
dst = distance.euclidean(a, b)
Avision avatar Feb 24 '2014 11:02 Avision

Para cualquiera interesado en calcular múltiples distancias a la vez, hice una pequeña comparación usando perfplot (un pequeño proyecto mío).

El primer consejo es organizar sus datos de manera que las matrices tengan dimensión (3, n)(y, obviamente, sean contiguas a C). Si la suma ocurre en la primera dimensión contigua, las cosas son más rápidas y no importa demasiado si usa sqrt-sumwith axis=0, linalg.normwith axis=0o

a_min_b = a - b
numpy.sqrt(numpy.einsum('ij,ij->j', a_min_b, a_min_b))

que es, por un ligero margen, la variante más rápida. (Eso también es válido solo para una fila).

Las variantes en las que se suma el segundo eje, axis=1son todas sustancialmente más lentas.

ingrese la descripción de la imagen aquí


Código para reproducir la trama:

import numpy
import perfplot
from scipy.spatial import distance


def linalg_norm(data):
    a, b = data[0]
    return numpy.linalg.norm(a - b, axis=1)


def linalg_norm_T(data):
    a, b = data[1]
    return numpy.linalg.norm(a - b, axis=0)


def sqrt_sum(data):
    a, b = data[0]
    return numpy.sqrt(numpy.sum((a - b) ** 2, axis=1))


def sqrt_sum_T(data):
    a, b = data[1]
    return numpy.sqrt(numpy.sum((a - b) ** 2, axis=0))


def scipy_distance(data):
    a, b = data[0]
    return list(map(distance.euclidean, a, b))


def sqrt_einsum(data):
    a, b = data[0]
    a_min_b = a - b
    return numpy.sqrt(numpy.einsum("ij,ij->i", a_min_b, a_min_b))


def sqrt_einsum_T(data):
    a, b = data[1]
    a_min_b = a - b
    return numpy.sqrt(numpy.einsum("ij,ij->j", a_min_b, a_min_b))


def setup(n):
    a = numpy.random.rand(n, 3)
    b = numpy.random.rand(n, 3)
    out0 = numpy.array([a, b])
    out1 = numpy.array([a.T, b.T])
    return out0, out1


b = perfplot.bench(
    setup=setup,
    n_range=[2 ** k for k in range(22)],
    kernels=[
        linalg_norm,
        linalg_norm_T,
        scipy_distance,
        sqrt_sum,
        sqrt_sum_T,
        sqrt_einsum,
        sqrt_einsum_T,
    ],
    xlabel="len(x), len(y)",
)
b.save("norm.png")
Nico Schlömer avatar Dec 12 '2017 14:12 Nico Schlömer

Quiero explicar la respuesta simple con varias notas de desempeño. np.linalg.norm hará quizás más de lo que necesita:

dist = numpy.linalg.norm(a-b)

En primer lugar, esta función está diseñada para trabajar con una lista y devolver todos los valores, por ejemplo, para comparar la distancia desde pAel conjunto de puntos sP:

sP = set(points)
pA = point
distances = np.linalg.norm(sP - pA, ord=2, axis=1.)  # 'distances' is a list

Recuerda varias cosas:

  • Las llamadas a funciones de Python son caras.
  • [Regular] Python no almacena en caché las búsquedas de nombres.

Entonces

def distance(pointA, pointB):
    dist = np.linalg.norm(pointA - pointB)
    return dist

No es tan inocente como parece.

>>> dis.dis(distance)
  2           0 LOAD_GLOBAL              0 (np)
              2 LOAD_ATTR                1 (linalg)
              4 LOAD_ATTR                2 (norm)
              6 LOAD_FAST                0 (pointA)
              8 LOAD_FAST                1 (pointB)
             10 BINARY_SUBTRACT
             12 CALL_FUNCTION            1
             14 STORE_FAST               2 (dist)

  3          16 LOAD_FAST                2 (dist)
             18 RETURN_VALUE

En primer lugar, cada vez que lo llamamos, tenemos que hacer una búsqueda global de "np", una búsqueda de alcance de "linalg" y una búsqueda de alcance de "norm", y la sobrecarga de simplemente llamar a la función puede equivaler a docenas de python. instrucciones.

Por último, desperdiciamos dos operaciones para almacenar el resultado y recargarlo para su devolución...

Primer paso para mejorar: hacer la búsqueda más rápida, omitir la tienda

def distance(pointA, pointB, _norm=np.linalg.norm):
    return _norm(pointA - pointB)

Obtenemos el mucho más simplificado:

>>> dis.dis(distance)
  2           0 LOAD_FAST                2 (_norm)
              2 LOAD_FAST                0 (pointA)
              4 LOAD_FAST                1 (pointB)
              6 BINARY_SUBTRACT
              8 CALL_FUNCTION            1
             10 RETURN_VALUE

Sin embargo, la sobrecarga de la llamada a la función todavía supone algo de trabajo. Y querrás hacer puntos de referencia para determinar si sería mejor que hicieras los cálculos tú mismo:

def distance(pointA, pointB):
    return (
        ((pointA.x - pointB.x) ** 2) +
        ((pointA.y - pointB.y) ** 2) +
        ((pointA.z - pointB.z) ** 2)
    ) ** 0.5  # fast sqrt

En algunas plataformas, **0.5es más rápido que math.sqrt. Su experiencia puede ser diferente.

**** Notas de rendimiento avanzadas.

¿Por qué estás calculando la distancia? Si el único propósito es exhibirlo,

 print("The target is %.2fm away" % (distance(a, b)))

superar. Pero si estás comparando distancias, haciendo comprobaciones de alcance, etc., me gustaría agregar algunas observaciones útiles de rendimiento.

Tomemos dos casos: ordenar por distancia o seleccionar una lista de elementos que cumplan una restricción de rango.

# Ultra naive implementations. Hold onto your hat.

def sort_things_by_distance(origin, things):
    return things.sort(key=lambda thing: distance(origin, thing))

def in_range(origin, range, things):
    things_in_range = []
    for thing in things:
        if distance(origin, thing) <= range:
            things_in_range.append(thing)

Lo primero que debemos recordar es que estamos usando Pitágoras para calcular la distancia ( dist = sqrt(x^2 + y^2 + z^2)), por lo que estamos haciendo muchas sqrtllamadas. Matemáticas 101:

dist = root ( x^2 + y^2 + z^2 )
:.
dist^2 = x^2 + y^2 + z^2
and
sq(N) < sq(M) iff M > N
and
sq(N) > sq(M) iff N > M
and
sq(N) = sq(M) iff N == M

En resumen: hasta que realmente requieramos la distancia en una unidad de X en lugar de X^2, podemos eliminar la parte más difícil de los cálculos.

# Still naive, but much faster.

def distance_sq(left, right):
    """ Returns the square of the distance between left and right. """
    return (
        ((left.x - right.x) ** 2) +
        ((left.y - right.y) ** 2) +
        ((left.z - right.z) ** 2)
    )

def sort_things_by_distance(origin, things):
    return things.sort(key=lambda thing: distance_sq(origin, thing))

def in_range(origin, range, things):
    things_in_range = []

    # Remember that sqrt(N)**2 == N, so if we square
    # range, we don't need to root the distances.
    range_sq = range**2

    for thing in things:
        if distance_sq(origin, thing) <= range_sq:
            things_in_range.append(thing)

Genial, ambas funciones ya no generan costosas raíces cuadradas. Será mucho más rápido, pero antes de continuar, compruébalo tú mismo: ¿por qué sort_things_by_distance necesitaba un descargo de responsabilidad "ingenuo" en las dos ocasiones anteriores? Responda al final (*a1).

Podemos mejorar in_range convirtiéndolo en un generador:

def in_range(origin, range, things):
    range_sq = range**2
    yield from (thing for thing in things
                if distance_sq(origin, thing) <= range_sq)

Esto tiene beneficios especialmente si estás haciendo algo como:

if any(in_range(origin, max_dist, things)):
    ...

Pero si lo siguiente que vas a hacer requiere distancia,

for nearby in in_range(origin, walking_distance, hotdog_stands):
    print("%s %.2fm" % (nearby.name, distance(origin, nearby)))

considere producir tuplas:

def in_range_with_dist_sq(origin, range, things):
    range_sq = range**2
    for thing in things:
        dist_sq = distance_sq(origin, thing)
        if dist_sq <= range_sq: yield (thing, dist_sq)

Esto puede resultar especialmente útil si puede encadenar comprobaciones de alcance ("encontrar cosas que estén cerca de X y dentro de Nm de Y", ya que no es necesario volver a calcular la distancia).

Pero ¿qué pasa si estamos buscando en una lista muy grande thingsy anticipamos que muchos de ellos no merecen ser considerados?

En realidad, existe una optimización muy simple:

def in_range_all_the_things(origin, range, things):
    range_sq = range**2
    for thing in things:
        dist_sq = (origin.x - thing.x) ** 2
        if dist_sq <= range_sq:
            dist_sq += (origin.y - thing.y) ** 2
            if dist_sq <= range_sq:
                dist_sq += (origin.z - thing.z) ** 2
                if dist_sq <= range_sq:
                    yield thing

Que esto sea útil dependerá del tamaño de las "cosas".

def in_range_all_the_things(origin, range, things):
    range_sq = range**2
    if len(things) >= 4096:
        for thing in things:
            dist_sq = (origin.x - thing.x) ** 2
            if dist_sq <= range_sq:
                dist_sq += (origin.y - thing.y) ** 2
                if dist_sq <= range_sq:
                    dist_sq += (origin.z - thing.z) ** 2
                    if dist_sq <= range_sq:
                        yield thing
    elif len(things) > 32:
        for things in things:
            dist_sq = (origin.x - thing.x) ** 2
            if dist_sq <= range_sq:
                dist_sq += (origin.y - thing.y) ** 2 + (origin.z - thing.z) ** 2
                if dist_sq <= range_sq:
                    yield thing
    else:
        ... just calculate distance and range-check it ...

Y nuevamente, considere generar dist_sq. Nuestro ejemplo de hot dog entonces se convierte en:

# Chaining generators
info = in_range_with_dist_sq(origin, walking_distance, hotdog_stands)
info = (stand, dist_sq**0.5 for stand, dist_sq in info)
for stand, dist in info:
    print("%s %.2fm" % (stand, dist))

(*a1: la clave de clasificación de sort_things_by_distance llama a Distance_sq para cada elemento, y esa clave de apariencia inocente es una lambda, que es una segunda función que debe invocarse...)

kfsone avatar Nov 28 '2017 22:11 kfsone

Otro ejemplo de este método de resolución de problemas :

def dist(x,y):   
    return numpy.sqrt(numpy.sum((x-y)**2))

a = numpy.array((xa,ya,za))
b = numpy.array((xb,yb,zb))
dist_a_b = dist(a,b)
Nathan Fellman avatar Sep 09 '2009 19:09 Nathan Fellman