!///////////////////////////////////////////////////////////////////////
!
!      Author:          M. Shiga
!      Last updated:    Jan 23, 2025 by M. Shiga
!      Description:     extensive MPI parallelization
!
!///////////////////////////////////////////////////////////////////////
!***********************************************************************
      subroutine correct_force_XMPI
!***********************************************************************

!-----------------------------------------------------------------------
!     /*   shared variables                                           */
!-----------------------------------------------------------------------

      use common_variables, only : &
     &   itrans_start, irot_start, iboundary, method

      implicit none

!-----------------------------------------------------------------------
!     /*   return by option                                           */
!-----------------------------------------------------------------------

      if      ( method(1:5) .eq. 'PIMD ' ) then
         continue
      else if ( method(1:5) .eq. 'PIHMC' ) then
         continue
      else if ( method(1:5) .eq. 'CMD  ' ) then
         continue
      else if ( method(1:5) .eq. 'RPMD ' ) then
         continue
      else if ( method(1:5) .eq. 'TRPMD' ) then
         continue
      else if ( method(1:5) .eq. 'BCMD ' ) then
         continue
      else
         return
      end if

      if ( ( itrans_start .ne. 2 ) .or. ( irot_start .ne. 2 ) ) return

!-----------------------------------------------------------------------
!     /*   Cartesian force  ->  normal mode force                     */
!-----------------------------------------------------------------------

      call nm_trans_force_XMPI ( 1 )

!-----------------------------------------------------------------------
!     /*   remove net force                                           */
!-----------------------------------------------------------------------

      if ( itrans_start .eq. 2 ) call subtract_force_cent_XMPI

!-----------------------------------------------------------------------
!     /*   subtract rotation:  only free boundary condition           */
!-----------------------------------------------------------------------

      if ( iboundary .eq. 0 ) then

         if ( irot_start .eq. 2 ) call subtract_torque_cent_XMPI

      end if

!-----------------------------------------------------------------------
!     /*   normal mode force  ->  Cartesian force                     */
!-----------------------------------------------------------------------

      call nm_trans_force_XMPI ( 0 )

      return
      end





!***********************************************************************
      subroutine subtract_force_cent_XMPI
!***********************************************************************

!-----------------------------------------------------------------------
!     /*   shared variables                                           */
!-----------------------------------------------------------------------

      use common_variables, only : &
     &   fux, fuy, fuz

      use XMPI_variables, only : &
     &   jstart_atom, jend_atom, jstart_bead, jend_bead

!-----------------------------------------------------------------------
!     /*   local variables                                            */
!-----------------------------------------------------------------------

      implicit none

      integer :: i, j

      real(8) :: sump, sumfx, sumfy, sumfz

!-----------------------------------------------------------------------
!     /*   remove net force                                           */
!-----------------------------------------------------------------------

      sumfx = 0.d0
      sumfy = 0.d0
      sumfz = 0.d0
      sump  = 0.d0

      do j = jstart_bead, jend_bead

         if ( j .ne. 1 ) cycle

         do i = jstart_atom, jend_atom

            sumfx = sumfx + fux(i,j)
            sumfy = sumfy + fuy(i,j)
            sumfz = sumfz + fuz(i,j)
            sump  = sump  + 1.d0

         end do

      end do

      call my_mpi_allreduce_real_0( sumfx )
      call my_mpi_allreduce_real_0( sumfy )
      call my_mpi_allreduce_real_0( sumfz )
      call my_mpi_allreduce_real_0( sump  )
!      call my_mpi_allreduce_real_0_sub( sumfx )
!      call my_mpi_allreduce_real_0_sub( sumfy )
!      call my_mpi_allreduce_real_0_sub( sumfz )
!      call my_mpi_allreduce_real_0_sub( sump  )
!      call my_mpi_allreduce_real_0_main( sumfx )
!      call my_mpi_allreduce_real_0_main( sumfy )
!      call my_mpi_allreduce_real_0_main( sumfz )
!      call my_mpi_allreduce_real_0_main( sump  )

      sumfx = sumfx/sump
      sumfy = sumfy/sump
      sumfz = sumfz/sump

      do j = jstart_bead, jend_bead

         if ( j .ne. 1 ) cycle

         do i = jstart_atom, jend_atom

            fux(i,j) = fux(i,j) - sumfx
            fuy(i,j) = fuy(i,j) - sumfy
            fuz(i,j) = fuz(i,j) - sumfz

         end do

      end do

