2 #if __CUDACC_VER_MAJOR__ >= 11
     5 #include <namd_cub/cub.cuh>
    10 #include <hip/hip_runtime.h>
    11 #include <hipcub/hipcub.hpp>
    15 #include "ComputeSMDCUDAKernel.h"
    16 #include "ComputeCOMCudaKernel.h"
    17 #include "HipDefines.h"
    19 #ifdef NODEGROUP_FORCE_REGISTER
    22 /*! Calculate SMD force and virial for large atom group (numSMDAtoms > 1024)
    23   Multiple thread block will be called to do this operation.
    24   The current COM (curCOM) must be calculated and pssed to this function. */
    25 template<bool T_DOENERGY, bool T_MGPUON>
    26 __global__ void computeSMDForceWithCOMKernel(
    27   const int                numSMDAtoms,
    29   const double             inv_group_mass,
    32   const double             velocity,
    33   const double3            direction,
    34   const int                currentTime,
    36   const float *  __restrict mass,
    37   const double*  __restrict pos_x,
    38   const double*  __restrict pos_y,
    39   const double*  __restrict pos_z,
    40   const char3*   __restrict transform,
    41   double*        __restrict f_normal_x,
    42   double*        __restrict f_normal_y,
    43   double*        __restrict f_normal_z,
    44   const int*     __restrict smdAtomsSOAIndex,
    45   cudaTensor*    __restrict d_virial,
    46   double3*       __restrict h_curCOM,
    47   double3*       __restrict d_curCOM,
    48   double3**      __restrict d_peerCOM,
    49   double*        __restrict h_extEnergy,
    50   double3*       __restrict h_extForce,
    51   cudaTensor*    __restrict h_extVirial,
    52   unsigned int*  __restrict tbcatomic)
    54   int tid = threadIdx.x + blockIdx.x * blockDim.x;  
    55   int totaltb = gridDim.x;
    56   bool isLastBlockDone = 0;
    57   double3 group_f = {0, 0, 0};
    59   double3 pos = {0, 0, 0};
    60   double3 f = {0, 0, 0};
    62   double3 cm={h_curCOM->x, h_curCOM->y, h_curCOM->z};
    63   r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
    64   r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
    65   r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
    70       cm.x = d_curCOM->x * inv_group_mass;
    71       cm.y = d_curCOM->y * inv_group_mass;
    72       cm.z = d_curCOM->z * inv_group_mass;
    75   if(tid < numSMDAtoms){
    76     SOAindex = smdAtomsSOAIndex[tid];
    78     // uncoalesced memory access: too bad!
    79     double m = mass[SOAindex]; // Cast from float to double here
    80     pos.x = pos_x[SOAindex];
    81     pos.y = pos_y[SOAindex];
    82     pos.z = pos_z[SOAindex];
    84     // calculate the distance difference along direction
    86     diffCOM.x = cm.x - origCM.x;
    87     diffCOM.y = cm.y - origCM.y;
    88     diffCOM.z = cm.z - origCM.z;
    89     double diff = diffCOM.x*direction.x + diffCOM.y*direction.y + 
    90       diffCOM.z*direction.z;
    92     // Ok so we've calculated the new center of mass, now we can calculate the bias
    93     double preFactor = (velocity*currentTime - diff);
    94     group_f.x = k*preFactor*direction.x + k2*(diff*direction.x - diffCOM.x);
    95     group_f.y = k*preFactor*direction.y + k2*(diff*direction.y - diffCOM.y);
    96     group_f.z = k*preFactor*direction.z + k2*(diff*direction.z - diffCOM.z);
    98     // calculate the force on each atom
    99     f.x = group_f.x * m * inv_group_mass;
   100     f.y = group_f.y * m * inv_group_mass;
   101     f.z = group_f.z * m * inv_group_mass;
   104     f_normal_x[SOAindex] += f.x ;
   105     f_normal_y[SOAindex] += f.y ;
   106     f_normal_z[SOAindex] += f.z ;
   108       // energy for restraint along the direction
   109       energy = 0.5*k*preFactor*preFactor; 
   110       // energy for transverse restraint
   111       energy += 0.5*k2*(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y +
   112         diffCOM.z*diffCOM.z - diff*diff);
   113       // unwrap coordinates before calculating the virial
   114       char3 t = transform[SOAindex];
   115       pos = lat.reverse_transform(pos, t);
   116       r_virial.xx = f.x * pos.x;
   117       r_virial.xy = f.x * pos.y;
   118       r_virial.xz = f.x * pos.z;
   119       r_virial.yx = f.y * pos.x;
   120       r_virial.yy = f.y * pos.y;
   121       r_virial.yz = f.y * pos.z;
   122       r_virial.zx = f.z * pos.x;
   123       r_virial.zy = f.z * pos.y;
   124       r_virial.zz = f.z * pos.z;
   130     typedef cub::BlockReduce<double, 128> BlockReduce;
   131     __shared__ typename BlockReduce::TempStorage temp_storage;
   133     r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
   135     r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
   137     r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
   140     r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
   142     r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
   144     r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
   147     r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
   149     r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
   151     r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
   154     if(threadIdx.x == 0){
   155       atomicAdd(&(d_virial->xx), r_virial.xx);
   156       atomicAdd(&(d_virial->xy), r_virial.xy);
   157       atomicAdd(&(d_virial->xz), r_virial.xz);
   159       atomicAdd(&(d_virial->yx), r_virial.yx);
   160       atomicAdd(&(d_virial->yy), r_virial.yy);
   161       atomicAdd(&(d_virial->yz), r_virial.yz);
   163       atomicAdd(&(d_virial->zx), r_virial.zx);
   164       atomicAdd(&(d_virial->zy), r_virial.zy);
   165       atomicAdd(&(d_virial->zz), r_virial.zz);
   168       unsigned int value = atomicInc(&tbcatomic[0], totaltb);
   169       isLastBlockDone = (value == (totaltb -1));
   173     // Last block will set the host values
   175       if(threadIdx.x == 0){
   176         h_extEnergy[0] = energy;
   177         h_extForce->x  = group_f.x;
   178         h_extForce->y  = group_f.y;
   179         h_extForce->z  = group_f.z;
   181         h_extVirial->xx = d_virial->xx;
   182         h_extVirial->xy = d_virial->xy;
   183         h_extVirial->xz = d_virial->xz;
   184         h_extVirial->yx = d_virial->yx;
   185         h_extVirial->yy = d_virial->yy;
   186         h_extVirial->yz = d_virial->yz;
   187         h_extVirial->zx = d_virial->zx;
   188         h_extVirial->zy = d_virial->zy;
   189         h_extVirial->zz = d_virial->zz;
   190         //reset the device virial value
   206     { // compute isLastBlockDone in the non energy steps
   207       if(threadIdx.x == 0){
   209    unsigned int value = atomicInc(&tbcatomic[0], totaltb);
   210    isLastBlockDone = (value == (totaltb -1));
   215     if(threadIdx.x == 0){
   224       //resets atomic counter
   232 /*! Calculate SMD force, virial, and COM for small atom group (numSMDAtoms <= 1024)
   233   Single thread block will be called to do this operation.
   234   The current COM will be calculated and stored in h_curCM. */
   235 template<bool T_DOENERGY, bool T_MGPUON>
   236 __global__ void computeSMDForceKernel(
   237   const int                numSMDAtoms,
   239   const double             inv_group_mass,
   242   const double             velocity,
   243   const double3            direction,
   244   const int                currentTime,
   245   const double3            origCM,
   246   const float * __restrict mass,
   247   const double* __restrict pos_x,
   248   const double* __restrict pos_y,
   249   const double* __restrict pos_z,
   250   const char3*  __restrict transform,
   251   double*       __restrict f_normal_x,
   252   double*       __restrict f_normal_y,
   253   double*       __restrict f_normal_z,
   254   const int*    __restrict smdAtomsSOAIndex,
   255   double3*      __restrict h_curCM,
   256   double3*      __restrict d_curCM,
   257   double3**     __restrict d_peerCOM,
   258   double*       __restrict h_extEnergy,
   259   double3*      __restrict h_extForce,
   260   cudaTensor*   __restrict h_extVirial,
   261   unsigned int*  __restrict tbcatomic)
   263   __shared__ double3 group_f;
   264   __shared__ double energy;
   265   int tid = threadIdx.x + blockIdx.x * blockDim.x;
   266   int totaltb = gridDim.x;
   267   bool isLastBlockDone = 0;
   269   double3 cm = {0, 0, 0};
   270   double3 pos = {0, 0, 0};
   271   double3 f = {0, 0, 0};
   273   r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
   274   r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
   275   r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
   277   // in the mGpuOn case the COM must be calculated across devices and passed in
   278   if(tid < numSMDAtoms){
   279     // First -> recalculate center of mass.
   280     // Only thread zero is doing this
   281     SOAindex = smdAtomsSOAIndex[tid];
   282     m = mass[SOAindex]; // Cast from float to double here
   283     pos.x = pos_x[SOAindex];
   284     pos.y = pos_y[SOAindex];
   285     pos.z = pos_z[SOAindex];
   287     // unwrap the  coordinate to calculate COM
   288     char3 t = transform[SOAindex];
   289     pos = lat.reverse_transform(pos, t);
   297   // now reduce the values and add it to thread zero
   298   typedef cub::BlockReduce<double, 1024> BlockReduce;
   299   __shared__ typename BlockReduce::TempStorage temp_storage;
   301     cm.x = BlockReduce(temp_storage).Sum(cm.x);
   303     cm.y = BlockReduce(temp_storage).Sum(cm.y);
   305     cm.z = BlockReduce(temp_storage).Sum(cm.z);
   308   // Calculate group force and acceleration
   309   if(threadIdx.x == 0){
   312    cm.x = d_curCM->x * inv_group_mass;
   313    cm.y = d_curCM->y * inv_group_mass;
   314    cm.z = d_curCM->z * inv_group_mass;
   318    cm.x *= inv_group_mass; // calculates the current center of mass
   319    cm.y *= inv_group_mass; // calculates the current center of mass
   320    cm.z *= inv_group_mass; // calculates the current center of mass
   322     // calculate the distance difference along direction
   324     diffCOM.x = cm.x - origCM.x;
   325     diffCOM.y = cm.y - origCM.y;
   326     diffCOM.z = cm.z - origCM.z; 
   327     double diff = diffCOM.x*direction.x + diffCOM.y*direction.y + 
   328       diffCOM.z*direction.z;
   330     // Ok so we've calculated the new center of mass, now we can calculate the bias
   331     double preFactor = (velocity*currentTime - diff);
   332     group_f.x = k*preFactor*direction.x + k2*(diff*direction.x - diffCOM.x);
   333     group_f.y = k*preFactor*direction.y + k2*(diff*direction.y - diffCOM.y);
   334     group_f.z = k*preFactor*direction.z + k2*(diff*direction.z - diffCOM.z);
   336       // energy for restraint along the direction
   337       energy = 0.5*k*preFactor*preFactor; 
   338       // energy for transverse restraint
   339       energy += 0.5*k2*(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y +
   340         diffCOM.z*diffCOM.z - diff*diff);
   345   if(tid < numSMDAtoms){
   346     // calculate the force on each atom
   347     f.x = group_f.x * m * inv_group_mass;
   348     f.y = group_f.y * m * inv_group_mass;
   349     f.z = group_f.z * m * inv_group_mass;
   352     f_normal_x[SOAindex] += f.x ;
   353     f_normal_y[SOAindex] += f.y ;
   354     f_normal_z[SOAindex] += f.z ;
   356       r_virial.xx = f.x * pos.x;
   357       r_virial.xy = f.x * pos.y;
   358       r_virial.xz = f.x * pos.z;
   359       r_virial.yx = f.y * pos.x;
   360       r_virial.yy = f.y * pos.y;
   361       r_virial.yz = f.y * pos.z;
   362       r_virial.zx = f.z * pos.x;
   363       r_virial.zy = f.z * pos.y;
   364       r_virial.zz = f.z * pos.z;
   368     if(threadIdx.x == 0){
   370       unsigned int value = atomicInc(&tbcatomic[0], totaltb);
   371       isLastBlockDone = (value == (totaltb -1));
   375     r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
   377     r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
   379     r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
   382     r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
   384     r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
   386     r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
   389     r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
   391     r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
   393     r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
   396     if(threadIdx.x == 0){
   397       // thread zero updates the value
   398       h_curCM->x = cm.x; // update current center of mass
   399       h_curCM->y = cm.y; // update current center of mass
   400       h_curCM->z = cm.z; // update current center of mass
   401       h_extEnergy[0] = energy;    // bias energy
   402       h_extForce->x  = group_f.x; // bias force
   403       h_extForce->y  = group_f.y;
   404       h_extForce->z  = group_f.z;
   406       h_extVirial->xx = r_virial.xx;
   407       h_extVirial->xy = r_virial.xy;
   408       h_extVirial->xz = r_virial.xz;
   409       h_extVirial->yx = r_virial.yx;
   410       h_extVirial->yy = r_virial.yy;
   411       h_extVirial->yz = r_virial.yz;
   412       h_extVirial->zx = r_virial.zx;
   413       h_extVirial->zy = r_virial.zy;
   414       h_extVirial->zz = r_virial.zz;
   417   // last block cleans up
   420       if(threadIdx.x == 0){
   421    // zero out for next iteration
   425    //resets atomic counter
   433 /*! Compute SMD force and virial on group of atoms */
   434 void computeSMDForce(
   436   const double      inv_group_mass,
   437   const double      spring_constant,
   438   const double      transverse_spring_constant, 
   439   const double      velocity, 
   440   const double3     direction,
   442   const int         currentTime,
   444   const double3     origCM,  
   446   const double*     d_pos_x, 
   447   const double*     d_pos_y,
   448   const double*     d_pos_z, 
   449   const char3*      d_transform,
   450   double *          d_f_normal_x, 
   451   double *          d_f_normal_y, 
   452   double *          d_f_normal_z, 
   453   const int         numSMDAtoms, 
   454   const int*        d_smdAtomsSOAIndex,
   458   cudaTensor*       d_extVirial,
   461   cudaTensor*       h_extVirial, 
   462   unsigned int*     d_tbcatomic,
   463   const int         numDevices,
   464   const int         deviceIndex,
   468   const int blocks = (numSMDAtoms > 1024) ? 128 : 1024;
   469   const int grid = (numSMDAtoms > 1024) ? (numSMDAtoms + blocks - 1) / blocks : 1;
   471 #define CALL_WITH_COM(DOENERGY, MGPUON) \
   472  computeSMDForceWithCOMKernel<DOENERGY, MGPUON> \
   473    <<< grid, blocks, 0 , stream >>> \
   474    (numSMDAtoms, lat, inv_group_mass, spring_constant, \
   475     transverse_spring_constant, velocity,  direction, currentTime, \
   476     origCM, d_mass, d_pos_x, d_pos_y, d_pos_z,  d_transform, \
   477     d_f_normal_x, d_f_normal_y, d_f_normal_z,  d_smdAtomsSOAIndex, \
   478     d_extVirial, h_curCM, d_curCM, d_peerCOM, h_extEnergy, h_extForce,     \
   479     h_extVirial, d_tbcatomic);
   481 #define CALL(DOENERGY, MGPUON) \
   482   computeSMDForceKernel<DOENERGY, MGPUON> \
   483   <<<grid, blocks, 0, stream>>> \
   484     (numSMDAtoms, lat, inv_group_mass, spring_constant, \
   485      transverse_spring_constant, velocity, direction,  currentTime, \
   486      origCM, d_mass, d_pos_x, d_pos_y, d_pos_z, d_transform, \
   487      d_f_normal_x, d_f_normal_y, d_f_normal_z, d_smdAtomsSOAIndex, \
   488      h_curCM, d_curCM, d_peerCOM, h_extEnergy, h_extForce, h_extVirial, \
   491   if (numSMDAtoms > 1024) {
   493       { //first calculate the COM for SMD group and store it in h_curCM
   494    computeCOMKernel<128><<<grid, blocks, 0, stream>>>(
   509       {// sum up the COMs across devices to this device
   510    computeDistCOMKernelMgpu<<<grid, blocks, 0, stream>>>(d_peerCOM,
   514     if(doEnergy && mGpuOn) CALL_WITH_COM(true, true);
   515     if(doEnergy && !mGpuOn) CALL_WITH_COM(true, false);
   516     if(!doEnergy && mGpuOn) CALL_WITH_COM(false, true);
   517     if(!doEnergy && !mGpuOn) CALL_WITH_COM(false, false);
   522    {// sum up the COMs across devices to this device
   523      computeDistCOMKernelMgpu<<<grid, blocks, 0, stream>>>(d_peerCOM,
   528       if(doEnergy && mGpuOn) CALL(true, true);
   529       if(doEnergy && !mGpuOn) CALL(true, false);
   530       if(!doEnergy && mGpuOn) CALL(false, true);
   531       if(!doEnergy && !mGpuOn) CALL(false, false);
   538 void initPeerCOMmgpu(
   539             const int numDevices,
   540             const int deviceIndex,
   541             double3** d_peerPool,
   545   const int blocks = numDevices;
   547   initPeerCOMKernel<<<grid, blocks, 0, stream>>>( numDevices,
   554 /* called in earlier phase to handle multi device COM */
   555 void computeCOMSMDMgpu(
   556   const int         numSMDAtoms,
   559   const double*     d_pos_x,
   560   const double*     d_pos_y,
   561   const double*     d_pos_z,
   562   const char3*      d_transform,
   563   const int*        d_smdAtomsSOAIndex,
   565   double3**         d_peer_curCM,
   566   unsigned int*     d_tbcatomic,
   567   const int         numDevices,
   568   const int         deviceIndex,
   571   // block it up if large, otherwise all in one go
   572   const int blocks = (numSMDAtoms > 1024) ? 128 : 1024;
   573   const int grid = (numSMDAtoms > 1024) ? (numSMDAtoms + blocks - 1) / blocks : 1;
   574   //initialize the device memory to zero here
   575   cudaCheck(cudaMemset(d_peerCOM, 0, sizeof(double3)));
   576   if(numSMDAtoms >1024)
   577     computeCOMKernelMgpu<128><<<grid, blocks, 0, stream>>>(numSMDAtoms,
   579                                                       d_pos_x, d_pos_y, d_pos_z,
   587     computeCOMKernelMgpu<1024><<<grid, blocks, 0, stream>>>(numSMDAtoms,
   589                                                  d_pos_x, d_pos_y, d_pos_z,
   598 #endif // NODEGROUP_FORCE_REGISTER