2 #if __CUDACC_VER_MAJOR__ >= 11
     5 #include <namd_cub/cub.cuh>
    10 #include <hip/hip_runtime.h>
    11 #include <hipcub/hipcub.hpp>
    15 #include "ComputeGroupRes2GroupCUDAKernel.h"
    16 #include "ComputeCOMCudaKernel.h"
    17 #include "HipDefines.h"
    19 #ifdef NODEGROUP_FORCE_REGISTER
    21 /*! Compute restraint force, virial, and energy applied to large
    22     group 2 (atoms >= 1024), due to restraining COM of group 2 
    23     (h_group2COM) to the COM of the group 1 (h_group1COM). 
    24     To use this function, the COM of the group 1 and 2 
    25     must be calculated and passed to this function as h_group1COM 
    27     This function also calculates the distance from COM of the
    28     group 1 to COM of the group 2. */
    29 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE, int T_MGPUON>
    30 __global__ void computeLargeGroupRestraint2GroupsKernel(
    31     const int         numRestrainedGroup1,
    32     const int         totalNumRestrained,
    33     const int         restraintExp,
    34     const double      restraintK,
    35     const double3     resCenterVec,
    36     const double3     resDirection,
    37     const double      inv_group1_mass,
    38     const double      inv_group2_mass,
    39     const int*      __restrict    groupAtomsSOAIndex,
    41     const char3*    __restrict    transform,
    42     const float*    __restrict    mass,
    43     const double*   __restrict    pos_x,
    44     const double*   __restrict    pos_y,
    45     const double*   __restrict    pos_z,
    46     double*         __restrict    f_normal_x,
    47     double*         __restrict    f_normal_y,
    48     double*         __restrict    f_normal_z,
    49     cudaTensor*     __restrict    d_virial,
    50     cudaTensor*     __restrict    h_extVirial,
    51     double*         __restrict    h_resEnergy,
    52     double3*        __restrict    h_resForce,
    53     double3*  __restrict    group1COM,  // on device if mgpu
    54     double3*  __restrict    group2COM,  // on device if mgpu
    55     double3*        __restrict    h_diffCOM,
    56     unsigned int*   __restrict    d_tbcatomic)
    58     int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
    59     int totaltb = gridDim.x;
    60     bool isLastBlockDone = false;
    65     double3 diffCOM = {0, 0, 0};
    66     double3 group_f = {0, 0, 0};
    67     double3 pos = {0, 0, 0};
    68     double3 f = {0, 0, 0};
    70     r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
    71     r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
    72     r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
    74     if(tIdx < totalNumRestrained) {
    75         SOAindex = groupAtomsSOAIndex[tIdx];
    77         // Here for consistency with ComputeGroupRes1, we calculate 
    78         // distance from com1 to com2 along specific restraint dimention,
    79         // so force is acting on group 2
    81      diffCOM.x = (group2COM->x*inv_group2_mass - group1COM->x*inv_group1_mass) * resDirection.x;
    82      diffCOM.y = (group2COM->y*inv_group2_mass - group1COM->y*inv_group1_mass) * resDirection.y;
    83      diffCOM.z = (group2COM->z*inv_group2_mass - group1COM->z*inv_group1_mass) * resDirection.z;
    87        diffCOM.x = (group2COM->x - group1COM->x) * resDirection.x;
    88        diffCOM.y = (group2COM->y - group1COM->y) * resDirection.y;
    89        diffCOM.z = (group2COM->z - group1COM->z) * resDirection.z;
    91         // Calculate the minimum image distance
    92         diffCOM = lat.delta_from_diff(diffCOM);
    95             // Calculate the difference from equilibrium restraint distance
    96             double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
    97             double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
    98                 resCenterVec.z*resCenterVec.z);  
   100             double distDiff = (comVal - centerVal);
   101             double distSqDiff = distDiff * distDiff;
   102             double invCOMVal = 1.0 / comVal; 
   104             // Calculate energy and force on group of atoms
   105             if(distSqDiff > 0.0f) { // To avoid numerical error
   106                 // Energy = k * (r - r_eq)^n
   107                 energy = restraintK * distSqDiff;
   108                 for (int n = 2; n < restraintExp; n += 2) {
   109                     energy *= distSqDiff;
   111                 // Force = -k * n * (r - r_eq)^(n-1)
   112                 double force = -energy * restraintExp / distDiff;
   113                 // calculate force along COM difference
   114                 group_f.x = force * diffCOM.x * invCOMVal;
   115                 group_f.y = force * diffCOM.y * invCOMVal;
   116                 group_f.z = force * diffCOM.z * invCOMVal;
   119             // Calculate the difference from equilibrium restraint distance vector
   120             // along specific restraint dimention
   122             resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
   123             resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
   124             resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z;
   125             // Wrap the distance difference (diffCOM - resCenterVec) 
   126             resDist = lat.delta_from_diff(resDist); 
   128             double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
   130             // Calculate energy and force on group of atoms
   131             if(distSqDiff > 0.0f) { // To avoid numerical error
   132                 // Energy = k * (r - r_eq)^n
   133                 energy = restraintK * distSqDiff;
   134                 for (int n = 2; n < restraintExp; n += 2) {
   135                     energy *= distSqDiff;
   137                 // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq| 
   138                 double force = -energy * restraintExp / distSqDiff;
   139                 group_f.x = force * resDist.x;
   140                 group_f.y = force * resDist.y;
   141                 group_f.z = force * resDist.z;
   145         // calculate the force on each atom of the group
   146         if (tIdx < numRestrainedGroup1) {
   147             // threads [0 , numGroup1Atoms) calculate force for group 1
   148             // We use negative because force is calculated for group 2
   149             f.x = -group_f.x * m * inv_group1_mass;
   150             f.y = -group_f.y * m * inv_group1_mass;
   151             f.z = -group_f.z * m * inv_group1_mass;
   153             // threads [numGroup1Atoms , totalNumRestrained) calculate force for group 2
   154             f.x = group_f.x * m * inv_group2_mass;
   155             f.y = group_f.y * m * inv_group2_mass;
   156             f.z = group_f.z * m * inv_group2_mass;
   158         // apply the bias to each atom in group
   159         f_normal_x[SOAindex] += f.x;
   160         f_normal_y[SOAindex] += f.y;
   161         f_normal_z[SOAindex] += f.z;
   162         // Virial is based on applied force on each atom
   164             // positions must be unwraped for virial calculation
   165             pos.x = pos_x[SOAindex];
   166             pos.y = pos_y[SOAindex];
   167             pos.z = pos_z[SOAindex];
   168             char3 tr = transform[SOAindex];
   169             pos = lat.reverse_transform(pos, tr);
   170             r_virial.xx = f.x * pos.x;
   171             r_virial.xy = f.x * pos.y;
   172             r_virial.xz = f.x * pos.z;
   173             r_virial.yx = f.y * pos.x;
   174             r_virial.yy = f.y * pos.y;
   175             r_virial.yz = f.y * pos.z;
   176             r_virial.zx = f.z * pos.x;
   177             r_virial.zy = f.z * pos.y;
   178             r_virial.zz = f.z * pos.z;
   183     if(T_DOENERGY || T_DOVIRIAL) {
   185             // Reduce virial values in the thread block
   186             typedef cub::BlockReduce<double, 128> BlockReduce;
   187             __shared__ typename BlockReduce::TempStorage temp_storage;
   189             r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
   191             r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
   193             r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
   196             r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
   198             r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
   200             r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
   203             r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
   205             r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
   207             r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
   211         if(threadIdx.x == 0) {
   213                 // thread 0 adds the reduced virial values into device memory
   214                 atomicAdd(&(d_virial->xx), r_virial.xx);
   215                 atomicAdd(&(d_virial->xy), r_virial.xy);
   216                 atomicAdd(&(d_virial->xz), r_virial.xz);
   218                 atomicAdd(&(d_virial->yx), r_virial.yx);
   219                 atomicAdd(&(d_virial->yy), r_virial.yy);
   220                 atomicAdd(&(d_virial->yz), r_virial.yz);
   222                 atomicAdd(&(d_virial->zx), r_virial.zx);
   223                 atomicAdd(&(d_virial->zy), r_virial.zy);
   224                 atomicAdd(&(d_virial->zz), r_virial.zz);
   227             unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
   228             isLastBlockDone = (value == (totaltb -1));
   233         if(isLastBlockDone) {
   234         // Thread 0 of the last block will set the host values
   235             if(threadIdx.x == 0) {
   237                     h_resEnergy[0] = energy;     // restraint energy for each group, needed for output
   238                     h_diffCOM->x = diffCOM.x;    // distance between COM of two restrained groups 
   239                     h_diffCOM->y = diffCOM.y;    // distance between COM of two restrained groups 
   240                     h_diffCOM->z = diffCOM.z;    // distance between COM of two restrained groups 
   241                     h_resForce->x = group_f.x;   // restraint force on group 2
   242                     h_resForce->y = group_f.y;   // restraint force on group 2
   243                     h_resForce->z = group_f.z;   // restraint force on group 2
   246                     // Add virial values to host memory. 
   247                     // We use add,since we have with multiple restraints group 
   248                     h_extVirial->xx += d_virial->xx;
   249                     h_extVirial->xy += d_virial->xy;
   250                     h_extVirial->xz += d_virial->xz;
   251                     h_extVirial->yx += d_virial->yx;
   252                     h_extVirial->yy += d_virial->yy;
   253                     h_extVirial->yz += d_virial->yz;
   254                     h_extVirial->zx += d_virial->zx;
   255                     h_extVirial->zy += d_virial->zy;
   256                     h_extVirial->zz += d_virial->zz;
   258                     //reset the device virial value
   271                 //resets atomic counter
   278       {// need lastBlockDone for MGPU
   279    if(threadIdx.x == 0){
   281      unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
   282      isLastBlockDone = (value == (totaltb -1));
   288    if(threadIdx.x == 0){
   289      // zero out for next iteration
   296      //resets atomic counter
   305 /*! Compute restraint force, virial, and energy applied to small
   306     groups (atoms < 1024), due to restraining COM of group 2 
   307     (h_group2COM) to the COM of the group 1 (h_group1COM). 
   308     This function also calculates the distance from COM of the
   309     group 1 to COM of the group 2. */
   310 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE, int T_MGPUON>
   311 __global__ void computeSmallGroupRestraint2GroupsKernel(
   312     const int         numRestrainedGroup1,
   313     const int         totalNumRestrained,
   314     const int         restraintExp,
   315     const double      restraintK,
   316     const double3     resCenterVec,
   317     const double3     resDirection,
   318     const double      inv_group1_mass,
   319     const double      inv_group2_mass,
   320     const int*      __restrict    groupAtomsSOAIndex,
   322     const char3*    __restrict    transform,
   323     const float*    __restrict    mass,
   324     const double*   __restrict    pos_x,
   325     const double*   __restrict    pos_y,
   326     const double*   __restrict    pos_z,
   327     double*         __restrict    f_normal_x,
   328     double*         __restrict    f_normal_y,
   329     double*         __restrict    f_normal_z,
   330     cudaTensor*     __restrict    h_extVirial,
   331     double*         __restrict    h_resEnergy,
   332     double3*        __restrict    h_resForce,
   333     double3*        __restrict    h_diffCOM,
   334     double3*        __restrict    group1COM, // on device in Multi GPU
   335     double3*        __restrict    group2COM,
   336     unsigned int*   __restrict    d_tbcatomic)
   338     int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
   339     __shared__ double3 sh_com1;
   340     __shared__ double3 sh_com2;
   341     bool isLastBlockDone = false;
   342     int totaltb = gridDim.x;
   345     double3 com1 = {0, 0, 0};
   346     double3 com2 = {0, 0, 0};
   347     double3 diffCOM = {0, 0, 0};
   348     double3 group_f = {0, 0, 0};
   349     double3 pos = {0, 0, 0};
   350     double3 f = {0, 0, 0};
   352     r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
   353     r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
   354     r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
   358    com1.x = group1COM->x;
   359    com1.y = group1COM->y;
   360    com1.z = group1COM->z;
   361    com2.x = group2COM->x;
   362    com2.y = group2COM->y;
   363    com2.z = group2COM->z;
   365     if(tIdx < totalNumRestrained){
   366         // First -> recalculate center of mass.
   367         SOAindex = groupAtomsSOAIndex[tIdx];
   369         m = mass[SOAindex]; // Cast from float to double here
   370         pos.x = pos_x[SOAindex];
   371         pos.y = pos_y[SOAindex];
   372         pos.z = pos_z[SOAindex];
   374         // unwrap the  coordinate to calculate COM
   375         char3 tr = transform[SOAindex];
   376         pos = lat.reverse_transform(pos, tr);
   378      if (tIdx < numRestrainedGroup1) {
   379             // we initialized the com2 to zero
   384             // we initialized the com1 to zero
   392     // reduce the (mass * position) values for group 1 and 2 in the thread block
   393     typedef cub::BlockReduce<double, 1024> BlockReduce;
   394     __shared__ typename BlockReduce::TempStorage temp_storage;
   396       com1.x = BlockReduce(temp_storage).Sum(com1.x);
   398       com1.y = BlockReduce(temp_storage).Sum(com1.y);
   400       com1.z = BlockReduce(temp_storage).Sum(com1.z);
   402       com2.x = BlockReduce(temp_storage).Sum(com2.x);
   404       com2.y = BlockReduce(temp_storage).Sum(com2.y);
   406       com2.z = BlockReduce(temp_storage).Sum(com2.z);
   409     // Thread 0 calculates the COM of group 1 and 2
   410     if(threadIdx.x == 0){
   411         sh_com1.x = com1.x * inv_group1_mass; // calculates the COM of group 1
   412         sh_com1.y = com1.y * inv_group1_mass; // calculates the COM of group 1
   413         sh_com1.z = com1.z * inv_group1_mass; // calculates the COM of group 1
   414         sh_com2.x = com2.x * inv_group2_mass; // calculates the COM of group 2
   415         sh_com2.y = com2.y * inv_group2_mass; // calculates the COM of group 2
   416         sh_com2.z = com2.z * inv_group2_mass; // calculates the COM of group 2
   420     if(tIdx < totalNumRestrained) {
   421         // Here for consistency with distanceZ, we calculate 
   422         // distance from com1 to com2 along specific restraint dimention,
   423         // so force is acting on group 2
   424         diffCOM.x = (sh_com2.x - sh_com1.x) * resDirection.x;
   425         diffCOM.y = (sh_com2.y - sh_com1.y) * resDirection.y; 
   426         diffCOM.z = (sh_com2.z - sh_com1.z) * resDirection.z;
   427         // Calculate the minimum image distance 
   428         diffCOM = lat.delta_from_diff(diffCOM);
   430         if (T_USEMAGNITUDE) {
   431             // Calculate the difference from equilibrium restraint distance
   432             double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
   433             double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
   434                 resCenterVec.z*resCenterVec.z);  
   436             double distDiff = (comVal - centerVal);
   437             double distSqDiff = distDiff * distDiff;
   438             double invCOMVal = 1.0 / comVal; 
   440             // Calculate energy and force on group of atoms
   441             if(distSqDiff > 0.0f) { // To avoid numerical error
   442                 // Energy = k * (r - r_eq)^n
   443                 energy = restraintK * distSqDiff;
   444                 for (int n = 2; n < restraintExp; n += 2) {
   445                     energy *= distSqDiff;
   447                 // Force = -k * n * (r - r_eq)^(n-1)
   448                 double force = -energy * restraintExp / distDiff;
   449                 // calculate force along COM difference
   450                 group_f.x = force * diffCOM.x * invCOMVal;
   451                 group_f.y = force * diffCOM.y * invCOMVal;
   452                 group_f.z = force * diffCOM.z * invCOMVal;
   455             // Calculate the difference from equilibrium restraint distance vector
   456             // along specific restraint dimention
   458             resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
   459             resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
   460             resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z; 
   461             // Wrap the distance difference (diffCOM - resCenterVec) 
   462             resDist = lat.delta_from_diff(resDist); 
   464             double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
   466             // Calculate energy and force on group of atoms
   467             if(distSqDiff > 0.0f) { // To avoid numerical error
   468                 // Energy = k * (r - r_eq)^n
   469                 energy = restraintK * distSqDiff;
   470                 for (int n = 2; n < restraintExp; n += 2) {
   471                     energy *= distSqDiff;
   473                 // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq| 
   474                 double force = -energy * restraintExp / distSqDiff;
   475                 group_f.x = force * resDist.x;
   476                 group_f.y = force * resDist.y;
   477                 group_f.z = force * resDist.z;
   481         // calculate the force on each atom of the group
   482         if (tIdx < numRestrainedGroup1) {
   483             // threads [0 , numGroup1Atoms) calculate force for group 1
   484             // We use negative because force is calculated for group 2
   485             f.x = -group_f.x * m * inv_group1_mass;
   486             f.y = -group_f.y * m * inv_group1_mass;
   487             f.z = -group_f.z * m * inv_group1_mass;
   489             // threads [numGroup1Atoms , totalNumRestrained) calculate force for group 2
   490             f.x = group_f.x * m * inv_group2_mass;
   491             f.y = group_f.y * m * inv_group2_mass;
   492             f.z = group_f.z * m * inv_group2_mass;
   495         // apply the bias to each atom in group
   496         f_normal_x[SOAindex] += f.x ;
   497         f_normal_y[SOAindex] += f.y ;
   498         f_normal_z[SOAindex] += f.z ;
   499         // Virial is based on applied force on each atom
   501             // positions must be unwraped for virial calculation
   502             r_virial.xx = f.x * pos.x;
   503             r_virial.xy = f.x * pos.y;
   504             r_virial.xz = f.x * pos.z;
   505             r_virial.yx = f.y * pos.x;
   506             r_virial.yy = f.y * pos.y;
   507             r_virial.yz = f.y * pos.z;
   508             r_virial.zx = f.z * pos.x;
   509             r_virial.zy = f.z * pos.y;
   510             r_virial.zz = f.z * pos.z;
   515     if(T_DOENERGY || T_DOVIRIAL) {
   517             // Reduce virial values in the thread block
   518             r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
   520             r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
   522             r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
   525             r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
   527             r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
   529             r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
   532             r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
   534             r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
   536             r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
   540         // thread zero updates the restraints energy and force
   541         if(threadIdx.x == 0){
   543                 // Add virial values to host memory. 
   544                 // We use add,since we have with multiple restraints group 
   545                 h_extVirial->xx += r_virial.xx;
   546                 h_extVirial->xy += r_virial.xy;
   547                 h_extVirial->xz += r_virial.xz;
   548                 h_extVirial->yx += r_virial.yx;
   549                 h_extVirial->yy += r_virial.yy;
   550                 h_extVirial->yz += r_virial.yz;
   551                 h_extVirial->zx += r_virial.zx;
   552                 h_extVirial->zy += r_virial.zy;
   553                 h_extVirial->zz += r_virial.zz;
   556                 h_resEnergy[0] = energy;     // restraint energy for each group, needed for output
   557                 h_diffCOM->x = diffCOM.x;    // distance between two COM of restrained groups 
   558                 h_diffCOM->y = diffCOM.y;    // distance between two COM of restrained groups  
   559                 h_diffCOM->z = diffCOM.z;    // distance between two COM of restrained groups 
   560                 h_resForce->x = group_f.x;   // restraint force on group 
   561                 h_resForce->y = group_f.y;   // restraint force on group 
   562                 h_resForce->z = group_f.z;   // restraint force on group 
   567       {// need lastBockDone for T_MGPUON
   568    if(threadIdx.x == 0){
   570      unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
   571      isLastBlockDone = (value == (totaltb -1));
   578        { // zero out for next iteration
   586          //resets atomic counter
   593 /*! Compute restraint force, energy, and virial 
   594     applied to group 2, due to restraining COM of 
   595     group 2 to the COM of group 1 */
   596 void computeGroupRestraint_2Group(
   598     const int         useMagnitude,
   601     const int         numRestrainedGroup1,
   602     const int         totalNumRestrained,
   603     const int         restraintExp,
   604     const double      restraintK,
   605     const double3     resCenterVec,
   606     const double3     resDirection,
   607     const double      inv_group1_mass,
   608     const double      inv_group2_mass,
   609     const int*        d_groupAtomsSOAIndex,
   611     const char3*      d_transform,
   613     const double*     d_pos_x,
   614     const double*     d_pos_y,
   615     const double*     d_pos_z,
   616     double*           d_f_normal_x,
   617     double*           d_f_normal_y,
   618     double*           d_f_normal_z,
   619     cudaTensor*       d_virial,
   620     cudaTensor*       h_extVirial,
   623     double3*          h_group1COM,
   624     double3*          h_group2COM,
   626     double3*          d_group1COM,
   627     double3*          d_group2COM,
   628     double3**         d_peer1COM,
   629     double3**         d_peer2COM,
   630     unsigned int*     d_tbcatomic,
   631     const int         numDevices,
   634     int options = doEnergy + (doVirial << 1) + (useMagnitude << 2)
   636     double3* COM1Ptr=(mGpuOn) ? d_group1COM: h_group1COM;
   637     double3* COM2Ptr=(mGpuOn) ? d_group2COM: h_group2COM;
   638     const int blocks = (totalNumRestrained > 1024) ? 128 : 1024;
   639     const int grid = (totalNumRestrained > 1024) ? (totalNumRestrained + blocks - 1) / blocks : 1;
   640     if (totalNumRestrained > 1024) {
   641         // first calculate the COM for restraint groups and store it in
   642         // h_group1COM and h_group2COM
   643    if(!mGpuOn) // if we don't have distributed COM
   645        compute2COMKernel<128><<<grid, blocks, 0, stream>>>(
   656             d_groupAtomsSOAIndex,
   664      computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer1COM,
   667      computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer2COM,
   671 #define CALL_LARGE_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON) \
   672         computeLargeGroupRestraint2GroupsKernel<DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON> \
   673         <<<grid, blocks, 0, stream>>>( \
   674        numRestrainedGroup1, totalNumRestrained, \
   675             restraintExp, restraintK, resCenterVec, resDirection, \
   676             inv_group1_mass, inv_group2_mass, d_groupAtomsSOAIndex, \
   677             lat, d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
   678             d_f_normal_x, d_f_normal_y, d_f_normal_z,  d_virial, \
   679             h_extVirial, h_resEnergy, h_resForce, COM1Ptr, \
   680             COM2Ptr, h_diffCOM, d_tbcatomic);
   682    case 0: CALL_LARGE_GROUP_RES(0, 0, 0, 0); break;
   683    case 1: CALL_LARGE_GROUP_RES(1, 0, 0, 0); break;
   684    case 2: CALL_LARGE_GROUP_RES(0, 1, 0, 0); break;
   685    case 3: CALL_LARGE_GROUP_RES(1, 1, 0, 0); break;
   686    case 4: CALL_LARGE_GROUP_RES(0, 0, 1, 0); break;
   687    case 5: CALL_LARGE_GROUP_RES(1, 0, 1, 0); break;
   688    case 6: CALL_LARGE_GROUP_RES(0, 1, 1, 0); break;
   689    case 7: CALL_LARGE_GROUP_RES(1, 1, 1, 0); break;
   690    case 8: CALL_LARGE_GROUP_RES(0, 0, 0, 1); break;
   691    case 9: CALL_LARGE_GROUP_RES(1, 0, 0, 1); break;
   692    case 10: CALL_LARGE_GROUP_RES(0, 1, 0, 1); break;
   693    case 11: CALL_LARGE_GROUP_RES(1, 1, 0, 1); break;
   694    case 12: CALL_LARGE_GROUP_RES(0, 0, 1, 1); break;
   695    case 13: CALL_LARGE_GROUP_RES(1, 0, 1, 1); break;
   696    case 14: CALL_LARGE_GROUP_RES(0, 1, 1, 1); break;
   697    case 15: CALL_LARGE_GROUP_RES(1, 1, 1, 1); break;
   699 #undef CALL_LARGE_GROUP_RES
   701         // For small group of restrained atom, we can just launch
   702         //  a single threadblock
   704    computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer1COM,
   707    computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer2COM,
   711 #define CALL_SMALL_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON) \
   712       computeSmallGroupRestraint2GroupsKernel<DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON> \
   713         <<<grid, blocks, 0, stream>>>( \
   714        numRestrainedGroup1, totalNumRestrained,      \
   715             restraintExp, restraintK, resCenterVec, resDirection, \
   716             inv_group1_mass, inv_group2_mass, d_groupAtomsSOAIndex, \
   717             lat, d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
   718             d_f_normal_x, d_f_normal_y, d_f_normal_z, \
   719             h_extVirial, h_resEnergy, h_resForce, h_diffCOM, \
   720        COM1Ptr, COM2Ptr, d_tbcatomic \
   723    case 0: CALL_SMALL_GROUP_RES(0, 0, 0, 0); break;
   724    case 1: CALL_SMALL_GROUP_RES(1, 0, 0, 0); break;
   725    case 2: CALL_SMALL_GROUP_RES(0, 1, 0, 0); break;
   726    case 3: CALL_SMALL_GROUP_RES(1, 1, 0, 0); break;
   727    case 4: CALL_SMALL_GROUP_RES(0, 0, 1, 0); break;
   728    case 5: CALL_SMALL_GROUP_RES(1, 0, 1, 0); break;
   729    case 6: CALL_SMALL_GROUP_RES(0, 1, 1, 0); break;
   730    case 7: CALL_SMALL_GROUP_RES(1, 1, 1, 0); break;
   731    case 8: CALL_SMALL_GROUP_RES(0, 0, 0, 1); break;
   732    case 9: CALL_SMALL_GROUP_RES(1, 0, 0, 1); break;
   733    case 10: CALL_SMALL_GROUP_RES(0, 1, 0, 1); break;
   734    case 11: CALL_SMALL_GROUP_RES(1, 1, 0, 1); break;
   735    case 12: CALL_SMALL_GROUP_RES(0, 0, 1, 1); break;
   736    case 13: CALL_SMALL_GROUP_RES(1, 0, 1, 1); break;
   737    case 14: CALL_SMALL_GROUP_RES(0, 1, 1, 1); break;
   738    case 15: CALL_SMALL_GROUP_RES(1, 1, 1, 1); break;
   740 #undef CALL_SMALL_GROUP_RES
   745 #endif // NODEGROUP_FORCE_REGISTER