!///////////////////////////////////////////////////////////////////////
!
!      Author:          M. Shiga
!      Last updated:    Nov 10, 2018 by M. Shiga
!      Description:     two-stage parallelization
!
!///////////////////////////////////////////////////////////////////////
!=======================================================================
!
!     The following subroutines were added to support two-stages
!     MPI parallelization
!
!=======================================================================
!***********************************************************************
      subroutine my_mpi_init_2
!***********************************************************************

      use common_variables, only : &
     &   mpi_group_world, nprocs, myrank, nprocs_world, myrank_world, &
     &   nprocs_pimd, myrank_pimd

      implicit none

      integer :: ierr

      include 'mpif.h'

      call MPI_INIT ( ierr )

      call MPI_COMM_SIZE ( MPI_COMM_WORLD, nprocs_world, ierr )
      call MPI_COMM_RANK ( MPI_COMM_WORLD, myrank_world, ierr )

!     /*   get nprocs_main and nprocs_sub   */
      call get_nprocs_2

!     /*   get group number of MPI_COMM_WORLD   */
      call MPI_COMM_GROUP ( MPI_COMM_WORLD, mpi_group_world, ierr )

!     /*   set pimd groups   */
      call setup_mpi_pimd_group_2

!     /*   set sub groups   */
      call setup_mpi_sub_group_2

!     /*   set main groups   */
      call setup_mpi_main_group_2

!     /*   mpi barrier   */
      call my_mpi_barrier

!-----------------------------------------------------------------------
!     /*    number of processors in main and sub groups               */
!-----------------------------------------------------------------------

      nprocs = nprocs_pimd
      myrank = myrank_pimd

#ifdef debug
      call my_mpi_debug_2
#endif

      return
      end





!***********************************************************************
      subroutine get_nprocs_2
!***********************************************************************
!=======================================================================
!
!     This subroutine returns np_beads and np_force
!
!=======================================================================

      use common_variables, only : &
     &   myrank_world, nbead, iounit, np_beads, np_force, np_cycle, &
     &   nprocs_main, nprocs_sub, nprocs_pimd, nprocs_world

      implicit none

      integer :: ierr = 0

!-----------------------------------------------------------------------
!     /*    read number of beads                                      */
!-----------------------------------------------------------------------

      call read_int1_MPI( nbead, '<nbead>', 7, iounit )

!-----------------------------------------------------------------------
!     /*    bead parallelization parameter                            */
!-----------------------------------------------------------------------

      call read_int1_MPI( np_beads, '<np_beads>', 10, iounit )

      if ( np_beads .le. 0 ) then

         np_beads = min(nbead,nprocs_world)

      else if ( np_beads .gt. nbead ) then

         if ( myrank_world .eq. 0 ) then
            write ( 6, '(a)' )
            write ( 6, '(a,a)' ) &
     &         'Warning - np_beads is larger than nbead. ', &
     &         'np_beads reset to 1.'
         end if

         np_beads = 1

      else if ( mod(nprocs_world,np_beads) .ne. 0 ) then

         if ( myrank_world .eq. 0 ) then
            write ( 6, '(a)' )
            write ( 6, '(a,a)' ) &
     &         'Warning - nprocs is not a multiple of np_beads. ', &
     &         'np_beads reset to 1.'
         end if

         np_beads = 1

      end if

!-----------------------------------------------------------------------
!     /*    force parallelization parameter                           */
!-----------------------------------------------------------------------

      call read_int1_MPI( np_force, '<np_force>', 10, iounit )

      if ( np_force .le. 0 ) then

         np_force = nprocs_world / np_beads

      else if ( np_force .gt. nprocs_world ) then

         np_force = nprocs_world / np_beads

      else if ( mod(nprocs_world,np_force) .ne. 0 ) then

         if ( myrank_world .eq. 0 ) then
            write ( 6, '(a)' )
            write ( 6, '(a,a)' ) &
     &         'Warning - nprocs is not a multiple of np_force. ', &
     &         'np_beads reset to 1.'
         end if

         np_beads = 1
         np_force = nprocs_world / np_beads

      else

         np_beads = nprocs_world / np_force

      end if

!-----------------------------------------------------------------------
!     /*    force cycles per step                                     */
!-----------------------------------------------------------------------

      if ( mod(nbead,(np_beads*np_force)) .eq. 0 ) then
         np_cycle = nbead / (np_beads*np_force)
      else
         np_cycle = nbead / (np_beads*np_force) + 1
      end if

