00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00021 
00022 
00023 #if defined(VMDCPUDISPATCH) && defined(VMDUSEAVX512) 
00024 
00025 #include <immintrin.h>
00026 
00027 #include <math.h>
00028 #include <stdio.h>
00029 #include "Orbital.h"
00030 #include "DrawMolecule.h"
00031 #include "utilities.h"
00032 #include "Inform.h"
00033 #include "WKFThreads.h"
00034 #include "WKFUtils.h"
00035 #include "ProfileHooks.h"
00036 
00037 #define ANGS_TO_BOHR 1.88972612478289694072f
00038 
00039 #if defined(__GNUC__) && ! defined(__INTEL_COMPILER)
00040 #define __align(X)  __attribute__((aligned(X) ))
00041 #else
00042 #define __align(X) __declspec(align(X) )
00043 #endif
00044 
00045 #define MLOG2EF    -1.44269504088896f
00046 
00047 #if 0
00048 static void print_mm512_ps(__m512 v) {
00049   __attribute__((aligned(64))) float tmp[16]; 
00050   _mm512_storeu_ps(&tmp[0], v);
00051 
00052   printf("mm512: ");
00053   int i;
00054   for (i=0; i<16; i++) 
00055     printf("%g ", tmp[i]);
00056   printf("\n");
00057 }
00058 #endif
00059 
00060 
00061 
00062 
00063 
00064 
00065 int evaluate_grid_avx512er(int numatoms,
00066                            const float *wave_f, const float *basis_array,
00067                            const float *atompos,
00068                            const int *atom_basis,
00069                            const int *num_shells_per_atom,
00070                            const int *num_prim_per_shell,
00071                            const int *shell_types,
00072                            const int *numvoxels,
00073                            float voxelsize,
00074                            const float *origin,
00075                            int density,
00076                            float * orbitalgrid) {
00077   if (!orbitalgrid)
00078     return -1;
00079 
00080   int nx, ny, nz;
00081   __attribute__((aligned(64))) float sxdelta[16]; 
00082   for (nx=0; nx<16; nx++) 
00083     sxdelta[nx] = ((float) nx) * voxelsize * ANGS_TO_BOHR;
00084 
00085   
00086   
00087   int numgridxy = numvoxels[0]*numvoxels[1];
00088   for (nz=0; nz<numvoxels[2]; nz++) {
00089     float grid_x, grid_y, grid_z;
00090     grid_z = origin[2] + nz * voxelsize;
00091     for (ny=0; ny<numvoxels[1]; ny++) {
00092       grid_y = origin[1] + ny * voxelsize;
00093       int gaddrzy = ny*numvoxels[0] + nz*numgridxy;
00094       for (nx=0; nx<numvoxels[0]; nx+=16) {
00095         grid_x = origin[0] + nx * voxelsize;
00096 
00097         
00098         
00099         int at;
00100         int prim, shell;
00101 
00102         
00103         __m512 value = _mm512_set1_ps(0.0f);
00104 
00105         
00106         int ifunc = 0; 
00107         int shell_counter = 0;
00108 
00109         
00110         for (at=0; at<numatoms; at++) {
00111           int maxshell = num_shells_per_atom[at];
00112           int prim_counter = atom_basis[at];
00113 
00114           
00115           float sxdist = (grid_x - atompos[3*at  ])*ANGS_TO_BOHR;
00116           float sydist = (grid_y - atompos[3*at+1])*ANGS_TO_BOHR;
00117           float szdist = (grid_z - atompos[3*at+2])*ANGS_TO_BOHR;
00118 
00119           float sydist2 = sydist*sydist;
00120           float szdist2 = szdist*szdist;
00121           float yzdist2 = sydist2 + szdist2;
00122 
00123           __m512 xdelta = _mm512_load_ps(&sxdelta[0]); 
00124           __m512 xdist  = _mm512_set1_ps(sxdist);
00125           xdist = _mm512_add_ps(xdist, xdelta);
00126           __m512 ydist  = _mm512_set1_ps(sydist);
00127           __m512 zdist  = _mm512_set1_ps(szdist);
00128           __m512 xdist2 = _mm512_mul_ps(xdist, xdist);
00129           __m512 ydist2 = _mm512_mul_ps(ydist, ydist);
00130           __m512 zdist2 = _mm512_mul_ps(zdist, zdist);
00131           __m512 dist2  = _mm512_set1_ps(yzdist2); 
00132           dist2 = _mm512_add_ps(dist2, xdist2);
00133  
00134           
00135           
00136           
00137           
00138           
00139           for (shell=0; shell < maxshell; shell++) {
00140             __m512 contracted_gto = _mm512_set1_ps(0.0f);
00141 
00142             
00143             
00144             
00145             
00146             
00147             
00148             
00149             
00150             int maxprim = num_prim_per_shell[shell_counter];
00151             int shelltype = shell_types[shell_counter];
00152             for (prim=0; prim<maxprim; prim++) {
00153               
00154               float exponent       = -basis_array[prim_counter    ];
00155               float contract_coeff =  basis_array[prim_counter + 1];
00156 
00157               
00158 #if 1
00159               __m512 expval = _mm512_mul_ps(_mm512_set1_ps(-exponent * MLOG2EF), dist2);
00160               
00161               __m512 retval = _mm512_exp2a23_ps(expval);
00162               contracted_gto = _mm512_fmadd_ps(_mm512_set1_ps(contract_coeff), retval, contracted_gto);
00163 #else
00164               __m512 expval = _mm512_mul_ps(_mm512_set1_ps(-exponent), dist2);
00165               
00166               expval = _mm512_mul_ps(expval, _mm512_set1_ps(MLOG2EF));
00167               __m512 retval = _mm512_exp2a23_ps(expval);
00168               __m512 ctmp = _mm512_mul_ps(_mm512_set1_ps(contract_coeff), retval);
00169               contracted_gto = _mm512_add_ps(contracted_gto, ctmp);
00170 #endif
00171 
00172               prim_counter += 2;
00173             }
00174 
00175             
00176             __m512 tmpshell = _mm512_set1_ps(0.0f);
00177             switch (shelltype) {
00178               
00179               case S_SHELL:
00180                 value = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), contracted_gto, value);
00181                 break;
00182 
00183               case P_SHELL:
00184                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), xdist, tmpshell);
00185                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), ydist, tmpshell);
00186                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), zdist, tmpshell);
00187                 value = _mm512_fmadd_ps(tmpshell, contracted_gto, value);
00188                 break;
00189 
00190               case D_SHELL:
00191                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), xdist2, tmpshell);
00192                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist, ydist), tmpshell);
00193                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), ydist2, tmpshell);
00194                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist, zdist), tmpshell);
00195                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist, zdist), tmpshell);
00196                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), zdist2, tmpshell);
00197                 value = _mm512_fmadd_ps(tmpshell, contracted_gto, value);
00198                 break;
00199 
00200               case F_SHELL:
00201                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist2, xdist), tmpshell);
00202                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist2, ydist), tmpshell);
00203                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist2, xdist), tmpshell);
00204                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist2, ydist), tmpshell);
00205                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist2, zdist), tmpshell);
00206                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(_mm512_mul_ps(xdist, ydist), zdist), tmpshell);
00207                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist2, zdist), tmpshell);
00208                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(zdist2, xdist), tmpshell);
00209                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(zdist2, ydist), tmpshell);
00210                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(zdist2, zdist), tmpshell);
00211                 value = _mm512_fmadd_ps(tmpshell, contracted_gto, value);
00212                 break;
00213 
00214  
00215 #if 0
00216               default:
00217                 
00218                 int i, j; 
00219                 float xdp, ydp, zdp;
00220                 float xdiv = 1.0f / xdist;
00221                 for (j=0, zdp=1.0f; j<=shelltype; j++, zdp*=zdist) {
00222                   int imax = shelltype - j; 
00223                   for (i=0, ydp=1.0f, xdp=pow(xdist, imax); i<=imax; i++, ydp*=ydist, xdp*=xdiv) {
00224                     tmpshell += wave_f[ifunc++] * xdp * ydp * zdp;
00225                   }
00226                 }
00227                 value += tmpshell * contracted_gto;
00228 #endif
00229             } 
00230 
00231             shell_counter++;
00232           } 
00233         } 
00234 
00235         
00236         if (density) {
00237           __mmask16 mask = _mm512_cmplt_ps_mask(value, _mm512_set1_ps(0.0f));
00238           __m512 sqdensity = _mm512_mul_ps(value, value);
00239           __m512 orbdensity = _mm512_mask_mul_ps(sqdensity, mask, sqdensity,
00240                                                  _mm512_set1_ps(-1.0f));
00241           _mm512_storeu_ps(&orbitalgrid[gaddrzy + nx], orbdensity);
00242         } else {
00243           _mm512_storeu_ps(&orbitalgrid[gaddrzy + nx], value);
00244         }
00245       }
00246     }
00247   }
00248 
00249   
00250   
00251   
00252   
00253   
00254   
00255   
00256   
00257   _mm256_zeroupper();
00258 
00259   return 0;
00260 }
00261 
00262 #endif
00263 
00264