!///////////////////////////////////////////////////////////////////////
!
!      Author:          M. Shiga, H. Kimizuka
!      Last updated:    Jan 23, 2025 by M. Shiga
!      Description:     extensive MPI parallelization
!
!///////////////////////////////////////////////////////////////////////
!***********************************************************************
      subroutine my_mpi_bcast_xyz_XMPI  ( ux, uy, uz, ioption )
!***********************************************************************
!-----------------------------------------------------------------------
!     //   shared variables
!-----------------------------------------------------------------------

      use common_variables, only : &
     &   nprocs_sub, nprocs_main, &
     &   mpi_comm_main, mpi_comm_sub, natom, nbead

      use XMPI_variables, only : &
     &   istart_bead, iend_bead, istart_atom, iend_atom, &
     &   jstart_bead, jend_bead, jstart_atom, jend_atom

!-----------------------------------------------------------------------
!     //   local variables
!-----------------------------------------------------------------------

      implicit none

      include 'mpif.h'

      real(8) :: ux(natom,nbead), uy(natom,nbead), uz(natom,nbead)

      integer :: i, l, ioption, vector_type, trans_type
      integer :: counts(nprocs_sub), disps(nprocs_sub), &
     &           counts_m(nprocs_main), disps_m(nprocs_main)
      integer(kind=mpi_address_kind) :: lb, extent
      integer :: ierr = 0

!-----------------------------------------------------------------------
!     //   allgatherv of atoms (for i-th bead, jstart < i < jend)
!-----------------------------------------------------------------------

      if ( ioption .eq. 1 ) then

          do l = 1, nprocs_sub
             counts(l) = iend_atom(l) - istart_atom(l) + 1
             disps(l) = istart_atom(l) - 1
          end do


          do i = jstart_bead, jend_bead

             call mpi_allgatherv( &
     &          mpi_in_place, 0, mpi_datatype_null, ux(1,i), &
     &          counts, disps, &
     &          mpi_double_precision, mpi_comm_sub, ierr )

             call mpi_allgatherv( &
     &          mpi_in_place, 0, mpi_datatype_null, uy(1,i), &
     &          counts, disps, &
     &          mpi_double_precision, mpi_comm_sub, ierr )

             call mpi_allgatherv( &
     &          mpi_in_place, 0, mpi_datatype_null, uz(1,i), &
     &          counts, disps, &
     &          mpi_double_precision, mpi_comm_sub, ierr )

          end do

!-----------------------------------------------------------------------
!     //   allgatherv of beads (for j-th atom, jstart < j < jend)
!-----------------------------------------------------------------------

      else if ( ioption .eq. 2 ) then

!         //   Setup the indices of that band(s) we want to communicate

          call mpi_type_vector( jend_bead - jstart_bead + 1, &
     &                          jend_atom - jstart_atom + 1, natom, &
     &                          mpi_double_precision, vector_type, ierr)

          call mpi_type_commit( vector_type, ierr )

!         //   The below code is needed to ensure that the bands start
!         //   in the correct location. We here set the start of all
!         //   bands to be in the first row of the ux, uy, uz matrices
!         //   (1, jstart_bead). If we didn't do this the bands would
!         //   start counting from the (jend_atom, jend_bead) position
!         //   of process (myrank_main - 1). Note that the types of lb,
!         //   and extent should always be internal MPI integer types,
!         //   to avoid problems when changing MPI implementation

          lb = 0

          extent = natom * (jend_bead - jstart_bead + 1) * 8

          call mpi_type_create_resized( vector_type, lb, extent, &
     &                                  trans_type, ierr )

          call mpi_type_commit( trans_type, ierr )

          counts_m(1:nprocs_main) = 1

          do i = 1, nprocs_main
             disps_m(i) = i - 1
          end do