!-----------------------------------------------------------------------
!     /*    print                                                     */
!-----------------------------------------------------------------------

      if( myrank_world .eq. 0 ) then

         write ( 6, '(a)' ) 

         write ( 6, '(a)' )  &
     &      '_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/' // &
     &      '_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/'

         write ( 6, '(a)' )

         write( 6, '(a)' )    'Information of parallel computation:'
         write( 6, '(a)' )
         write( 6, '(a,i6)' ) '  Number of processors  = ', nprocs_world
         write( 6, '(a,i6)' ) '  Number of beads       = ', nbead
         write( 6, '(a,i6)' ) '  Bead parallelization  = ', np_beads
         write( 6, '(a,i6)' ) '  Force parallelization = ', np_force
         write( 6, '(a,i6)' ) '  Force cycles per step = ', np_cycle
         write( 6, '(a)' )

      end if

      call my_mpi_barrier_world

!-----------------------------------------------------------------------
!     /*    error termination                                         */
!-----------------------------------------------------------------------

      call error_handling_MPI &
     &    ( ierr, 'subroutine communicate_2_MPI', 28 )

!-----------------------------------------------------------------------
!     /*    print                                                     */
!-----------------------------------------------------------------------

      if( myrank_world .eq. 0 ) then

          write ( 6, '(a)')  &
     &       '_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/' // &
     &       '_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/'

      end if

!-----------------------------------------------------------------------
!     /*    number of processors in main and sub groups               */
!-----------------------------------------------------------------------

      nprocs_main = np_beads
      nprocs_sub  = np_force
      nprocs_pimd = np_beads * np_force

      return
      end





!***********************************************************************
      subroutine setup_mpi_pimd_group_2
!***********************************************************************

      use common_variables, only : &
     &   nprocs_pimd, myrank_pimd, mpi_comm_pimd, &
     &   mpi_group_pimd, mpi_group_world

      implicit none

      integer, allocatable :: ranklist_pimd(:)

      integer :: ierr, i

      include 'mpif.h'

      allocate( ranklist_pimd(nprocs_pimd) )

      do i = 1, nprocs_pimd
         ranklist_pimd(i) = i - 1
      end do

      call MPI_GROUP_INCL ( mpi_group_world, nprocs_pimd, ranklist_pimd, &
     &                      mpi_group_pimd, ierr )

      call MPI_COMM_CREATE ( MPI_COMM_WORLD, mpi_group_pimd, &
     &                       mpi_comm_pimd, ierr )

      call MPI_COMM_RANK ( mpi_comm_pimd, myrank_pimd, ierr )

      deallocate( ranklist_pimd )

      return
      end





!***********************************************************************
      subroutine setup_mpi_main_group_2
!***********************************************************************

      use common_variables, only : &
     &   nprocs_main, nprocs_sub, myrank_world, myrank_main, &
     &   mpi_group_main, mpi_group_world, mpi_comm_main

      implicit none

      integer, allocatable :: ranklist_main(:)

      integer :: ierr, i

      include 'mpif.h'

      allocate( ranklist_main(nprocs_main) )

      do i = 1, nprocs_main
         ranklist_main(i) = nprocs_sub*(i-1) &
     &                    + mod(myrank_world,nprocs_sub)
      end do

      call MPI_GROUP_INCL ( mpi_group_world, nprocs_main, ranklist_main, &
     &                      mpi_group_main, ierr )

      call MPI_COMM_CREATE ( MPI_COMM_WORLD, mpi_group_main, &
     &                       mpi_comm_main, ierr )

      call MPI_COMM_RANK ( mpi_comm_main, myrank_main, ierr )

      deallocate( ranklist_main )

      return
      end





!***********************************************************************
      subroutine setup_mpi_sub_group_2
!***********************************************************************

      use common_variables, only : &
     &   nprocs_sub, mpi_group_sub, mpi_group_world, &
     &   myrank_world, mpi_comm_sub, myrank_sub

      implicit none

      integer, allocatable :: ranklist_sub(:)

      integer :: ierr, i

      include 'mpif.h'

      allocate( ranklist_sub(nprocs_sub) )

      do i = 1, nprocs_sub
         ranklist_sub(i) = int(myrank_world/nprocs_sub)*nprocs_sub + i-1
      end do

      call MPI_GROUP_INCL ( mpi_group_world, nprocs_sub, ranklist_sub, &
     &                      mpi_group_sub, ierr )

      call MPI_COMM_CREATE ( MPI_COMM_WORLD, mpi_group_sub, &
     &                       mpi_comm_sub, ierr )

      call MPI_COMM_RANK ( mpi_comm_sub, myrank_sub, ierr )

      deallocate( ranklist_sub )

      return
      end





