00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00022 #include <stdio.h>
00023 #include <stdlib.h>
00024 #include <string.h>
00025 #include <math.h>
00026 #include "WKFThreads.h"
00027 #include "OrbitalJIT.h"
00028 
00029 
00030 
00031 
00032 
00033 #define ANGS_TO_BOHR 1.8897259877218677f
00034 
00035 
00036 #define UNROLLX       1
00037 #define UNROLLY       1
00038 #define BLOCKSIZEX    8 
00039 #define BLOCKSIZEY    8
00040 #define BLOCKSIZE     BLOCKSIZEX * BLOCKSIZEY
00041 
00042 
00043 #define TILESIZEX BLOCKSIZEX*UNROLLX
00044 #define TILESIZEY BLOCKSIZEY*UNROLLY
00045 #define GPU_X_ALIGNMASK (TILESIZEX - 1)
00046 #define GPU_Y_ALIGNMASK (TILESIZEY - 1)
00047 
00048 #define MEMCOALESCE  384
00049 
00050 
00051 #define S_SHELL 0
00052 #define P_SHELL 1
00053 #define D_SHELL 2
00054 #define F_SHELL 3
00055 #define G_SHELL 4
00056 #define H_SHELL 5
00057 
00058 
00059 
00060 
00061 #define MAX_ATOM_SZ 256
00062 
00063 #define MAX_ATOMPOS_SZ (MAX_ATOM_SZ)
00064 
00065 
00066 #define MAX_ATOM_BASIS_SZ (MAX_ATOM_SZ)
00067 
00068 
00069 #define MAX_ATOMSHELL_SZ (MAX_ATOM_SZ)
00070 
00071 
00072 #define MAX_BASIS_SZ 6144 
00073 
00074 
00075 #define MAX_SHELL_SZ 1024
00076 
00077 
00078 
00079 #define MAX_WAVEF_SZ 6144
00080 
00081 
00082 
00083 
00084 
00085 
00086 int orbital_jit_generate(int jitlanguage,
00087                          const char * srcfilename, int numatoms,
00088                          const float *wave_f, const float *basis_array,
00089                          const int *atom_basis,
00090                          const int *num_shells_per_atom,
00091                          const int *num_prim_per_shell,
00092                          const int *shell_types) {
00093   FILE *ofp=NULL;
00094   if (srcfilename) 
00095     ofp=fopen(srcfilename, "w");
00096 
00097   if (ofp == NULL)
00098     ofp=stdout; 
00099 
00100   
00101   
00102   int at;
00103   int prim, shell;
00104 
00105   
00106   int shell_counter = 0;
00107 
00108   if (jitlanguage == ORBITAL_JIT_CUDA) {
00109     fprintf(ofp, 
00110       "__global__ static void cuorbitalconstmem_jit(int numatoms,\n"
00111       "                          float voxelsize,\n"
00112       "                          float originx,\n"
00113       "                          float originy,\n"
00114       "                          float grid_z, \n"
00115       "                          int density, \n"
00116       "                          float * orbitalgrid) {\n"
00117       "  unsigned int xindex  = blockIdx.x * blockDim.x + threadIdx.x;\n"
00118       "  unsigned int yindex  = blockIdx.y * blockDim.y + threadIdx.y;\n"
00119       "  unsigned int outaddr = gridDim.x * blockDim.x * yindex + xindex;\n"
00120     );
00121   } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00122     fprintf(ofp, 
00123       "// unit conversion                                                  \n"
00124       "#define ANGS_TO_BOHR 1.8897259877218677f                            \n"
00125     );
00126 
00127     fprintf(ofp, "__kernel __attribute__((reqd_work_group_size(%d, %d, 1)))\n",
00128             BLOCKSIZEX, BLOCKSIZEY);
00129 
00130     fprintf(ofp, 
00131       "void clorbitalconstmem_jit(int numatoms,                            \n"
00132       "                       __constant float *const_atompos,             \n"
00133       "                       __constant float *const_wave_f,              \n"
00134       "                       float voxelsize,                             \n"
00135       "                       float originx,                               \n"
00136       "                       float originy,                               \n"
00137       "                       float grid_z,                                \n"
00138       "                       int density,                                 \n"
00139       "                       __global float * orbitalgrid) {              \n"
00140       "  unsigned int xindex  = get_global_id(0);                          \n"
00141       "  unsigned int yindex  = get_global_id(1);                          \n"
00142       "  unsigned int outaddr = get_global_size(0) * yindex + xindex;      \n"
00143     );
00144   }
00145 
00146   fprintf(ofp, 
00147     "  float grid_x = originx + voxelsize * xindex;\n"
00148     "  float grid_y = originy + voxelsize * yindex;\n"
00149  
00150     "  // similar to C version\n"
00151     "  int at;\n"
00152     "  // initialize value of orbital at gridpoint\n"
00153     "  float value = 0.0f;\n"
00154     "  // initialize the wavefunction and shell counters\n"
00155     "  int ifunc = 0;\n"
00156     "  // loop over all the QM atoms\n"
00157     "  for (at = 0; at < numatoms; at++) {\n"
00158     "    // calculate distance between grid point and center of atom\n"
00159 
00160 
00161     "    float xdist = (grid_x - const_atompos[3*at  ])*ANGS_TO_BOHR;\n"
00162     "    float ydist = (grid_y - const_atompos[3*at+1])*ANGS_TO_BOHR;\n"
00163     "    float zdist = (grid_z - const_atompos[3*at+2])*ANGS_TO_BOHR;\n"
00164     "    float xdist2 = xdist*xdist;\n"
00165     "    float ydist2 = ydist*ydist;\n"
00166     "    float zdist2 = zdist*zdist;\n"
00167     "    float dist2 = xdist2 + ydist2 + zdist2;\n"
00168     "    float contracted_gto=0.0f;\n"
00169     "    float tmpshell=0.0f;\n"
00170     "\n"
00171   );
00172 
00173 #if 0
00174   
00175   for (at=0; at<numatoms; at++) {
00176 #else
00177   
00178   for (at=0; at<1; at++) {
00179 #endif
00180     int maxshell = num_shells_per_atom[at];
00181     int prim_counter = atom_basis[at];
00182 
00183     
00184     for (shell=0; shell < maxshell; shell++) {
00185       
00186       
00187       int maxprim = num_prim_per_shell[shell_counter];
00188       int shelltype = shell_types[shell_counter];
00189       for (prim=0; prim<maxprim; prim++) {
00190         float exponent       = basis_array[prim_counter    ];
00191         float contract_coeff = basis_array[prim_counter + 1];
00192 #if 1
00193         if (jitlanguage == ORBITAL_JIT_CUDA) {
00194           if (prim == 0) {
00195             fprintf(ofp,"    contracted_gto = %ff * exp2f(-%ff*dist2);\n",
00196                     contract_coeff, exponent);
00197           } else {
00198             fprintf(ofp,"    contracted_gto += %ff * exp2f(-%ff*dist2);\n",
00199                     contract_coeff, exponent);
00200           }
00201         } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00202           if (prim == 0) {
00203             fprintf(ofp,"    contracted_gto = %ff * native_exp2(-%ff*dist2);\n",
00204                     contract_coeff, exponent);
00205           } else {
00206             fprintf(ofp,"    contracted_gto += %ff * native_exp2(-%ff*dist2);\n",
00207                     contract_coeff, exponent);
00208           }
00209         }
00210 #else
00211         if (jitlanguage == ORBITAL_JIT_CUDA) {
00212           if (prim == 0) {
00213             fprintf(ofp,"    contracted_gto = %ff * expf(-%ff*dist2);\n",
00214                     contract_coeff, exponent);
00215           } else {
00216             fprintf(ofp,"    contracted_gto += %ff * expf(-%ff*dist2);\n",
00217                     contract_coeff, exponent);
00218           }
00219         } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00220           if (prim == 0) {
00221             fprintf(ofp,"    contracted_gto = %ff * native_exp(-%ff*dist2);\n",
00222                     contract_coeff, exponent);
00223           } else {
00224             fprintf(ofp,"    contracted_gto += %ff * native_exp(-%ff*dist2);\n",
00225                     contract_coeff, exponent);
00226           }
00227         }
00228 #endif
00229         prim_counter += 2;
00230       }
00231 
00232       
00233       switch (shelltype) {
00234         case S_SHELL:
00235           fprintf(ofp, 
00236             "    // S_SHELL\n"
00237             "    value += const_wave_f[ifunc++] * contracted_gto;\n");
00238           break;
00239 
00240         case P_SHELL:
00241           fprintf(ofp,
00242             "    // P_SHELL\n"
00243             "    tmpshell = const_wave_f[ifunc++] * xdist;\n"
00244             "    tmpshell += const_wave_f[ifunc++] * ydist;\n"
00245             "    tmpshell += const_wave_f[ifunc++] * zdist;\n"
00246             "    value += tmpshell * contracted_gto;\n"
00247           );
00248           break;
00249 
00250         case D_SHELL:
00251           fprintf(ofp,
00252             "    // D_SHELL\n"
00253             "    tmpshell = const_wave_f[ifunc++] * xdist2;\n"
00254             "    tmpshell += const_wave_f[ifunc++] * xdist * ydist;\n"
00255             "    tmpshell += const_wave_f[ifunc++] * ydist2;\n"
00256             "    tmpshell += const_wave_f[ifunc++] * xdist * zdist;\n"
00257             "    tmpshell += const_wave_f[ifunc++] * ydist * zdist;\n"
00258             "    tmpshell += const_wave_f[ifunc++] * zdist2;\n"
00259             "    value += tmpshell * contracted_gto;\n"
00260           );
00261           break;
00262 
00263         case F_SHELL:
00264           fprintf(ofp,
00265             "    // F_SHELL\n"
00266             "    tmpshell = const_wave_f[ifunc++] * xdist2 * xdist;\n"
00267             "    tmpshell += const_wave_f[ifunc++] * xdist2 * ydist;\n"
00268             "    tmpshell += const_wave_f[ifunc++] * ydist2 * xdist;\n"
00269             "    tmpshell += const_wave_f[ifunc++] * ydist2 * ydist;\n"
00270             "    tmpshell += const_wave_f[ifunc++] * xdist2 * zdist;\n"
00271             "    tmpshell += const_wave_f[ifunc++] * xdist * ydist * zdist;\n"
00272             "    tmpshell += const_wave_f[ifunc++] * ydist2 * zdist;\n"
00273             "    tmpshell += const_wave_f[ifunc++] * zdist2 * xdist;\n"
00274             "    tmpshell += const_wave_f[ifunc++] * zdist2 * ydist;\n"
00275             "    tmpshell += const_wave_f[ifunc++] * zdist2 * zdist;\n"
00276             "    value += tmpshell * contracted_gto;\n"
00277           );
00278           break;
00279 
00280         case G_SHELL:
00281           fprintf(ofp,
00282             "    // G_SHELL\n"
00283             "    tmpshell = const_wave_f[ifunc++] * xdist2 * xdist2;\n"
00284             "    tmpshell += const_wave_f[ifunc++] * xdist2 * xdist * ydist;\n"
00285             "    tmpshell += const_wave_f[ifunc++] * xdist2 * ydist2;\n"
00286             "    tmpshell += const_wave_f[ifunc++] * ydist2 * ydist * xdist;\n"
00287             "    tmpshell += const_wave_f[ifunc++] * ydist2 * ydist2;\n"
00288             "    tmpshell += const_wave_f[ifunc++] * xdist2 * xdist * zdist;\n"
00289             "    tmpshell += const_wave_f[ifunc++] * xdist2 * ydist * zdist;\n"
00290             "    tmpshell += const_wave_f[ifunc++] * ydist2 * xdist * zdist;\n"
00291             "    tmpshell += const_wave_f[ifunc++] * ydist2 * ydist * zdist;\n"
00292             "    tmpshell += const_wave_f[ifunc++] * xdist2 * zdist2;\n"
00293             "    tmpshell += const_wave_f[ifunc++] * zdist2 * xdist * ydist;\n"
00294             "    tmpshell += const_wave_f[ifunc++] * ydist2 * zdist2;\n"
00295             "    tmpshell += const_wave_f[ifunc++] * zdist2 * zdist * xdist;\n"
00296             "    tmpshell += const_wave_f[ifunc++] * zdist2 * zdist * ydist;\n"
00297             "    tmpshell += const_wave_f[ifunc++] * zdist2 * zdist2;\n"
00298             "    value += tmpshell * contracted_gto;\n"
00299           );
00300           break;
00301 
00302       } 
00303       fprintf(ofp, "\n");
00304 
00305       shell_counter++;
00306     } 
00307   } 
00308 
00309   fprintf(ofp, 
00310     "  }\n"
00311     "\n"
00312     "  // return either orbital density or orbital wavefunction amplitude \n"
00313     "  if (density) { \n"
00314   );
00315 
00316   if (jitlanguage == ORBITAL_JIT_CUDA) {
00317     fprintf(ofp, "    orbitalgrid[outaddr] = copysignf(value*value, value);\n");
00318   } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00319     fprintf(ofp, "    orbitalgrid[outaddr] = copysign(value*value, value);\n");
00320   }
00321 
00322   fprintf(ofp, 
00323     "  } else { \n"
00324     "    orbitalgrid[outaddr] = value; \n"
00325     "  }\n"
00326     "}\n"
00327   );
00328 
00329   if (ofp != stdout)
00330     fclose(ofp);
00331 
00332   return 0;
00333 }
00334 
00335 
00336