!         //   Note here that since we are doing in place we need to
!         //   start inserting at (jstart_atom,1) to ensure the
!         //   correct placement of all elements.

          call mpi_allgatherv( &
     &       mpi_in_place, 0, mpi_datatype_null, ux(jstart_atom,1), &
     &       counts_m, disps_m, &
     &       trans_type, mpi_comm_main, ierr )

          call mpi_allgatherv( &
     &       mpi_in_place, 0, mpi_datatype_null, uy(jstart_atom,1), &
     &       counts_m, disps_m, &
     &       trans_type, mpi_comm_main, ierr )

          call mpi_allgatherv( &
     &       mpi_in_place, 0, mpi_datatype_null, uz(jstart_atom,1), &
     &       counts_m, disps_m, &
     &       trans_type, mpi_comm_main, ierr )

          call mpi_type_free(trans_type, ierr)
          call mpi_type_free(vector_type, ierr)

!-----------------------------------------------------------------------
!     //   allgatherv of atoms (for all beads)
!-----------------------------------------------------------------------

      else if ( ioption .eq. 3 ) then

          do l = 1, nprocs_sub
             counts(l) = iend_atom(l) - istart_atom(l) + 1
             disps(l) = istart_atom(l) - 1
          end do

          do i = 1, nbead

             call mpi_allgatherv( &
     &          mpi_in_place, 0, mpi_datatype_null, ux(1,i), &
     &          counts, disps, &
     &          mpi_double_precision, mpi_comm_sub, ierr )

             call mpi_allgatherv( &
     &          mpi_in_place, 0, mpi_datatype_null, uy(1,i), &
     &          counts, disps, &
     &          mpi_double_precision, mpi_comm_sub, ierr )

             call mpi_allgatherv( &
     &          mpi_in_place, 0, mpi_datatype_null, uz(1,i), &
     &          counts, disps, &
     &          mpi_double_precision, mpi_comm_sub, ierr )

          end do

!-----------------------------------------------------------------------
!     //   allgatherv of beads (for all atoms)
!-----------------------------------------------------------------------

      else if ( ioption .eq. 4 ) then

          counts_m(1) = ( iend_bead(1) - istart_bead(1) + 1 ) * natom
          disps_m(1) = 0

          do i = 2, nprocs_main
             counts_m(i) = ( iend_bead(i) - istart_bead(i) + 1 ) * natom
             disps_m(i) = disps_m(i-1) + counts_m(i)
          end do

          call mpi_allgatherv &
     &       ( mpi_in_place, 0, mpi_datatype_null, ux, &
     &         counts_m, disps_m, &
     &         mpi_double_precision, mpi_comm_main, ierr )

          call mpi_allgatherv &
     &       ( mpi_in_place, 0, mpi_datatype_null, uy, &
     &         counts_m, disps_m, &
     &         mpi_double_precision, mpi_comm_main, ierr )

          call mpi_allgatherv &
     &       ( mpi_in_place, 0, mpi_datatype_null, uz, &
     &         counts_m, disps_m, &
     &         mpi_double_precision, mpi_comm_main, ierr )

      end if

!-----------------------------------------------------------------------
!     //   end of subroutine
!-----------------------------------------------------------------------

      return
      end





!***********************************************************************
      subroutine my_mpi_reduce_xyz_XMPI( ux, uy, uz, ioption )
!***********************************************************************
!-----------------------------------------------------------------------
!     //   shared variables
!-----------------------------------------------------------------------

      use common_variables, only : &
     &   nprocs_sub, mpi_comm_sub, natom

      use XMPI_variables, only : &
     &   istart_atom, iend_atom, natom_paral

!-----------------------------------------------------------------------
!     //   local variables
!-----------------------------------------------------------------------

      implicit none

      include 'mpif.h'

      real(8) :: ux(natom), uy(natom), uz(natom)
      real(8), dimension(:), allocatable :: p, q

      integer :: j, k, l, n, ierr, ioption