#ifdef debug
!***********************************************************************
      subroutine my_mpi_debug_2
!***********************************************************************

      use common_variables, only : &
     &   np_beads, np_force, np_cycle, nprocs_world, nprocs, nprocs_sub, &
     &   myrank, myrank_world, myrank_main, myrank_sub, nprocs_main, &
     &   nbead

      implicit none

      integer :: i

      if ( myrank .eq. 0 ) then
         write( 6, '(a)' )
         write( 6, '(a)' ) "Debugging information for MPI:"
         write( 6, '(a)' )
         write( 6, '(a,i6)' ) "nbead:        ", nbead
         write( 6, '(a,i6)' ) "np_beads:     ", np_beads
         write( 6, '(a,i6)' ) "np_force:     ", np_force
         write( 6, '(a,i6)' ) "np_cycle:     ", np_cycle
         write( 6, '(a,i6)' ) "nprocs_world: ", nprocs_world
         write( 6, '(a,i6)' ) "nprocs:       ", nprocs
         write( 6, '(a,i6)' ) "nprocs_main:  ", nprocs_main
         write( 6, '(a,i6)' ) "nprocs_sub:   ", nprocs_sub

         write( 6, '(a)' )
         write( 6, '(a,i6)' ) &
     &      "myrank_world myrank myrank_main myrank_sub"
      end if

      call my_mpi_barrier

      do i = 0, nprocs-1
         if ( myrank .eq. i ) then
            write( 6, '(i12,i7,i12,i11)' ) &
     &         myrank, myrank_world, myrank_main, myrank_sub
            flush( 6 )
         end if
         call my_mpi_barrier
      end do

!cc      call my_mpi_finalize_2; stop

      return
      end
#endif





!***********************************************************************
      subroutine my_mpi_allreduce_real_0_main ( a )
!***********************************************************************

      use common_variables, only : mpi_comm_main

      implicit none

      integer :: ierr
      real(8) :: a, b1(1), b2(1)

      include 'mpif.h'

      b1(1) = a

      call MPI_ALLREDUCE ( b1, b2, 1, MPI_DOUBLE_PRECISION, &
     &                     MPI_SUM, mpi_comm_main, ierr )

      a = b2(1)

      return
      end





!***********************************************************************
      subroutine my_mpi_allreduce_real_1_main ( a, n )
!***********************************************************************

      use common_variables, only : mpi_comm_main

      implicit none

      integer :: i, n, ierr
      real(8) :: a(n), b1(n), b2(n)

      include 'mpif.h'

      do i = 1, n
         b1(i) = a(i)
      end do

      call MPI_ALLREDUCE ( b1, b2, n, MPI_DOUBLE_PRECISION, &
     &                     MPI_SUM, mpi_comm_main, ierr )

      do i = 1, n
         a(i) = b2(i)
      end do

      return
      end





!***********************************************************************
      subroutine my_mpi_allreduce_real_2_main ( a, n1, n2 )
!***********************************************************************

      use common_variables, only : mpi_comm_main

      implicit none

      integer :: i, j, k, n, n1, n2, ierr
      real(8) :: a(n1,n2), b1(n1*n2), b2(n1*n2)

      include 'mpif.h'

      k = 0
      do j = 1, n2
      do i = 1, n1
         k = k + 1
         b1(k) = a(i,j)
      end do
      end do

      n = n1*n2

      call MPI_ALLREDUCE ( b1, b2, n, MPI_DOUBLE_PRECISION, &
     &                     MPI_SUM, mpi_comm_main, ierr )

      k = 0
      do j = 1, n2
      do i = 1, n1
         k = k + 1
         a(i,j) = b2(k)
      end do
      end do

      return
      end





!***********************************************************************
      subroutine my_mpi_allreduce_real_3_main ( a, n1, n2, n3 )
!***********************************************************************

      use common_variables, only : mpi_comm_main

      implicit none

      integer :: i, j, k, l, n, n1, n2, n3, ierr
      real(8) :: a(n1,n2,n3), b1(n1*n2*n3), b2(n1*n2*n3)

      include 'mpif.h'

      l = 0
      do k = 1, n3
      do j = 1, n2
      do i = 1, n1
         l = l + 1
         b1(l) = a(i,j,k)
      end do
      end do
      end do

      n = n1*n2*n3

      call MPI_ALLREDUCE ( b1, b2, n, MPI_DOUBLE_PRECISION, &
     &                     MPI_SUM, mpi_comm_main, ierr )

      l = 0
      do k = 1, n3
      do j = 1, n2
      do i = 1, n1
         l = l + 1
         a(i,j,k) = b2(l)
      end do
      end do
      end do

      return
      end





