!///////////////////////////////////////////////////////////////////////
!
!      Author:          M. Shiga
!      Last updated:    July 20, 2024 by M. Shiga
!      Description:     energy and force from LIBNNP calculation
!
!///////////////////////////////////////////////////////////////////////


!     // LIBNNP defined
#ifdef libnnp


!***********************************************************************
      module libnnp_variables
!***********************************************************************

      use nnp_kinds, only : &
     &   stin, sp, stsh, dp, isp

      use nnp_message_passing, only : &
     &   mp_world_type

      use nnp_types, only : &
     &   nnp_type

      type(mp_world_type) :: mp_world ! communicator

      type(nnp_type) :: nnp1 ! global nnp object (NN-PES)
      type(nnp_type) :: nnp2 ! global nnp object (NN-DMS)
      type(nnp_type) :: nnp3 ! global nnp object (NN-IP)

      character(len=stsh), dimension(:),   allocatable :: symb
      real(kind=dp),       dimension(:,:), allocatable :: pos
      real(kind=dp),       dimension(:,:), allocatable :: forces
      real(kind=dp),       dimension(3)                :: dip
      real(kind=dp),       dimension(:),   allocatable :: charges
      real(kind=dp)                                    :: ener
      real(kind=dp), save                              :: ref_charge
      logical                                          :: expol1
      logical                                          :: expol2
      logical                                          :: expol3
      real(kind=dp),       dimension(3,3)              :: lattice

      character(len=stsh), dimension(:),   allocatable :: symb_int
      real(kind=dp),       dimension(:,:), allocatable :: pos_int
      real(kind=dp),       dimension(:,:), allocatable :: forces_int
      real(kind=dp)                                    :: ener_int

      character(len=stin)  :: inputname
      character(len=stsh)  :: dummy

!***********************************************************************
      end module libnnp_variables
!***********************************************************************





!***********************************************************************
      subroutine force_libnnp_MPI
!***********************************************************************
!-----------------------------------------------------------------------
!     //   shared variables
!-----------------------------------------------------------------------

      use common_variables, only : &
     &   x, y, z, fx, fy, fz, pot, dipx, dipy, dipz, vir, species, &
     &   natom, nbead, iounit, iboundary, box, myrank, nprocs_main, &
     &   myrank_main, nprocs_sub, myrank_sub, mpi_comm_sub

      use nnp_message_passing, only : &
     &   mp_bcast

      use nnp_predict, only : &
     &   nnp_initialize, nnp_predict_dip, nnp_predict_ener_forces

      use libnnp_variables, only : &
     &   mp_world, nnp1, nnp2, nnp3, symb, pos, forces, dip, charges, &
     &   ener, ref_charge, expol1, expol2, expol3, lattice, inputname, &
     &   symb_int, pos_int, forces_int, ener_int

!-----------------------------------------------------------------------
!     //   local variables
!-----------------------------------------------------------------------

      implicit none

      integer, save :: iset = 0
      integer :: i, m, l
      integer :: matom
      character(len=3) :: char
      integer, save :: ioption_dms_libnnp = 0
      integer, save :: natom_ip_libnnp = 0
      character(len=80), save :: file_pes_libnnp
      character(len=80), save :: file_dms_libnnp
      character(len=80), save :: file_ip_libnnp

!-----------------------------------------------------------------------
!     //   initial settings
!-----------------------------------------------------------------------

!     //   initial setting
      if ( iset .eq. 0 ) then

!-----------------------------------------------------------------------
!        //   set communicator for libnnp
!-----------------------------------------------------------------------

         mp_world%iroot = 0
         mp_world%iproc = myrank_sub
         mp_world%nproc = nprocs_sub

         mp_world%root = .false.
         if (mp_world%iproc .eq. 0 ) mp_world%root = .true.

         mp_world%comm = mpi_comm_sub

!-----------------------------------------------------------------------
!        //   reference charge
!-----------------------------------------------------------------------

         call read_real1_MPI &
     &      ( ref_charge, '<charge_libnnp>', 15, iounit )

!-----------------------------------------------------------------------
!        //   option: 1 calculate dipole moment or 0 not
!-----------------------------------------------------------------------

         call read_int1_MPI &
     &      ( ioption_dms_libnnp, '<ioption_dms_libnnp>', 20, iounit )

!-----------------------------------------------------------------------
!        //   option: 1 calculate interaction potential or 0 not
!-----------------------------------------------------------------------

         call read_int1_MPI &
     &      ( natom_ip_libnnp, '<natom_ip_libnnp>', 17, iounit )

!-----------------------------------------------------------------------
!        //   input filename for potential energy surface
!-----------------------------------------------------------------------