!-----------------------------------------------------------------------
!     //   allgatherv of atoms (for jstart < i < jend)
!-----------------------------------------------------------------------

      if ( ioption .eq. 1 ) then

         do l = 1, nprocs_sub

            n = 3 * natom_paral(l)

            if ( .not. allocated( p ) ) allocate( p(n) )
            if ( .not. allocated( q ) ) allocate( q(n) )

            k = 0
            do j = istart_atom(l), iend_atom(l)
               k = k + 1
               p(k) = ux(j)
               k = k + 1
               p(k) = uy(j)
               k = k + 1
               p(k) = uz(j)
            end do

            call mpi_reduce( p, q, n, mpi_double_precision, mpi_sum, &
     &                       l-1, mpi_comm_sub, ierr )
            k = 0
            do j = istart_atom(l), iend_atom(l)
               k = k + 1
               ux(j) = q(k)
               k = k + 1
               uy(j) = q(k)
               k = k + 1
               uz(j) = q(k)
            end do

            if ( allocated( p ) ) deallocate( p )
            if ( allocated( q ) ) deallocate( q )

         end do

      end if

      return
      end





!***********************************************************************
      subroutine my_mpi_bcast_pot_XMPI  ( pot, vir_bead )
!***********************************************************************
!-----------------------------------------------------------------------
!     //   shared variables
!-----------------------------------------------------------------------

      use common_variables, only : &
     &   nprocs_main, mpi_comm_main, nbead

      use XMPI_variables, only : &
     &   istart_bead, iend_bead

!-----------------------------------------------------------------------
!     //   local variables
!-----------------------------------------------------------------------

      implicit none

      include 'mpif.h'

      real(8) :: pot(nbead), vir_bead(3,3,nbead)

      integer :: i, ierr
      integer :: counts_m(nprocs_main), disps_m(nprocs_main)

!-----------------------------------------------------------------------
!     //   allgatherv of beads (for all atoms)
!-----------------------------------------------------------------------

      counts_m(1) = ( iend_bead(1) - istart_bead(1) + 1 ) * 9

      disps_m(1) = 0

      do i = 2, nprocs_main
         counts_m(i) = ( iend_bead(i) - istart_bead(i) + 1 ) * 9
         disps_m(i) = disps_m(i-1) + counts_m(i)
      enddo

      call mpi_allgatherv( &
     &   mpi_in_place, 0, mpi_datatype_null, &
     &   vir_bead, counts_m, disps_m, &
     &   mpi_double_precision, mpi_comm_main, ierr )

      counts_m(1) = ( iend_bead(1) - istart_bead(1) + 1 )

      disps_m(1) = 0

      do i = 2, nprocs_main
         counts_m(i) = ( iend_bead(i) - istart_bead(i) + 1 )
         disps_m(i) = disps_m(i-1) + counts_m(i)
      enddo

      call mpi_allgatherv( &
     &   mpi_in_place, 0, mpi_datatype_null, &
     &   pot, counts_m, disps_m, &
     &   mpi_double_precision, mpi_comm_main, ierr)

!-----------------------------------------------------------------------
!     //   end of subroutine
!-----------------------------------------------------------------------

      return
      end





!***********************************************************************
      subroutine my_mpi_bcast_mnhc_cent_XMPI  ( ux, uy, uz, ioption )
!***********************************************************************
!-----------------------------------------------------------------------
!     //   shared variables
!-----------------------------------------------------------------------

      use common_variables, only : &
     &   myrank_main, myrank_sub, nprocs_sub, &
     &   mpi_comm_sub, natom, ncolor, nnhc

      use XMPI_variables, only : &
     &   istart_atom, iend_atom, nbead_paral, natom_paral

!-----------------------------------------------------------------------
!     //   local variables
!-----------------------------------------------------------------------

      implicit none

      include 'mpif.h'

      real(8) :: ux(natom,nnhc,ncolor)
      real(8) :: uy(natom,nnhc,ncolor)
      real(8) :: uz(natom,nnhc,ncolor)
      real(8), dimension(:), allocatable :: p

      integer :: i, j, k, l, m, n, ierr, ioption