!***********************************************************************
      subroutine my_mpi_allreduce_real_0_sub ( a )
!***********************************************************************

      use common_variables, only : mpi_comm_sub

      implicit none

      integer :: ierr
      real(8) :: a, b1(1), b2(1)

      include 'mpif.h'

      b1(1) = a

      call MPI_ALLREDUCE ( b1, b2, 1, MPI_DOUBLE_PRECISION, &
     &                     MPI_SUM, mpi_comm_sub, ierr )

      a = b2(1)

      return
      end





!***********************************************************************
      subroutine my_mpi_allreduce_real_1_sub ( a, n )
!***********************************************************************

      use common_variables, only : mpi_comm_sub

      implicit none

      integer :: i, n, ierr
      real(8) :: a(n), b1(n), b2(n)

      include 'mpif.h'

      do i = 1, n
         b1(i) = a(i)
      end do

      call MPI_ALLREDUCE ( b1, b2, n, MPI_DOUBLE_PRECISION, &
     &                     MPI_SUM, mpi_comm_sub, ierr )

      do i = 1, n
         a(i) = b2(i)
      end do

      return
      end





!***********************************************************************
      subroutine my_mpi_allreduce_real_2_sub ( a, n1, n2 )
!***********************************************************************

      use common_variables, only : mpi_comm_sub

      implicit none

      integer :: i, j, k, n, n1, n2, ierr
      real(8) :: a(n1,n2), b1(n1*n2), b2(n1*n2)

      include 'mpif.h'

      k = 0
      do j = 1, n2
      do i = 1, n1
         k = k + 1
         b1(k) = a(i,j)
      end do
      end do

      n = n1*n2

      call MPI_ALLREDUCE ( b1, b2, n, MPI_DOUBLE_PRECISION, &
     &                     MPI_SUM, mpi_comm_sub, ierr )

      k = 0
      do j = 1, n2
      do i = 1, n1
         k = k + 1
         a(i,j) = b2(k)
      end do
      end do

      return
      end





!***********************************************************************
      subroutine my_mpi_bcast_real_0_sub ( a )
!***********************************************************************

      use common_variables, only : mpi_comm_sub

      implicit none

      integer :: ierr
      real(8) :: a, b(1)

      include 'mpif.h'

      b(1) = a

      call MPI_BCAST ( b, 1, MPI_DOUBLE_PRECISION, &
     &                 0, mpi_comm_sub, ierr )

      a = b(1)

      return
      end





!***********************************************************************
      subroutine my_mpi_bcast_real_1_sub ( a, n )
!***********************************************************************

      use common_variables, only : mpi_comm_sub

      implicit none

      integer :: ierr, n
      real(8) :: a(n)

      include 'mpif.h'

      call MPI_BCAST ( a, n, MPI_DOUBLE_PRECISION, &
     &                 0, mpi_comm_sub, ierr )

      return
      end





!***********************************************************************
      subroutine my_mpi_bcast_real_2_sub ( a, n1, n2 )
!***********************************************************************

      use common_variables, only : mpi_comm_sub

      implicit none

      integer :: i, j, k, n, n1, n2, ierr
      real(8) :: a(n1,n2), b(n1*n2)

      include 'mpif.h'

      k = 0
      do j = 1, n2
      do i = 1, n1
         k = k + 1
         b(k) = a(i,j)
      end do
      end do

      n = n1*n2

      call MPI_BCAST ( b, n, MPI_DOUBLE_PRECISION, &
     &                 0, mpi_comm_sub, ierr )

      k = 0
      do j = 1, n2
      do i = 1, n1
         k = k + 1
         a(i,j) = b(k)
      end do
      end do

      return
      end





!***********************************************************************
      subroutine my_mpi_bcast_real_3_sub ( a, n1, n2, n3 )
!***********************************************************************

      use common_variables, only : mpi_comm_sub

      implicit none

      integer :: i, j, k, l, n, n1, n2, n3, ierr
      real(8) :: a(n1,n2,n3), b(n1*n2*n3)

      include 'mpif.h'

      l = 0
      do k = 1, n3
      do j = 1, n2
      do i = 1, n1
         l = l + 1
         b(l) = a(i,j,k)
      end do
      end do
      end do

      n = n1*n2*n3

      call MPI_BCAST ( b, n, MPI_DOUBLE_PRECISION, &
     &                 0, mpi_comm_sub, ierr )

      l = 0
      do k = 1, n3
      do j = 1, n2
      do i = 1, n1
         l = l + 1
         a(i,j,k) = b(l)
      end do
      end do
      end do

      return
      end