!         file_pes_libnnp = "NN-PES/input.nn"
         call read_char_MPI &
     &      ( file_pes_libnnp, 80, '<file_pes_libnnp>', 17, iounit )

!-----------------------------------------------------------------------
!        //   input filename for dipole moment surface
!-----------------------------------------------------------------------

!         file_dms_libnnp = "NN-DMS/input.nn"
         call read_char_MPI &
     &      ( file_dms_libnnp, 80, '<file_dms_libnnp>', 17, iounit )

!-----------------------------------------------------------------------
!        //   input filename for interaction potential
!-----------------------------------------------------------------------

!         file_dms_libnnp = "NN-IP/input.nn"
         call read_char_MPI &
     &      ( file_ip_libnnp, 80, '<file_ip_libnnp>', 16, iounit )

!-----------------------------------------------------------------------
!        //   read neural networks
!-----------------------------------------------------------------------

         call int3_to_char ( myrank_main, char )

         if ( myrank_sub .eq. 0 ) &
     &      open ( 666, file = 'libnnp.' // char // '.out' )

!        //   energy and force neural network
         inputname = trim(file_pes_libnnp)
         call nnp_initialize( nnp1, mp_world, inputname )

!        //   dipole neural network
         inputname = trim(file_dms_libnnp)
         if ( ioption_dms_libnnp .eq. 1 ) then
            call nnp_initialize( nnp2, mp_world, inputname )
         end if

!        //   interaction potential
         inputname = trim(file_ip_libnnp)
         if ( natom_ip_libnnp .ge. 1 ) then
            call nnp_initialize( nnp3, mp_world, inputname )
         end if

         if ( myrank_sub .eq. 0 ) close( 666 )

!-----------------------------------------------------------------------
!        //   keep only the output file of first bead
!-----------------------------------------------------------------------

         call my_mpi_barrier

         if ( myrank .eq. 0 ) then
            call system &
     &         ( "mv libnnp.000.out libnnp.out; " // &
     &           "rm -f libnnp.*.out" )
         end if

!-----------------------------------------------------------------------
!        //   memory allocation
!-----------------------------------------------------------------------

!        //   number of atoms without noble gas atoms
         matom = natom - natom_ip_libnnp

         if ( natom_ip_libnnp .eq. 0 ) then 

            if ( .not. allocated(symb)    ) allocate(symb(natom))
            if ( .not. allocated(pos)     ) allocate(pos(3,natom))
            if ( .not. allocated(forces)  ) allocate(forces(3,natom))
            if ( .not. allocated(charges) ) allocate(charges(natom))

         else

            if ( .not. allocated(symb)    ) allocate(symb(matom))
            if ( .not. allocated(pos)     ) allocate(pos(3,matom))
            if ( .not. allocated(forces)  ) allocate(forces(3,matom))
            if ( .not. allocated(charges) ) allocate(charges(matom))
         
            if ( .not. allocated(symb_int)   ) &
     &         allocate(symb_int(matom+1))
            if ( .not. allocated(pos_int)    ) &
     &         allocate(pos_int(3,matom+1))
            if ( .not. allocated(forces_int) ) &
     &         allocate(forces_int(3,matom+1))

         end if

!-----------------------------------------------------------------------
!        //   atomic symbol
!-----------------------------------------------------------------------

         if ( natom_ip_libnnp .eq. 0 ) then

            do i = 1, natom
               symb(i) = species(i)(1:2)
            end do

            call mp_bcast( symb, mp_world )

         else

            do i = 1, matom
               symb(i) = species(i)(1:2)
            end do

            do i = 1, matom+1
               symb_int(i) = species(i)(1:2)
            end do

            call mp_bcast( symb, mp_world )
            call mp_bcast( symb_int, mp_world )

         end if

!-----------------------------------------------------------------------
!        //   set complete
!-----------------------------------------------------------------------

         iset = 1

!     //   initial setting
      end if

!-----------------------------------------------------------------------
!     //   box
!-----------------------------------------------------------------------

      lattice(:,:) = 0.d0

      if ( iboundary .eq. 1 ) lattice(:,:) = box(:,:)
      if ( iboundary .eq. 2 ) lattice(:,:) = box(:,:)

!-----------------------------------------------------------------------
!     //   number of atoms without noble gas atoms
!-----------------------------------------------------------------------

      matom = natom - natom_ip_libnnp

!-----------------------------------------------------------------------
!     //   set communicator for libnnp
!-----------------------------------------------------------------------

!     /*   loop of beads   */
      do m = 1, nbead

!        /*   bead parallel   */
         if ( mod( m-1, nprocs_main ) .ne. myrank_main ) cycle

!-----------------------------------------------------------------------
!        //   position of bead m
!-----------------------------------------------------------------------

         do i = 1, matom
            pos(1,i) = x(i,m)
            pos(2,i) = y(i,m)
            pos(3,i) = z(i,m)
         end do

         call mp_bcast( pos, mp_world )

!-----------------------------------------------------------------------
!        //   predict energy and forces of bead m
!-----------------------------------------------------------------------

         if ( iboundary .eq. 0 ) then
            call nnp_predict_ener_forces &
     &         ( nnp1, symb, pos, ener, forces, expol1, mp_world )
         else
            call nnp_predict_ener_forces &
     &         ( nnp1, symb, pos, ener, forces, expol1, mp_world, &
     &           lattice )
         end if

!-----------------------------------------------------------------------
!        //   substitution of bead m
!-----------------------------------------------------------------------

         if ( myrank_sub .eq. 0 ) then

!           //   potential
            pot(m) = ener

!           //   forces
            do i = 1, matom
               fx(i,m) = forces(1,i)
               fy(i,m) = forces(2,i)
               fz(i,m) = forces(3,i)
            end do

!           //   virial: unfinished
            vir(:,:) = 0.d0

         end if

!-----------------------------------------------------------------------
!        //   predict dipoles of bead m
!-----------------------------------------------------------------------

         if ( ioption_dms_libnnp .eq. 1 ) then

            if ( iboundary .eq. 0 ) then
               call nnp_predict_dip &
     &         ( nnp2, symb, pos, ref_charge, dip, expol2, mp_world, &
     &           charges )
            else
               call nnp_predict_dip &
     &         ( nnp2, symb, pos, ref_charge, dip, expol2, mp_world, &
     &           charges, lattice )
            end if

         end if

!-----------------------------------------------------------------------
!        //   substitution of bead m
!-----------------------------------------------------------------------

         if ( myrank_sub .eq. 0 ) then

!           //   dipole moment
            if ( ioption_dms_libnnp .eq. 1 ) then
               dipx(m) = dip(1)
               dipy(m) = dip(2)
               dipz(m) = dip(3)
            end if

         end if

!-----------------------------------------------------------------------
!        //   predict interaction of bead m
!-----------------------------------------------------------------------

         do l = 1, natom_ip_libnnp

            do i = 1, matom
               pos_int(1,i) = x(i,m)
               pos_int(2,i) = y(i,m)
               pos_int(3,i) = z(i,m)
            end do

            pos_int(1,matom+1) = x(matom+l,m)
            pos_int(2,matom+1) = y(matom+l,m)
            pos_int(3,matom+1) = z(matom+l,m)

            call mp_bcast( pos_int, mp_world )

            if ( iboundary .eq. 0 ) then
               call nnp_predict_ener_forces &
     &         ( nnp3, symb_int, pos_int, ener_int, forces_int, &
     &           expol3, mp_world )
            else
               call nnp_predict_ener_forces &
     &         ( nnp3, symb_int, pos_int, ener_int, forces_int, &
     &           expol3, mp_world, lattice )
            end if

!-----------------------------------------------------------------------
!           //   substitution of bead m
!-----------------------------------------------------------------------

            if ( myrank_sub .eq. 0 ) then

!              //   potential
               pot(m) = pot(m) + ener_int

!              //   forces
               do i = 1, matom
                  fx(i,m) = fx(i,m) + forces_int(1,i)
                  fy(i,m) = fy(i,m) + forces_int(2,i)
                  fz(i,m) = fz(i,m) + forces_int(3,i)
               end do

!              //   forces
               fx(matom+l,m) = fx(matom+l,m) + forces_int(1,matom+1)
               fy(matom+l,m) = fy(matom+l,m) + forces_int(2,matom+1)
               fz(matom+l,m) = fz(matom+l,m) + forces_int(3,matom+1)

            end if

         end do

!     /*   loop of beads   */
      end do

!-----------------------------------------------------------------------
!     /*   all-reduce communication                                   */
!-----------------------------------------------------------------------

      call my_mpi_allreduce_md

      return
      end


!     // LIBNNP undefined
#else


!***********************************************************************
      subroutine force_libnnp_MPI
!***********************************************************************

      use common_variables, only : myrank

      implicit none

      if( myrank .eq. 0 ) then

         write(6, '(a)') 'Error termination.'
         write(6, '(a)') ''
         write(6, '(a)') &
     &      'ipotential=LIBNNP is not available because'
         write(6, '(a)') &
     &      'LIBNNP subroutines were not linked.'
         write(6, '(a)') &
     &      'Compile the pimd.mpi.x with the following.'
         write(6, '(a)')
         write(6, '(a)') &
     &      '  FLAGMP = -Dlibnnp'
         write(6, '(a)')
      end if

      call my_mpi_barrier
      call my_mpi_abort

      return
      end


!     // LIBNNP
#endif