!-----------------------------------------------------------------------
!     //   allgatherv of atoms (for i-th bead, jstart < i < jend)
!-----------------------------------------------------------------------

      if ( myrank_main .ne. 0 ) return

      if ( ioption .eq. 1 ) then

         do l = 1, nprocs_sub

            n = 3 * natom_paral(l) * nbead_paral(1) * nnhc * ncolor

            if ( .not. allocated( p ) ) allocate( p(n) )

            if ( myrank_sub .eq. (l-1) ) then

               k = 0
               do i = 1, ncolor
               do m = 1, nnhc
               do j = istart_atom(l), iend_atom(l)
                  k = k + 1
                  p(k) = ux(j,m,i)
                  k = k + 1
                  p(k) = uy(j,m,i)
                  k = k + 1
                  p(k) = uz(j,m,i)
               end do
               end do
               end do

               call mpi_bcast( p, n, mpi_double_precision, l-1, &
     &                         mpi_comm_sub, ierr )

            else

               call mpi_bcast( p, n, mpi_double_precision, l-1, &
     &                         mpi_comm_sub, ierr )

               k = 0
               do i = 1, ncolor
               do m = 1, nnhc
               do j = istart_atom(l), iend_atom(l)
                  k = k + 1
                  ux(j,m,i) = p(k)
                  k = k + 1
                  uy(j,m,i) = p(k)
                  k = k + 1
                  uz(j,m,i) = p(k)
               end do
               end do
               end do

            end if

            if ( allocated( p ) ) deallocate( p )

         end do

!-----------------------------------------------------------------------
!     //   allgatherv of atoms (for all beads)
!-----------------------------------------------------------------------

      else if ( ioption .eq. 3 ) then

         do l = 1, nprocs_sub

            n = 3 * natom_paral(l) * nbead_paral(1) * nnhc * ncolor

            if ( .not. allocated( p ) ) allocate( p(n) )

            if ( myrank_sub .eq. (l-1) ) then

               k = 0
               do i = 1, ncolor
               do m = 1, nnhc
               do j = istart_atom(l), iend_atom(l)
                  k = k + 1
                  p(k) = ux(j,m,i)
                  k = k + 1
                  p(k) = uy(j,m,i)
                  k = k + 1
                  p(k) = uz(j,m,i)
               end do
               end do
               end do

               call mpi_bcast( p, n, mpi_double_precision, l-1, &
     &                         mpi_comm_sub, ierr )

            else

               call mpi_bcast( p, n, mpi_double_precision, l-1, &
     &                         mpi_comm_sub, ierr )

               k = 0
               do i = 1, ncolor
               do m = 1, nnhc
               do j = istart_atom(l), iend_atom(l)
                  k = k + 1
                  ux(j,m,i) = p(k)
                  k = k + 1
                  uy(j,m,i) = p(k)
                  k = k + 1
                  uz(j,m,i) = p(k)
               end do
               end do
               end do

            end if

            if ( allocated( p ) ) deallocate( p )

         end do

      end if

!-----------------------------------------------------------------------
!     //   end of subroutine
!-----------------------------------------------------------------------

      return
      end





!***********************************************************************
      subroutine my_mpi_bcast_mnhc_mode_XMPI  ( ux, uy, uz, ioption )
!***********************************************************************
!-----------------------------------------------------------------------
!     //   shared variables
!-----------------------------------------------------------------------

      use common_variables, only : &
     &   myrank_main, myrank_sub, nprocs_sub, nprocs_main, &
     &   mpi_comm_main, mpi_comm_sub, natom, nbead, nnhc

      use XMPI_variables, only : &
     &   istart_bead, iend_bead, istart_atom, iend_atom, nbead_paral, &
     &   natom_paral, jstart_bead, jend_bead, jstart_atom, jend_atom

!-----------------------------------------------------------------------
!     //   local variables
!-----------------------------------------------------------------------

      implicit none

      include 'mpif.h'

      real(8) :: ux(natom,nnhc,nbead)
      real(8) :: uy(natom,nnhc,nbead)
      real(8) :: uz(natom,nnhc,nbead)
      real(8), dimension(:), allocatable :: p

      integer :: i, j, k, l, m, n, ierr, ioption

