Aritmética modular y optimizaciones NTT (DFT de campo finito)
Quería usar NTT para elevar al cuadrado rápidamente (consulte Cálculo rápido de cuadrados bignum ), pero el resultado es lento incluso para números realmente grandes... más de 12000 bits.
Entonces mi pregunta es:
- ¿Existe alguna forma de optimizar mi transformación NTT? No quise acelerarlo mediante paralelismo (hilos); Esta es solo una capa de bajo nivel.
- ¿Hay alguna manera de acelerar mi aritmética modular?
Este es mi código fuente (ya optimizado) en C++ para NTT (está completo y funciona al 100% en C++ sin necesidad de bibliotecas de terceros y también debe ser seguro para subprocesos. ¡¡¡Cuidado, la matriz fuente se usa como temporal!!! , Además, no puede transformar la matriz en sí misma).
//---------------------------------------------------------------------------
class fourier_NTT // Number theoretic transform
{
public:
DWORD r,L,p,N;
DWORD W,iW,rN;
fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; }
// main interface
void NTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast NTT(DWORD src[n])
void INTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast INTT(DWORD src[n])
// Helper functions
bool init(DWORD n); // init r,L,p,W,iW,rN
void NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = fast NTT(DWORD src[n])
// Only for testing
void NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow NTT(DWORD src[n])
void INTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow INTT(DWORD src[n])
// DWORD arithmetics
DWORD shl(DWORD a);
DWORD shr(DWORD a);
// Modular arithmetics
DWORD mod(DWORD a);
DWORD modadd(DWORD a,DWORD b);
DWORD modsub(DWORD a,DWORD b);
DWORD modmul(DWORD a,DWORD b);
DWORD modpow(DWORD a,DWORD b);
};
//---------------------------------------------------------------------------
void fourier_NTT:: NTT(DWORD *dst,DWORD *src,DWORD n)
{
if (n>0) init(n);
NTT_fast(dst,src,N,W);
// NTT_slow(dst,src,N,W);
}
//---------------------------------------------------------------------------
void fourier_NTT::INTT(DWORD *dst,DWORD *src,DWORD n)
{
if (n>0) init(n);
NTT_fast(dst,src,N,iW);
for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN);
// INTT_slow(dst,src,N,W);
}
//---------------------------------------------------------------------------
bool fourier_NTT::init(DWORD n)
{
// (max(src[])^2)*n < p else NTT overflow can ocur !!!
r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit
// r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
// r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
// r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
N=n; // size of vectors [DWORDs]
W=modpow(r, L); // Wn for NTT
iW=modpow(r,p-1-L); // Wn for INTT
rN=modpow(n,p-2 ); // scale for INTT
return true;
}
//---------------------------------------------------------------------------
void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w)
{
if (n<=1) { if (n==1) dst[0]=src[0]; return; }
DWORD i,j,a0,a1,n2=n>>1,w2=modmul(w,w);
// reorder even,odd
for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
for ( j=1;i<n ;i++,j+=2) dst[i]=src[j];
// recursion
NTT_fast(src ,dst ,n2,w2); // even
NTT_fast(src+n2,dst+n2,n2,w2); // odd
// restore results
for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w))
{
a0=src[i];
a1=modmul(src[j],w2);
dst[i]=modadd(a0,a1);
dst[j]=modsub(a0,a1);
}
}
//---------------------------------------------------------------------------
void fourier_NTT:: NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
{
DWORD i,j,wj,wi,a,n2=n>>1;
for (wj=1,j=0;j<n;j++)
{
a=0;
for (wi=1,i=0;i<n;i++)
{
a=modadd(a,modmul(wi,src[i]));
wi=modmul(wi,wj);
}
dst[j]=a;
wj=modmul(wj,w);
}
}
//---------------------------------------------------------------------------
void fourier_NTT::INTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
{
DWORD i,j,wi=1,wj=1,a,n2=n>>1;
for (wj=1,j=0;j<n;j++)
{
a=0;
for (wi=1,i=0;i<n;i++)
{
a=modadd(a,modmul(wi,src[i]));
wi=modmul(wi,wj);
}
dst[j]=modmul(a,rN);
wj=modmul(wj,iW);
}
}
//---------------------------------------------------------------------------
DWORD fourier_NTT::shl(DWORD a) { return (a<<1)&0xFFFFFFFE; }
DWORD fourier_NTT::shr(DWORD a) { return (a>>1)&0x7FFFFFFF; }
//---------------------------------------------------------------------------
DWORD fourier_NTT::mod(DWORD a)
{
DWORD bb;
for (bb=p;(DWORD(a)>DWORD(bb))&&(!DWORD(bb&0x80000000));bb=shl(bb));
for (;;)
{
if (DWORD(a)>=DWORD(bb)) a-=bb;
if (bb==p) break;
bb =shr(bb);
}
return a;
}
//---------------------------------------------------------------------------
DWORD fourier_NTT::modadd(DWORD a,DWORD b)
{
DWORD d,cy;
a=mod(a);
b=mod(b);
d=a+b;
cy=(shr(a)+shr(b)+shr((a&1)+(b&1)))&0x80000000;
if (cy) d-=p;
if (DWORD(d)>=DWORD(p)) d-=p;
return d;
}
//---------------------------------------------------------------------------
DWORD fourier_NTT::modsub(DWORD a,DWORD b)
{
DWORD d;
a=mod(a);
b=mod(b);
d=a-b; if (DWORD(a)<DWORD(b)) d+=p;
if (DWORD(d)>=DWORD(p)) d-=p;
return d;
}
//---------------------------------------------------------------------------
DWORD fourier_NTT::modmul(DWORD a,DWORD b)
{ // b bez orezania !
int i;
DWORD d;
a=mod(a);
for (d=0,i=0;i<32;i++)
{
if (DWORD(a&1)) d=modadd(d,b);
a=shr(a);
b=modadd(b,b);
}
return d;
}
//---------------------------------------------------------------------------
DWORD fourier_NTT::modpow(DWORD a,DWORD b)
{ // a,b bez orezania !
int i;
DWORD d=1;
for (i=0;i<32;i++)
{
d=modmul(d,d);
if (DWORD(b&0x80000000)) d=modmul(d,a);
b=shl(b);
}
return d;
}
//---------------------------------------------------------------------------
Ejemplo de uso de mi clase NTT:
fourier_NTT ntt;
const DWORD n=32
DWORD x[N]={0,1,2,3,....31},y[N]={32,33,34,35,...63},z[N];
ntt.NTT(z,x,N); // z[N]=NTT(x[N]), also init constants for N
ntt.NTT(x,y); // x[N]=NTT(y[N]), no recompute of constants, use last N
// modular convolution y[]=z[].x[]
for (i=0;i<n;i++) y[i]=ntt.modmul(z[i],x[i]);
ntt.INTT(x,y); // x[N]=INTT(y[N]), no recompute of constants, use last N
// x[]=convolution of original x[].y[]
Algunas mediciones antes de las optimizaciones (no Clase NTT):
a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.177 ms ] fast sqr
sqr2[ 720.419 ms ] NTT sqr
mul1[ 5.588 ms ] simpe mul
mul2[ 3.172 ms ] karatsuba mul
mul3[ 1053.382 ms ] NTT mul
Algunas medidas después de mis optimizaciones (código actual, menor tamaño/recuento de parámetros de recursividad y mejor aritmética modular):
a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.214 ms ] fast sqr
sqr2[ 208.298 ms ] NTT sqr
mul1[ 5.564 ms ] simpe mul
mul2[ 3.113 ms ] karatsuba mul
mul3[ 302.740 ms ] NTT mul
Verifique los tiempos de NTT mul y NTT sqr (mis optimizaciones lo aceleran un poco más de 3 veces). Es solo un bucle de 1x, por lo que no es muy preciso (error ~ 10%), pero la aceleración es notable incluso ahora (normalmente lo hago 1000x y más, pero mi NTT es demasiado lento para eso).
Puedes usar mi código libremente... Simplemente mantén mi nick y/o enlace a esta página en algún lugar (rem in code, readme.txt, about o lo que sea). Espero que ayude... (No vi la fuente C++ para NTT rápidos en ninguna parte, así que tuve que escribirla yo mismo). Se probaron raíces de unidad para todos los N aceptados, consulte la fourier_NTT::init(DWORD n)
función.
PD: Para obtener más información sobre NTT, consulte Traducción de FFT compleja a FFT de campo finito . Este código se basa en mis publicaciones dentro de ese enlace.
[editar1:] Más cambios en el código
Logré optimizar aún más mi aritmética modular, explotando que el módulo principal siempre es 0xC0000001 y eliminando llamadas innecesarias. La aceleración resultante es sorprendente (más de 40 veces) ahora y la multiplicación de NTT es más rápida que karatsuba después del umbral de 1500 * 32 bits. Por cierto, la velocidad de mi NTT ahora es la misma que la de mi DFFT optimizado en dobles de 64 bits.
Algunas medidas:
a = 0.98765588997654321000 | 1553*32bits
looped 10x times
mul2[ 28.585 ms ] karatsuba mul
mul3[ 26.311 ms ] NTT mul
Nuevo código fuente para aritmética modular:
//---------------------------------------------------------------------------
DWORD fourier_NTT::mod(DWORD a)
{
if (a>p) a-=p;
return a;
}
//---------------------------------------------------------------------------
DWORD fourier_NTT::modadd(DWORD a,DWORD b)
{
DWORD d,cy;
if (a>p) a-=p;
if (b>p) b-=p;
d=a+b;
cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000;
if (cy ) d-=p;
if (d>p) d-=p;
return d;
}
//---------------------------------------------------------------------------
DWORD fourier_NTT::modsub(DWORD a,DWORD b)
{
DWORD d;
if (a>p) a-=p;
if (b>p) b-=p;
d=a-b;
if (a<b) d+=p;
if (d>p) d-=p;
return d;
}
//---------------------------------------------------------------------------
DWORD fourier_NTT::modmul(DWORD a,DWORD b)
{
DWORD _a,_b,_p;
_a=a;
_b=b;
_p=p;
asm {
mov eax,_a
mov ebx,_b
mul ebx // H(edx),L(eax) = eax * ebx
mov ebx,_p
div ebx // eax = H(edx),L(eax) / ebx
mov _a,edx // edx = H(edx),L(eax) % ebx
}
return _a;
}
//---------------------------------------------------------------------------
DWORD fourier_NTT::modpow(DWORD a,DWORD b)
{ // b bez orezania!
int i;
DWORD d=1;
if (a>p) a-=p;
for (i=0;i<32;i++)
{
d=modmul(d,d);
if (DWORD(b&0x80000000)) d=modmul(d,a);
b<<=1;
}
return d;
}
//---------------------------------------------------------------------------
Como puede ver, las funciones shl
y shr
ya no se utilizan. Creo que modpow se puede optimizar aún más, pero no es una función crítica porque solo se llama muy pocas veces. La función más crítica es modmul, y parece estar en la mejor forma posible.
Mas preguntas:
- ¿Existe alguna otra opción para acelerar NTT?
- ¿Son seguras mis optimizaciones de aritmética modular? (Los resultados parecen ser los mismos, pero podría perderme algo).
[editar2] Nuevas optimizaciones
a = 0.99991970486 | 2000*32 bits
looped 10x
sqr1[ 13.908 ms ] fast sqr
sqr2[ 13.649 ms ] NTT sqr
mul1[ 19.726 ms ] simpe mul
mul2[ 31.808 ms ] karatsuba mul
mul3[ 19.373 ms ] NTT mul
Implementé todas las cosas utilizables de todos sus comentarios (gracias por la información).
Aceleraciones:
- +2,5 % al eliminar modificaciones de seguridad innecesarias (Mandalf The Beige)
- +34,9% por el uso de poderes W,iW precalculados (Mysticial)
- +35% total
Código fuente completo real:
//---------------------------------------------------------------------------
//--- Number theoretic transforms: 2.03 -------------------------------------
//---------------------------------------------------------------------------
#ifndef _fourier_NTT_h
#define _fourier_NTT_h
//---------------------------------------------------------------------------
//---------------------------------------------------------------------------
class fourier_NTT // Number theoretic transform
{
public:
DWORD r,L,p,N;
DWORD W,iW,rN; // W=(r^L) mod p, iW=inverse W, rN = inverse N
DWORD *WW,*iWW,NN; // Precomputed (W,iW)^(0,..,NN-1) powers
// Internals
fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; WW=NULL; iWW=NULL; NN=0; }
~fourier_NTT(){ _free(); }
void _free(); // Free precomputed W,iW powers tables
void _alloc(DWORD n); // Allocate and precompute W,iW powers tables
// Main interface
void NTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast NTT(DWORD src[n])
void iNTT(DWORD *dst,DWORD *src,DWORD n=0); // DWORD dst[n] = fast INTT(DWORD src[n])
// Helper functions
bool init(DWORD n); // init r,L,p,W,iW,rN
void NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = fast NTT(DWORD src[n])
void NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD *w2,DWORD i2);
// Only for testing
void NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow NTT(DWORD src[n])
void iNTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w); // DWORD dst[n] = slow INTT(DWORD src[n])
// Modular arithmetics (optimized, but it works only for p >= 0x80000000!!!)
DWORD mod(DWORD a);
DWORD modadd(DWORD a,DWORD b);
DWORD modsub(DWORD a,DWORD b);
DWORD modmul(DWORD a,DWORD b);
DWORD modpow(DWORD a,DWORD b);
};
//---------------------------------------------------------------------------
//---------------------------------------------------------------------------
void fourier_NTT::_free()
{
NN=0;
if ( WW) delete[] WW; WW=NULL;
if (iWW) delete[] iWW; iWW=NULL;
}
//---------------------------------------------------------------------------
void fourier_NTT::_alloc(DWORD n)
{
if (n<=NN) return;
DWORD *tmp,i,w;
tmp=new DWORD[n]; if ((NN)&&( WW)) for (i=0;i<NN;i++) tmp[i]= WW[i]; if ( WW) delete[] WW; WW=tmp; WW[0]=1; for (i=NN?NN:1,w= WW[i-1];i<n;i++){ w=modmul(w, W); WW[i]=w; }
tmp=new DWORD[n]; if ((NN)&&(iWW)) for (i=0;i<NN;i++) tmp[i]=iWW[i]; if (iWW) delete[] iWW; iWW=tmp; iWW[0]=1; for (i=NN?NN:1,w=iWW[i-1];i<n;i++){ w=modmul(w,iW); iWW[i]=w; }
NN=n;
}
//---------------------------------------------------------------------------
void fourier_NTT:: NTT(DWORD *dst,DWORD *src,DWORD n)
{
if (n>0) init(n);
NTT_fast(dst,src,N,WW,1);
// NTT_fast(dst,src,N,W);
// NTT_slow(dst,src,N,W);
}
//---------------------------------------------------------------------------
void fourier_NTT::iNTT(DWORD *dst,DWORD *src,DWORD n)
{
if (n>0) init(n);
NTT_fast(dst,src,N,iWW,1);
// NTT_fast(dst,src,N,iW);
for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN);
// iNTT_slow(dst,src,N,W);
}
//---------------------------------------------------------------------------
bool fourier_NTT::init(DWORD n)
{
// (max(src[])^2)*n < p else NTT overflow can ocur!!!
r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit
// r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
// r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
// r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
N=n; // Size of vectors [DWORDs]
W=modpow(r, L); // Wn for NTT
iW=modpow(r,p-1-L); // Wn for INTT
rN=modpow(n,p-2 ); // Scale for INTT
_alloc(n>>1); // Precompute W,iW powers
return true;
}
//---------------------------------------------------------------------------
void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w)
{
if (n<=1) { if (n==1) dst[0]=src[0]; return; }
DWORD i,j,a0,a1,n2=n>>1,w2=modmul(w,w);
// Reorder even,odd
for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
for ( j=1;i<n ;i++,j+=2) dst[i]=src[j];
// Recursion
NTT_fast(src ,dst ,n2,w2); // Even
NTT_fast(src+n2,dst+n2,n2,w2); // Odd
// Restore results
for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w))
{
a0=src[i];
a1=modmul(src[j],w2);
dst[i]=modadd(a0,a1);
dst[j]=modsub(a0,a1);
}
}
//---------------------------------------------------------------------------
void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD *w2,DWORD i2)
{
if (n<=1) { if (n==1) dst[0]=src[0]; return; }
DWORD i,j,a0,a1,n2=n>>1;
// Reorder even,odd
for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
for ( j=1;i<n ;i++,j+=2) dst[i]=src[j];
// Recursion
i=i2<<1;
NTT_fast(src ,dst ,n2,w2,i); // Even
NTT_fast(src+n2,dst+n2,n2,w2,i); // Odd
// Restore results
for (i=0,j=n2;i<n2;i++,j++,w2+=i2)
{
a0=src[i];
a1=modmul(src[j],*w2);
dst[i]=modadd(a0,a1);
dst[j]=modsub(a0,a1);
}
}
//---------------------------------------------------------------------------
void fourier_NTT:: NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
{
DWORD i,j,wj,wi,a;
for (wj=1,j=0;j<n;j++)
{
a=0;
for (wi=1,i=0;i<n;i++)
{
a=modadd(a,modmul(wi,src[i]));
wi=modmul(wi,wj);
}
dst[j]=a;
wj=modmul(wj,w);
}
}
//---------------------------------------------------------------------------
void fourier_NTT::iNTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
{
DWORD i,j,wi=1,wj=1,a;
for (wj=1,j=0;j<n;j++)
{
a=0;
for (wi=1,i=0;i<n;i++)
{
a=modadd(a,modmul(wi,src[i]));
wi=modmul(wi,wj);
}
dst[j]=modmul(a,rN);
wj=modmul(wj,iW);
}
}
//---------------------------------------------------------------------------
DWORD fourier_NTT::mod(DWORD a)
{
if (a>p) a-=p;
return a;
}
//---------------------------------------------------------------------------
DWORD fourier_NTT::modadd(DWORD a,DWORD b)
{
DWORD d,cy;
//if (a>p) a-=p;
//if (b>p) b-=p;
d=a+b;
cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000;
if (cy ) d-=p;
if (d>p) d-=p;
return d;
}
//---------------------------------------------------------------------------
DWORD fourier_NTT::modsub(DWORD a,DWORD b)
{
DWORD d;
//if (a>p) a-=p;
//if (b>p) b-=p;
d=a-b;
if (a<b) d+=p;
if (d>p) d-=p;
return d;
}
//---------------------------------------------------------------------------
DWORD fourier_NTT::modmul(DWORD a,DWORD b)
{
DWORD _a,_b,_p;
_a=a;
_b=b;
_p=p;
asm {
mov eax,_a
mov ebx,_b
mul ebx // H(edx),L(eax) = eax * ebx
mov ebx,_p
div ebx // eax = H(edx),L(eax) / ebx
mov _a,edx // edx = H(edx),L(eax) % ebx
}
return _a;
}
//---------------------------------------------------------------------------
DWORD fourier_NTT::modpow(DWORD a,DWORD b)
{ // b is not mod(p)!
int i;
DWORD d=1;
//if (a>p) a-=p;
for (i=0;i<32;i++)
{
d=modmul(d,d);
if (DWORD(b&0x80000000)) d=modmul(d,a);
b<<=1;
}
return d;
}
//---------------------------------------------------------------------------
//---------------------------------------------------------------------------
#endif
//---------------------------------------------------------------------------
//---------------------------------------------------------------------------
Todavía existe la posibilidad de utilizar menos basura del montón separándola NTT_fast
en dos funciones. Uno con WW[]
y el otro con iWW[]
que lleva a un parámetro menos en llamadas recursivas. Pero no espero mucho de él (sólo puntero de 32 bits) y prefiero tener una función para una mejor gestión del código en el futuro. Muchas funciones están inactivas ahora (para pruebas), como variantes lentas mod
y la función rápida más antigua (con w
parámetro en lugar de *w2,i2
).
Para evitar desbordamientos en grandes conjuntos de datos, limite los números de entrada a p/4
bits. ¿Dónde p
está el número de bits por elemento NTT , por lo que para esta versión de 32 bits utilice (32 bit/4 -> 8 bit)
valores de entrada máximos?
[editar3] Multiplicación de cadenas simple bigint
para pruebas
//---------------------------------------------------------------------------
char* mul_NTT(const char *sx,const char *sy)
{
char *s;
int i,j,k,n;
// n = min power of 2 <= 2 max length(x,y)
for (i=0;sx[i];i++); for (n=1;n<i;n<<=1); i--;
for (j=0;sx[j];j++); for (n=1;n<j;n<<=1); n<<=1; j--;
DWORD *x,*y,*xx,*yy,a;
x=new DWORD[n]; xx=new DWORD[n];
y=new DWORD[n]; yy=new DWORD[n];
// Zero padding
for (k=0;i>=0;i--,k++) x[k]=sx[i]-'0'; for (;k<n;k++) x[k]=0;
for (k=0;j>=0;j--,k++) y[k]=sy[j]-'0'; for (;k<n;k++) y[k]=0;
//NTT
fourier_NTT ntt;
ntt.NTT(xx,x,n);
ntt.NTT(yy,y);
// Convolution
for (i=0;i<n;i++) xx[i]=ntt.modmul(xx[i],yy[i]);
//INTT
ntt.iNTT(yy,xx);
//suma
a=0; s=new char[n+1]; for (i=0;i<n;i++) { a+=yy[i]; s[n-i-1]=(a%10)+'0'; a/=10; } s[n]=0;
delete[] x; delete[] xx;
delete[] y; delete[] yy;
return s;
}
//---------------------------------------------------------------------------
Yo uso AnsiString
's, así que lo transfiero a char*
, con suerte, no cometí ningún error. Parece que funciona correctamente (en comparación con la AnsiString
versión).
sx,sy
son números enteros decádicos- Devuelve la cadena asignada
(char*)=sx*sy
Esto es solo ~4 bits por palabra de datos de 32 bits, por lo que no hay riesgo de desbordamiento, pero, por supuesto, es más lento. En mi bignum
biblioteca utilizo una representación binaria y uso 8 bit
fragmentos por WORD de 32 bits para NTT . Más que eso es arriesgado si N
es grande...
Diviértete con esto
En primer lugar, muchas gracias por publicarlo y hacerlo de uso gratuito. Realmente aprecio eso.
Pude usar algunos trucos para eliminar algunas ramificaciones, reorganicé el bucle principal, modifiqué el ensamblaje y pude obtener una aceleración de 1,35x.
Además, agregué una condición de preprocesador para 64 bits, ya que Visual Studio no permite el ensamblaje en línea en modo de 64 bits (gracias Microsoft; siéntete libre de arruinarte).
Algo extraño sucedió cuando estaba optimizando la función modsub(). Lo reescribí usando trucos de bits como lo hice con modadd (que fue más rápido). Pero por alguna razón, la versión bit de modsub era más lenta. No estoy seguro de por qué. Podría ser solo mi computadora.
//
// Mandalf The Beige
// Based on:
// Spektre
// http://stackoverflow.com/questions/18577076/modular-arithmetics-and-ntt-finite-field-dft-optimizations
//
// This code may be freely used however you choose, so long as it is accompanied by this notice.
//
#ifndef H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR
#define H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR
#include <string.h>
#ifndef uint32
#define uint32 unsigned long int
#endif
#ifndef uint64
#define uint64 unsigned long long int
#endif
class fast_ntt // number theoretic transform
{
public:
fast_ntt()
{
r = 0; L = 0;
W = 0; iW = 0; rN = 0;
}
// main interface
void NTT(uint32 *dst, uint32 *src, uint32 n = 0); // uint32 dst[n] = fast NTT(uint32 src[n])
void INTT(uint32 *dst, uint32 *src, uint32 n = 0); // uint32 dst[n] = fast INTT(uint32 src[n])
// helper functions
private:
bool init(uint32 n); // init r,L,p,W,iW,rN
void NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = fast NTT(uint32 src[n])
void NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = fast NTT(uint32 src[n])
void NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w);
// only for testing
void NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = slow NTT(uint32 src[n])
void INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = slow INTT(uint32 src[n])
// uint32 arithmetics
// modular arithmetics
inline uint32 modadd(uint32 a, uint32 b);
inline uint32 modsub(uint32 a, uint32 b);
inline uint32 modmul(uint32 a, uint32 b);
inline uint32 modpow(uint32 a, uint32 b);
uint32 r, L, N;//, p;
uint32 W, iW, rN;
const uint32 p = 0xC0000001;
};
//---------------------------------------------------------------------------
void fast_ntt::NTT(uint32 *dst, uint32 *src, uint32 n)
{
if (n > 0)
{
init(n);
}
NTT_fast(dst, src, N, W);
// NTT_slow(dst,src,N,W);
}
//---------------------------------------------------------------------------
void fast_ntt::INTT(uint32 *dst, uint32 *src, uint32 n)
{
if (n > 0)
{
init(n);
}
NTT_fast(dst, src, N, iW);
for (uint32 i = 0; i<N; i++)
{
dst[i] = modmul(dst[i], rN);
}
// INTT_slow(dst,src,N,W);
}
//---------------------------------------------------------------------------
bool fast_ntt::init(uint32 n)
{
// (max(src[])^2)*n < p else NTT overflow can ocur !!!
r = 2;
//p = 0xC0000001;
if ((n < 2) || (n > 0x10000000))
{
r = 0; L = 0; W = 0; // p = 0;
iW = 0; rN = 0; N = 0;
return false;
}
L = 0x30000000 / n; // 32:30 bit best for unsigned 32 bit
// r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
// r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
// r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
N = n; // size of vectors [uint32s]
W = modpow(r, L); // Wn for NTT
iW = modpow(r, p - 1 - L); // Wn for INTT
rN = modpow(n, p - 2); // scale for INTT
return true;
}
//---------------------------------------------------------------------------
void fast_ntt::NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
if(n > 1)
{
if(dst != src)
{
NTT_calc(dst, src, n, w);
}
else
{
uint32* temp = new uint32[n];
NTT_calc(temp, src, n, w);
memcpy(dst, temp, n * sizeof(uint32));
delete [] temp;
}
}
else if(n == 1)
{
dst[0] = src[0];
}
}
void fast_ntt::NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w)
{
if (n > 1)
{
uint32* temp = new uint32[n];
memcpy(temp, src, n * sizeof(uint32));
NTT_calc(dst, temp, n, w);
delete[] temp;
}
else if (n == 1)
{
dst[0] = src[0];
}
}
void fast_ntt::NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
if(n > 1)
{
uint32 i, j, a0, a1,
n2 = n >> 1,
w2 = modmul(w, w);
// reorder even,odd
for (i = 0, j = 0; i < n2; i++, j += 2)
{
dst[i] = src[j];
}
for (j = 1; i < n; i++, j += 2)
{
dst[i] = src[j];
}
// recursion
if(n2 > 1)
{
NTT_calc(src, dst, n2, w2); // even
NTT_calc(src + n2, dst + n2, n2, w2); // odd
}
else if(n2 == 1)
{
src[0] = dst[0];
src[1] = dst[1];
}
// restore results
w2 = 1, i = 0, j = n2;
a0 = src[i];
a1 = src[j];
dst[i] = modadd(a0, a1);
dst[j] = modsub(a0, a1);
while (++i < n2)
{
w2 = modmul(w2, w);
j++;
a0 = src[i];
a1 = modmul(src[j], w2);
dst[i] = modadd(a0, a1);
dst[j] = modsub(a0, a1);
}
}
}
//---------------------------------------------------------------------------
void fast_ntt::NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
uint32 i, j, wj, wi, a,
n2 = n >> 1;
for (wj = 1, j = 0; j < n; j++)
{
a = 0;
for (wi = 1, i = 0; i < n; i++)
{
a = modadd(a, modmul(wi, src[i]));
wi = modmul(wi, wj);
}
dst[j] = a;
wj = modmul(wj, w);
}
}
//---------------------------------------------------------------------------
void fast_ntt::INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
uint32 i, j, wi = 1, wj = 1, a, n2 = n >> 1;
for (wj = 1, j = 0; j < n; j++)
{
a = 0;
for (wi = 1, i = 0; i < n; i++)
{
a = modadd(a, modmul(wi, src[i]));
wi = modmul(wi, wj);
}
dst[j] = modmul(a, rN);
wj = modmul(wj, iW);
}
}
//---------------------------------------------------------------------------
uint32 fast_ntt::modadd(uint32 a, uint32 b)
{
uint32 d;
d = a + b;
if(d < a)
{
d -= p;
}
if (d >= p)
{
d -= p;
}
return d;
}
//---------------------------------------------------------------------------
uint32 fast_ntt::modsub(uint32 a, uint32 b)
{
uint32 d;
d = a - b;
if (d > a)
{
d += p;
}
return d;
}
//---------------------------------------------------------------------------
uint32 fast_ntt::modmul(uint32 a, uint32 b)
{
uint32 _a = a;
uint32 _b = b;
// Original
uint32 _p = p;
__asm
{
mov eax, _a;
mul _b;
div _p;
mov eax, edx;
};
}
uint32 fast_ntt::modpow(uint32 a, uint32 b)
{
//*
uint64 D, M, A, P;
P = p; A = a;
M = 0llu - (b & 1);
D = (M & A) | ((~M) & 1);
while ((b >>= 1) != 0)
{
A = modmul(A, A);
//A = (A * A) % P;
if ((b & 1) == 1)
{
//D = (D * A) % P;
D = modmul(D, A);
}
}
return (uint32)D;
}
Nuevo módulo
uint32 fast_ntt::modmul(uint32 a, uint32 b)
{
uint32 _a = a;
uint32 _b = b;
__asm
{
mov eax, a;
mul b;
mov ebx, eax;
mov eax, 2863311530;
mov ecx, edx;
mul edx;
shld edx, eax, 1;
mov eax, 3221225473;
mul edx;
sub ebx, eax;
mov eax, 3221225473;
sbb ecx, edx;
jc addback;
neg ecx;
and ecx, eax;
sub ebx, ecx;
sub ebx, eax;
sbb edx, edx;
and eax, edx;
addback:
add eax, ebx;
};
}
[EDITAR] Spektre, según tus comentarios, cambié modadd y modsub a su original. También me di cuenta de que hice algunos cambios en la función NTT recursiva que no debería haber hecho.
[EDITAR2] Se eliminaron declaraciones if innecesarias y funciones bit a bit.
[EDITAR3] Se agregó un nuevo ensamblaje en línea de modmul.