!-----------------------------------------------------------------------
!     /*   normal modes to cartesian                                  */
!-----------------------------------------------------------------------

      call nm_trans_force_XMPI ( 0 )

      return
      end





!***********************************************************************
      subroutine subtract_torque_cent_XMPI
!***********************************************************************

!-----------------------------------------------------------------------
!     /*   shared variables                                           */
!-----------------------------------------------------------------------

      use common_variables, only : &
     &   ux, uy, uz, fux, fuy, fuz, xg, yg, zg, fictmass, natom

      use XMPI_variables, only : &
     &   jstart_atom, jend_atom, jstart_bead, jend_bead

!-----------------------------------------------------------------------
!     /*   local variables                                            */
!-----------------------------------------------------------------------

      implicit none

      integer :: i, j

      real(8) :: fm, sumx, sumy, sumz, sump

      real(8), dimension(3,3) :: a, c

      real(8), dimension(3)   :: b, d, e, f

      real(8) :: tiny = 1.d-14

!-----------------------------------------------------------------------
!     /*   center of mass                                             */
!-----------------------------------------------------------------------

      sumx = 0.d0
      sumy = 0.d0
      sumz = 0.d0
      sump = 0.d0

      do j = jstart_bead, jend_bead

         if ( j .ne. 1 ) cycle

         do i = jstart_atom, jend_atom

            sumx = sumx + fictmass(i,j)*ux(i,j)
            sumy = sumy + fictmass(i,j)*uy(i,j)
            sumz = sumz + fictmass(i,j)*uz(i,j)
            sump = sump + fictmass(i,j)

         end do

      end do

      call my_mpi_allreduce_real_0( sumx )
      call my_mpi_allreduce_real_0( sumy )
      call my_mpi_allreduce_real_0( sumz )
      call my_mpi_allreduce_real_0( sump )
!      call my_mpi_allreduce_real_0_sub( sumfx )
!      call my_mpi_allreduce_real_0_sub( sumfy )
!      call my_mpi_allreduce_real_0_sub( sumfz )
!      call my_mpi_allreduce_real_0_sub( sump  )
!      call my_mpi_allreduce_real_0_main( sumfx )
!      call my_mpi_allreduce_real_0_main( sumfy )
!      call my_mpi_allreduce_real_0_main( sumfz )
!      call my_mpi_allreduce_real_0_main( sump  )

      xg(1) = sumx/sump
      yg(1) = sumy/sump
      zg(1) = sumz/sump

!-----------------------------------------------------------------------
!     /*   centroids only                                             */
!-----------------------------------------------------------------------

      b(:) = 0.d0

      do j = jstart_bead, jend_bead

         if ( j .ne. 1 ) cycle

         do i = jstart_atom, jend_atom

            fm   = fictmass(i,1)
            b(1) = b(1) + (uy(i,1)-yg(1))*fuz(i,1) &
     &                  - (uz(i,1)-zg(1))*fuy(i,1)
            b(2) = b(2) + (uz(i,1)-zg(1))*fux(i,1) &
     &                  - (ux(i,1)-xg(1))*fuz(i,1)
            b(3) = b(3) + (ux(i,1)-xg(1))*fuy(i,1) &
     &                  - (uy(i,1)-yg(1))*fux(i,1)

         end do

      end do

      call my_mpi_allreduce_real_1( b, 3 )
!      call my_mpi_allreduce_real_1_sub( b, 3 )
!      call my_mpi_allreduce_real_1_main( b, 3 )