!-----------------------------------------------------------------------
!     //   allgatherv of atoms (for i-th bead, jstart < i < jend)
!-----------------------------------------------------------------------

      if ( ioption .eq. 1 ) then

         do l = 1, nprocs_sub

            n = 3 * natom_paral(l) * nbead_paral(myrank_main+1) * nnhc

            if ( .not. allocated( p ) ) allocate( p(n) )

            if ( myrank_sub .eq. (l-1) ) then

               k = 0
               do i = jstart_bead, jend_bead
               do m = 1, nnhc
               do j = istart_atom(l), iend_atom(l)
                  k = k + 1
                  p(k) = ux(j,m,i)
                  k = k + 1
                  p(k) = uy(j,m,i)
                  k = k + 1
                  p(k) = uz(j,m,i)
               end do
               end do
               end do

               call mpi_bcast( p, n, mpi_double_precision, l-1, &
     &                         mpi_comm_sub, ierr )

            else

               call mpi_bcast( p, n, mpi_double_precision, l-1, &
     &                         mpi_comm_sub, ierr )

               k = 0
               do i = jstart_bead, jend_bead
               do m = 1, nnhc
               do j = istart_atom(l), iend_atom(l)
                  k = k + 1
                  ux(j,m,i) = p(k)
                  k = k + 1
                  uy(j,m,i) = p(k)
                  k = k + 1
                  uz(j,m,i) = p(k)
               end do
               end do
               end do

            end if

            if ( allocated( p ) ) deallocate( p )

         end do

!-----------------------------------------------------------------------
!     //   allgatherv of beads (for j-th atom, jstart < j < jend)
!-----------------------------------------------------------------------

      else if ( ioption .eq. 2 ) then

         do l = 1, nprocs_main

            n = 3 * nbead_paral(l) * natom_paral(myrank_sub+1) * nnhc

            if ( .not. allocated( p ) ) allocate( p(n) )

            if ( myrank_main .eq. (l-1) ) then

               k = 0
               do j = istart_bead(l), iend_bead(l)
               do m = 1, nnhc
               do i = jstart_atom, jend_atom
                  k = k + 1
                  p(k) = ux(i,m,j)
                  k = k + 1
                  p(k) = uy(i,m,j)
                  k = k + 1
                  p(k) = uz(i,m,j)
               end do
               end do
               end do

               call mpi_bcast( p, n, mpi_double_precision, l-1, &
     &                         mpi_comm_main, ierr )

            else

               call mpi_bcast( p, n, mpi_double_precision, l-1, &
     &                         mpi_comm_main, ierr )

               k = 0
               do j = istart_bead(l), iend_bead(l)
               do m = 1, nnhc
               do i = jstart_atom, jend_atom
                  k = k + 1
                  ux(i,m,j) = p(k)
                  k = k + 1
                  uy(i,m,j) = p(k)
                  k = k + 1
                  uz(i,m,j) = p(k)
               end do
               end do
               end do

            end if

            if ( allocated( p ) ) deallocate( p )

         end do

!-----------------------------------------------------------------------
!     //   allgatherv of atoms (for all beads)
!-----------------------------------------------------------------------

      else if ( ioption .eq. 3 ) then

         do l = 1, nprocs_sub

            n = 3 * natom_paral(l) * nbead * nnhc

            if ( .not. allocated( p ) ) allocate( p(n) )

            if ( myrank_sub .eq. (l-1) ) then

               k = 0
               do i = 1, nbead
               do m = 1, nnhc
               do j = istart_atom(l), iend_atom(l)
                  k = k + 1
                  p(k) = ux(j,m,i)
                  k = k + 1
                  p(k) = uy(j,m,i)
                  k = k + 1
                  p(k) = uz(j,m,i)
               end do
               end do
               end do

               call mpi_bcast( p, n, mpi_double_precision, l-1, &
     &                         mpi_comm_sub, ierr )

            else

               call mpi_bcast( p, n, mpi_double_precision, l-1, &
     &                         mpi_comm_sub, ierr )

               k = 0
               do i = 1, nbead
               do m = 1, nnhc
               do j = istart_atom(l), iend_atom(l)
                  k = k + 1
                  ux(j,m,i) = p(k)
                  k = k + 1
                  uy(j,m,i) = p(k)
                  k = k + 1
                  uz(j,m,i) = p(k)
               end do
               end do
               end do

            end if

            if ( allocated( p ) ) deallocate( p )

         end do

