2 #if __CUDACC_VER_MAJOR__ >= 11
     5 #include <namd_cub/cub.cuh>
     8 #include <hip/hip_runtime.h>
     9 #include <hipcub/hipcub.hpp>
    11 #endif // end NAMD_CUDA vs. NAMD_HIP
    13 #include "HipDefines.h"
    15 #include "ComputeRestraintsCUDAKernel.h"
    17 #ifdef NODEGROUP_FORCE_REGISTER
    20 #define PI 3.141592653589793
    23 // Host function to update the rotation matrix
    24 void vec_rotation_matrix(double angle, double3 v, cudaTensor& m){
    27   double xs, ys, zs, one_c;
    28   s = sin(angle * PI/180.0);
    29   c = cos(angle * PI/180.0);
    35   mag = sqrt(v.x*v.x + v.y*v.y + v.z*v.z);
    38     // Return a 3x3 identity matrix
    50   m.xx = (one_c * (v.x * v.x) ) + c;
    51   m.xy = (one_c * (v.x * v.y) ) - zs;
    52   m.xz = (one_c * (v.z * v.x) ) + ys;
    54   m.yx = (one_c * (v.x * v.y) ) + zs;
    55   m.yy = (one_c * (v.y * v.y) ) + c;
    56   m.yz = (one_c * (v.y * v.z) ) - xs;
    58   m.zx = (one_c * (v.z * v.x) ) - ys;
    59   m.zy = (one_c * (v.y * v.z) ) + xs;
    60   m.zz = (one_c * (v.z * v.z) ) + c;
    64 template<bool T_DOENERGY>
    65 __global__ void computeRestrainingForceKernel(
    66   const int currentTime, 
    67   const int nConstrainedAtoms,
    69   const double consScaling, 
    73   const bool spheConsOn, 
    74   const bool consSelectX,
    75   const bool consSelectY,
    76   const bool consSelectZ, 
    78   const double3  rotAxis,
    79   const double3  rotPivot,
    80   const double3  moveVel,
    81   const double3  spheConsCenter,
    82   const int*    __restrict d_constrainedSOA, 
    83   const int*    __restrict d_constrainedID, 
    84   const double* __restrict d_pos_x,
    85   const double* __restrict d_pos_y, 
    86   const double* __restrict d_pos_z,
    87   const double* __restrict d_k,
    88   const double* __restrict d_cons_x,
    89   const double* __restrict d_cons_y,
    90   const double* __restrict d_cons_z,
    91   double*       __restrict f_normal_x, 
    92   double*       __restrict f_normal_y,
    93   double*       __restrict f_normal_z,
    94   double*       __restrict d_bcEnergy,
    95   double*       __restrict h_bcEnergy,
    96   double3*      __restrict d_netForce, 
    97   double3*      __restrict h_netForce,
    98   cudaTensor*   __restrict d_virial, 
    99   cudaTensor*   __restrict h_virial,
   101   unsigned int* __restrict tbcatomic,
   106   int tid = threadIdx.x + (blockIdx.x * blockDim.x);
   108   int totaltb = gridDim.x;
   109   bool isLastBlockDone;
   111   if(threadIdx.x == 0){
   118   double3 r_netForce = {0, 0, 0};
   120   r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
   121   r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
   122   r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
   124   if(tid < nConstrainedAtoms){
   126     // Index of the constrained atom in the SOA data structure
   127     int indexC = d_constrainedID[tid];
   128     int soaID = d_constrainedSOA[indexC];
   130     // Atomic fixed positions
   131     double ref_x = d_cons_x[indexC];
   132     double ref_y = d_cons_y[indexC];
   133     double ref_z = d_cons_z[indexC];
   135     // JM: BAD BAD BAD -> UNCOALESCED GLOBAL MEMORY ACCESS
   136     double pos_x = d_pos_x[soaID];
   137     double pos_y = d_pos_y[soaID];
   138     double pos_z = d_pos_z[soaID];
   141     double k = d_k[indexC];
   143     // I can just store consScaling * k here instead of doing the math
   146       ref_x += currentTime * moveVel.x;
   147       ref_y += currentTime * moveVel.y;
   148       ref_z += currentTime * moveVel.z;
   153       // do a matrix-vector operation
   155       double rx = ref_x - rotPivot.x;
   156       double ry = ref_y - rotPivot.y;
   157       double rz = ref_z - rotPivot.z;
   159       ref_x = rotMatrix.xx * rx + rotMatrix.xy * ry + rotMatrix.xz * rz;
   160       ref_y = rotMatrix.yx * rx + rotMatrix.yy * ry + rotMatrix.yz * rz;
   161       ref_z = rotMatrix.zx * rx + rotMatrix.zy * ry + rotMatrix.zz * rz;
   164     // END moving and rotationg contraints
   167       // JM: This code sucks, but maybe it's not a very common use-case, so let's go with it for now
   169       diff.x = ref_x - spheConsCenter.x;
   170       diff.y = ref_y - spheConsCenter.y;
   171       diff.z = ref_z - spheConsCenter.z;
   173       double refRad = sqrt(diff.x * diff.x + diff.y*diff.y + diff.z * diff.z); // Whoops
   175       // Reusing diff here as relPos: first let's store global position - spherical center
   176       diff = lat.delta(Vector(pos_x, pos_y, pos_z), spheConsCenter);
   178       refRad *= rsqrt(diff.x * diff.x + diff.y * diff.y + diff.z * diff.z); // 2x-whoops
   179       // now we recalculate refPos:
   180       ref_x = spheConsCenter.x + diff.x * refRad;
   181       ref_y = spheConsCenter.y + diff.y * refRad;
   182       ref_z = spheConsCenter.z + diff.z * refRad;
   185     // Calculating the RIJ vector as lattice.delta(ref, pos);
   188     rij = lat.delta(Vector(ref_x, ref_y, ref_z), Vector(pos_x, pos_y, pos_z));
   190     vpos.x = ref_x - rij.x;
   191     vpos.y = ref_y - rij.y;
   192     vpos.z = ref_z - rij.z;
   195       rij.x *= (1.0 * consSelectX);
   196       rij.y *= (1.0 * consSelectY);
   197       rij.z *= (1.0 * consSelectZ);
   200     double r2 = rij.x * rij.x + rij.y*rij.y + rij.z*rij.z;
   201     double r  = sqrt(r2); // 3x-whoops
   205       double value = k * (pow(r, consExp)); // NOTE: this consExp is an int, so it might be better to just do a loop
   206       if (T_DOENERGY) energy = value;
   213       // JM: BAD BAD BAD ->UNCOALESCED GLOBAL MEMORY ACCESS
   214       f_normal_x[soaID] += rij.x;
   215       f_normal_y[soaID] += rij.y;
   216       f_normal_z[soaID] += rij.z;
   217       r_netForce.x = rij.x;
   218       r_netForce.y = rij.y;
   219       r_netForce.z = rij.z;
   221       // Now we calculate the virial contribution
   222       // JM: is this virial symmetrical? 
   223       r_virial.xx = rij.x * vpos.x;
   224       r_virial.xy = rij.x * vpos.y;
   225       r_virial.xz = rij.x * vpos.z;
   226       r_virial.yx = rij.y * vpos.x;
   227       r_virial.yy = rij.y * vpos.y;
   228       r_virial.yz = rij.y * vpos.z;
   229       r_virial.zx = rij.z * vpos.x;
   230       r_virial.zy = rij.z * vpos.y;
   231       r_virial.zz = rij.z * vpos.z;
   237     // Reduce energy and virials
   238     typedef cub::BlockReduce<double, 128> BlockReduce;
   239     __shared__ typename BlockReduce::TempStorage temp_storage;
   240     energy  = BlockReduce(temp_storage).Sum(energy);
   243     r_netForce.x  = BlockReduce(temp_storage).Sum(r_netForce.x);
   245     r_netForce.y  = BlockReduce(temp_storage).Sum(r_netForce.y);
   247     r_netForce.z  = BlockReduce(temp_storage).Sum(r_netForce.z);
   249     r_virial.xx  = BlockReduce(temp_storage).Sum(r_virial.xx);
   251     r_virial.xy  = BlockReduce(temp_storage).Sum(r_virial.xy);
   253     r_virial.xz  = BlockReduce(temp_storage).Sum(r_virial.xz);
   255     r_virial.yx  = BlockReduce(temp_storage).Sum(r_virial.yx);
   257     r_virial.yy  = BlockReduce(temp_storage).Sum(r_virial.yy);
   259     r_virial.yz  = BlockReduce(temp_storage).Sum(r_virial.yz);
   261     r_virial.zx  = BlockReduce(temp_storage).Sum(r_virial.zx);
   263     r_virial.zy  = BlockReduce(temp_storage).Sum(r_virial.zy);
   265     r_virial.zz  = BlockReduce(temp_storage).Sum(r_virial.zz);
   268     if(threadIdx.x == 0){
   269       atomicAdd(d_bcEnergy, energy);
   270       atomicAdd(&(d_netForce->x), r_netForce.x);
   271       atomicAdd(&(d_netForce->y), r_netForce.y);
   272       atomicAdd(&(d_netForce->z), r_netForce.z);
   274       atomicAdd(&(d_virial->xx), r_virial.xx);
   275       atomicAdd(&(d_virial->xy), r_virial.xy);
   276       atomicAdd(&(d_virial->xz), r_virial.xz);
   278       atomicAdd(&(d_virial->yx), r_virial.yx);
   279       atomicAdd(&(d_virial->yy), r_virial.yy);
   280       atomicAdd(&(d_virial->yz), r_virial.yz);
   282       atomicAdd(&(d_virial->zx), r_virial.zx);
   283       atomicAdd(&(d_virial->zy), r_virial.zy);
   284       atomicAdd(&(d_virial->zz), r_virial.zz);
   287       unsigned int value = atomicInc(tbcatomic, totaltb);
   288       isLastBlockDone = (value == (totaltb -1));
   296     if(threadIdx.x == 0){
   297       //updates to host-mapped mem
   298       h_bcEnergy[0]  = d_bcEnergy[0];      
   299       h_netForce->x = d_netForce->x;
   300       h_netForce->y = d_netForce->y;
   301       h_netForce->z = d_netForce->z;
   303       h_virial->xx = d_virial->xx;
   304       h_virial->xy = d_virial->xy;
   305       h_virial->xz = d_virial->xz;
   307       h_virial->yx = d_virial->yx;
   308       h_virial->yy = d_virial->yy;
   309       h_virial->yz = d_virial->yz;
   311       h_virial->zx = d_virial->zx;
   312       h_virial->zy = d_virial->zy;
   313       h_virial->zz = d_virial->zz;
   338 void computeRestrainingForce(
   341   const int currentTime, 
   342   const int nConstrainedAtoms,
   344   const double consScaling, 
   345   const bool movConsOn, 
   346   const bool rotConsOn,
   347   const bool selConsOn,
   348   const bool spheConsOn, 
   349   const bool consSelectX,
   350   const bool consSelectY,
   351   const bool consSelectZ, 
   353   const double3  rotAxis,
   354   const double3  rotPivot,
   355   const double3  moveVel,
   356   const double3  spheConsCenter,
   357   const int*     d_constrainedSOA, 
   358   const int*     d_constrainedID, 
   359   const double*  d_pos_x,
   360   const double*  d_pos_y, 
   361   const double*  d_pos_z,
   363   const double*  d_cons_x,
   364   const double*  d_cons_y,
   365   const double*  d_cons_z,
   366   double*        d_f_normal_x, 
   367   double*        d_f_normal_y,
   368   double*        d_f_normal_z,
   374   cudaTensor* d_virial, 
   375   cudaTensor* h_virial, 
   376   cudaTensor  rotationMatrix, 
   377   unsigned int* d_tbcatomic, 
   381   const int blocks = 128;
   382   const int grid = (nConstrainedAtoms + blocks - 1) / blocks;
   384   // we calculate the rotational matrix for this timestep on the host, hopefully this is fast enough
   385   vec_rotation_matrix(rotVel * currentTime, rotAxis, rotationMatrix);
   387   if(doEnergy || doVirial){
   388     computeRestrainingForceKernel<true> <<<grid, blocks, 0, stream >>>(
   428     computeRestrainingForceKernel <false> <<<grid, blocks, 0, stream>>>( 
   469 #endif // NODEGROUP_FORCE_REGISTER