!     /*   moment of inertia   */

      a(:,:) = 0.d0

      do j = jstart_bead, jend_bead

         if ( j .ne. 1 ) cycle

         do i = jstart_atom, jend_atom

            fm   = fictmass(i,1)

            a(1,1) = a(1,1) + fm*(uy(i,1)-yg(1))*(uy(i,1)-yg(1)) &
     &                      + fm*(uz(i,1)-zg(1))*(uz(i,1)-zg(1))
            a(1,2) = a(1,2) - fm*(ux(i,1)-xg(1))*(uy(i,1)-yg(1))
            a(1,3) = a(1,3) - fm*(ux(i,1)-xg(1))*(uz(i,1)-zg(1))
            a(2,1) = a(2,1) - fm*(uy(i,1)-yg(1))*(ux(i,1)-xg(1))
            a(2,2) = a(2,2) + fm*(uz(i,1)-zg(1))*(uz(i,1)-zg(1)) &
     &                      + fm*(ux(i,1)-xg(1))*(ux(i,1)-xg(1))
            a(2,3) = a(2,3) - fm*(uy(i,1)-yg(1))*(uz(i,1)-zg(1))
            a(3,1) = a(3,1) - fm*(uz(i,1)-zg(1))*(ux(i,1)-xg(1))
            a(3,2) = a(3,2) - fm*(uz(i,1)-zg(1))*(uy(i,1)-yg(1))
            a(3,3) = a(3,3) + fm*(ux(i,1)-xg(1))*(ux(i,1)-xg(1)) &
     &                      + fm*(uy(i,1)-yg(1))*(uy(i,1)-yg(1))

         end do

      end do

      call my_mpi_allreduce_real_2( a, 3, 3 )
!      call my_mpi_allreduce_real_2_sub( a, 3, 3 )
!      call my_mpi_allreduce_real_2_main( a, 3, 3 )

!     /*   principal axis:  diagonalize moment of inertia          */

      call ddiag_MPI ( a, e, c, 3 )

!     /*   in principal axis:  torque                              */

      d(1) = c(1,1)*b(1) + c(2,1)*b(2) + c(3,1)*b(3)
      d(2) = c(1,2)*b(1) + c(2,2)*b(2) + c(3,2)*b(3)
      d(3) = c(1,3)*b(1) + c(2,3)*b(2) + c(3,3)*b(3)

!     /*   d = torque divided by moment of inertia                 */

      if ( natom .eq. 1 ) then
         d(1) = 0.d0
         d(2) = 0.d0
         d(3) = 0.d0
      else if ( natom .eq. 2 ) then
         d(1) = 0.d0
         if ( e(2) .gt. tiny ) d(2) = d(2)/e(2)
         if ( e(3) .gt. tiny ) d(3) = d(3)/e(3)
      else
         if ( e(1) .gt. tiny ) d(1) = d(1)/e(1)
         if ( e(2) .gt. tiny ) d(2) = d(2)/e(2)
         if ( e(3) .gt. tiny ) d(3) = d(3)/e(3)
      end if

!     /*   d in laboratory frame                                   */

      f(1) = c(1,1)*d(1) + c(1,2)*d(2) + c(1,3)*d(3)
      f(2) = c(2,1)*d(1) + c(2,2)*d(2) + c(2,3)*d(3)
      f(3) = c(3,1)*d(1) + c(3,2)*d(2) + c(3,3)*d(3)

      do j = jstart_bead, jend_bead

         if ( j .ne. 1 ) cycle

         do i = jstart_atom, jend_atom

            fm   = fictmass(i,1)
            fux(i,1) = fux(i,1) - f(2)*(uz(i,1)-zg(1))*fm &
     &                          + f(3)*(uy(i,1)-yg(1))*fm
            fuy(i,1) = fuy(i,1) - f(3)*(ux(i,1)-xg(1))*fm &
     &                          + f(1)*(uz(i,1)-zg(1))*fm
            fuz(i,1) = fuz(i,1) - f(1)*(uy(i,1)-yg(1))*fm &
     &                          + f(2)*(ux(i,1)-xg(1))*fm

         end do

      end do

!-----------------------------------------------------------------------
!     /*   normal modes to cartesian                                  */
!-----------------------------------------------------------------------

      call nm_trans_force_XMPI ( 0 )

      return
      end