!-----------------------------------------------------------------------
!     //   allgatherv of beads (for all atoms)
!-----------------------------------------------------------------------

      else if ( ioption .eq. 4 ) then

         do l = 1, nprocs_main

            n = 3 * nbead_paral(l) * natom * nnhc

            if ( .not. allocated( p ) ) allocate( p(n) )

            if ( myrank_main .eq. (l-1) ) then

               k = 0
               do j = istart_bead(l), iend_bead(l)
               do m = 1, nnhc
               do i = 1, natom
                  k = k + 1
                  p(k) = ux(i,m,j)
                  k = k + 1
                  p(k) = uy(i,m,j)
                  k = k + 1
                  p(k) = uz(i,m,j)
               end do
               end do
               end do

               call mpi_bcast( p, n, mpi_double_precision, l-1, &
     &                         mpi_comm_main, ierr )

            else

               call mpi_bcast( p, n, mpi_double_precision, l-1, &
     &                         mpi_comm_main, ierr )

               k = 0
               do j = istart_bead(l), iend_bead(l)
               do m = 1, nnhc
               do i = 1, natom
                  k = k + 1
                  ux(i,m,j) = p(k)
                  k = k + 1
                  uy(i,m,j) = p(k)
                  k = k + 1
                  uz(i,m,j) = p(k)
               end do
               end do
               end do

            end if

            if ( allocated( p ) ) deallocate( p )

         end do

      end if

!-----------------------------------------------------------------------
!     //   end of subroutine
!-----------------------------------------------------------------------

      return
      end





!***********************************************************************
      subroutine my_mpi_allreduce_pot_sub( pot, fx, fy, fz, vir, natom )
!***********************************************************************

      use common_variables, only : mpi_comm_sub

      implicit none

      integer :: i, ierr, natom
      real(8) :: pot, fx(natom), fy(natom), fz(natom), vir(3,3)
      real(8) :: b1(3*natom+10), b2(3*natom+10)

      include 'mpif.h'

      b1(1) = pot
      do i = 1, natom
         b1(1+i) = fx(i)
      end do
      do i = 1, natom
         b1(1+natom+i) = fy(i)
      end do
      do i = 1, natom
         b1(1+2*natom+i) = fz(i)
      end do
      b1(1+3*natom+1) = vir(1,1)
      b1(1+3*natom+2) = vir(1,2)
      b1(1+3*natom+3) = vir(1,3)
      b1(1+3*natom+4) = vir(2,1)
      b1(1+3*natom+5) = vir(2,2)
      b1(1+3*natom+6) = vir(2,3)
      b1(1+3*natom+7) = vir(3,1)
      b1(1+3*natom+8) = vir(3,2)
      b1(1+3*natom+9) = vir(3,3)

      call MPI_ALLREDUCE ( b1, b2, 3*natom+10, MPI_DOUBLE_PRECISION, &
     &                     MPI_SUM, mpi_comm_sub, ierr )

      pot = b2(1)
      do i = 1, natom
         fx(i) = b1(1+i)
      end do
      do i = 1, natom
         fy(i) = b1(1+natom+i)
      end do
      do i = 1, natom
         fz(i) = b1(1+2*natom+i)
      end do
      vir(1,1) = b2(1+3*natom+1)
      vir(1,2) = b2(1+3*natom+2)
      vir(1,3) = b2(1+3*natom+3)
      vir(2,1) = b2(1+3*natom+4)
      vir(2,2) = b2(1+3*natom+5)
      vir(2,3) = b2(1+3*natom+6)
      vir(3,1) = b2(1+3*natom+7)
      vir(3,2) = b2(1+3*natom+8)
      vir(3,3) = b2(1+3*natom+9)

      return
      end
