!///////////////////////////////////////////////////////////////////////
!
!      Author:          M. Shiga
!      Last updated:    Nov 1, 2019 by M. Shiga
!      Description:     energy and force from VASP calculation
!
!///////////////////////////////////////////////////////////////////////
!***********************************************************************
      module vasp_variables
!***********************************************************************

!-----------------------------------------------------------------------
!     general variables
!-----------------------------------------------------------------------

!     /*   current number of MD steps   */
      integer :: pimd_istep

!     /*   current bead   */
      integer :: pimd_ibead

!     /*   if 1, all processess generate output file   */
      integer :: vasp_output_all_proc

!     /*   VASP output results every n-step   */
      integer :: vasp_output_every_nstep

!     /*   flag whether to reuse wave-function   */
      integer :: vasp_reuse_wavefunction

!-----------------------------------------------------------------------
!     energy, coordinate, and force
!-----------------------------------------------------------------------

!     /*    total energy (au)   */
      real(8) :: total_energy

!     /*    nuclear coordinate (au,cartesian)   */
      real(8), allocatable :: coord_x(:), coord_y(:), coord_z(:)

!     /*    force (au)   */
      real(8), allocatable :: force_x(:), force_y(:), force_z(:)

!     /*    stress tensor (au)   */
      real(8) :: stress_tensor(3,3)

!     /*    stress tensor (kB)   */
      real(8) :: stress_tensor_kb(3,3)

!     /*    cell parameter   */
      real(8) :: cell_param(3,3)

!     /*    dipole moment   */
      real(8) :: dipole_x, dipole_y, dipole_z

!-----------------------------------------------------------------------
!     wavefunciton information
!-----------------------------------------------------------------------

!     /*   vasp wavefunction type   */
      type vaspwave

!     /*   flag whether empty or not   */
      logical :: empty

!     /*   wavefunction CPTWFP(NRPLWV,NBANDS,NKPTS,ISPIN)   */
      complex(SELECTED_REAL_KIND(10)), pointer :: CPTWFP(:,:,:,:)

!     /*   vasp wavefunction type   */
      end type vaspwave

!     /* wavefunctions for all beads   */
      type(vaspwave), allocatable, target :: vwaves(:)

!     /* wavefunctions for current bead   */
      type(vaspwave), pointer :: vwave

!***********************************************************************
      end module vasp_variables
!***********************************************************************





!***********************************************************************
      subroutine force_vasp_MPI
!***********************************************************************

!-----------------------------------------------------------------------
!     /*   shared variables                                           */
!-----------------------------------------------------------------------

      use common_variables, only : &
     &   x, y, z, fx, fy, fz, dipx, dipy, dipz, pot, au_energy, volume, &
     &   volume_bead, au_length, box, vir, vir_bead, volume_bead, &
     &   box_bead, method, istep, natom, nbead, method, &
     &   myrank_main, nprocs_main, myrank

      use vasp_variables, only : &
     &   coord_x, coord_y, coord_z, dipole_x, dipole_y, dipole_z, &
     &   total_energy, stress_tensor, stress_tensor_kb, vwave, vwaves, &
     &   pimd_istep, pimd_ibead, cell_param, force_x, force_y, force_z

!-----------------------------------------------------------------------
!     /*   local variables                                            */
!-----------------------------------------------------------------------

      implicit none

      integer :: ierr, ibead, i, j
      real(8) :: stress_tensor_kb_sum(3,3)

!-----------------------------------------------------------------------
!     /*   initialize stress tensor                                   */
!-----------------------------------------------------------------------

      stress_tensor(:,:) = 0.d0
      stress_tensor_kb_sum(:,:) = 0.d0

!-----------------------------------------------------------------------
!     /*   initialize dipole moment                                   */
!-----------------------------------------------------------------------

      dipole_x = 0.d0
      dipole_y = 0.d0
      dipole_z = 0.d0

!-----------------------------------------------------------------------
!     /*   start loop of beads                                        */
!-----------------------------------------------------------------------

      do ibead = 1, nbead

!-----------------------------------------------------------------------
!     /*   set box                                                    */
!-----------------------------------------------------------------------

      if ( method(1:6) .eq. 'REHMC ' ) then
         cell_param(:,:) = box_bead(:,:,ibead)
      else
         cell_param(:,:) = box(:,:)
      end if

!-----------------------------------------------------------------------
!     /*   skip if `ibead is not my job'                              */
!-----------------------------------------------------------------------

      if ( mod(ibead-1,nprocs_main) .ne. myrank_main ) cycle

!-----------------------------------------------------------------------
!     /*   execute VASP subroutine                                    */
!-----------------------------------------------------------------------

!     /*   set istep   */
      pimd_istep = istep

!     /*   set ibaed   */
      pimd_ibead = ibead

!     /*   set coordinate   */
      do i = 1, natom
         coord_x(i) = x(i,ibead)
         coord_y(i) = y(i,ibead)
         coord_z(i) = z(i,ibead)
      end do

!     /*   set wavefunctions   */
      vwave => vwaves(ibead)

!     /*   execute VASP subroutine   */
      ierr = 0
      call vasp_force(ierr)
      if( ierr .ne. 0 ) then
         if ( myrank .eq. 0 ) then
            write( 6, '(a)' ) &
     &        'Error - unable to execute VASP. See log file.'
            write( 6, '(a)' )
         end if
         call error_handling_MPI &
     &        ( 1, 'subroutine force_vasp_MPI', 25 )
      end if

!     /*   get total energy   */
      pot(ibead) = total_energy

!     /*   get force   */
      do i = 1, natom
         fx(i,ibead) = force_x(i)
         fy(i,ibead) = force_y(i)
         fz(i,ibead) = force_z(i)
      end do

!     /*   get stress   */
      stress_tensor_kb_sum(:,:) = stress_tensor_kb_sum(:,:) &
     &                          + stress_tensor_kb(:,:)

!     /*   virial of each bead   */
      if ( method(1:6) .eq. 'REHMC ' ) then
         vir_bead(:,:,ibead) = vir_bead(:,:,ibead) &
     &                       + stress_tensor_kb(:,:) &
     &                       * 1.e+8 / au_energy * au_length**3 &
     &                       * volume_bead(ibead)
      end if

!     /*   get dipole moment   */
      dipx(ibead) = dipole_x
      dipy(ibead) = dipole_y
      dipz(ibead) = dipole_z

!-----------------------------------------------------------------------
!     /*   end loop of beads                                          */
!-----------------------------------------------------------------------

      end do

!-----------------------------------------------------------------------
!     /*   change units from kilobar to hartree/bohr**3               */
!-----------------------------------------------------------------------

!     /*   get stress   */
      stress_tensor(:,:) = stress_tensor_kb_sum(:,:) &
     &                   * 1.e+8 / au_energy * au_length**3

!-----------------------------------------------------------------------
!     /*   all-reduce communication                                   */
!-----------------------------------------------------------------------

!     /*   potential   */
      call my_mpi_allreduce_real_1_main ( pot, nbead )

!     /*   force   */
      call my_mpi_allreduce_real_2_main ( fx, natom, nbead )
      call my_mpi_allreduce_real_2_main ( fy, natom, nbead )
      call my_mpi_allreduce_real_2_main ( fz, natom, nbead )

!     /*   stress   */
      call my_mpi_allreduce_real_2_main ( stress_tensor, 3, 3 )

!     /*   dipole moment   */
      call my_mpi_allreduce_real_1_main ( dipx, nbead )
      call my_mpi_allreduce_real_1_main ( dipy, nbead )
      call my_mpi_allreduce_real_1_main ( dipz, nbead )

!     /*   stress   */
      call my_mpi_allreduce_real_2_main ( stress_tensor, 3, 3 )

!     /*   virial of each bead   */
      if ( method(1:6) .eq. 'REHMC ' ) then
         call my_mpi_allreduce_real_3_main ( vir_bead, 3, 3, nbead )
      end if

!-----------------------------------------------------------------------
!     /*   virial                                                     */
!-----------------------------------------------------------------------

      if ( method(1:6) .eq. 'REHMC ' ) then

         do j = 1, nbead

            vir(1,1) = vir(1,1) + vir_bead(1,1,j) / nbead
            vir(1,2) = vir(1,2) + vir_bead(1,2,j) / nbead
            vir(1,3) = vir(1,3) + vir_bead(1,3,j) / nbead
            vir(2,1) = vir(2,1) + vir_bead(2,1,j) / nbead
            vir(2,2) = vir(2,2) + vir_bead(2,2,j) / nbead
            vir(2,3) = vir(2,3) + vir_bead(2,3,j) / nbead
            vir(3,1) = vir(3,1) + vir_bead(3,1,j) / nbead
            vir(3,2) = vir(3,2) + vir_bead(3,2,j) / nbead
            vir(3,3) = vir(3,3) + vir_bead(3,3,j) / nbead

         end do

      else

         do j = 1, nbead

            vir(1,1) = vir(1,1) + stress_tensor(1,1) * volume / nbead
            vir(1,2) = vir(1,2) + stress_tensor(1,2) * volume / nbead
            vir(1,3) = vir(1,3) + stress_tensor(1,3) * volume / nbead
            vir(2,1) = vir(2,1) + stress_tensor(2,1) * volume / nbead
            vir(2,2) = vir(2,2) + stress_tensor(2,2) * volume / nbead
            vir(2,3) = vir(2,3) + stress_tensor(2,3) * volume / nbead
            vir(3,1) = vir(3,1) + stress_tensor(3,1) * volume / nbead
            vir(3,2) = vir(3,2) + stress_tensor(3,2) * volume / nbead
            vir(3,3) = vir(3,3) + stress_tensor(3,3) * volume / nbead

         end do

      end if

      return
      end





!***********************************************************************
      subroutine init_vasp_MPI
!***********************************************************************

!-----------------------------------------------------------------------
!     /*   shared variables                                           */
!-----------------------------------------------------------------------

      use common_variables, only : myrank, iounit

      use vasp_variables, only : &
     &   vasp_output_all_proc, vasp_output_every_nstep, &
     &   vasp_reuse_wavefunction

!-----------------------------------------------------------------------
!     /*   local variables                                            */
!-----------------------------------------------------------------------

      implicit none

      integer :: ierr

!-----------------------------------------------------------------------
!     /*   set VASP parameters                                        */
!-----------------------------------------------------------------------

!c      call read_int1_MPI ( vasp_output_all_proc,
!c     &     '<vasp_output_all_proc>', 22, iounit )

!c      call read_int1_MPI ( vasp_output_every_nstep,
!c     &     '<vasp_output_every_nstep>', 25, iounit )

      call read_int1_MPI ( vasp_reuse_wavefunction, &
     &     '<vasp_reuse_wavefunction>', 25, iounit )

!     /*   parent process only   */
      if ( myrank .eq. 0 ) then

!        /*   file open   */
         open ( iounit, file = 'input.dat' )

!        /*   search for tag    */
         call search_tag ( '<vasp_output>', 13, iounit, ierr )

!        /*   read a line   */
         read ( iounit, *, iostat=ierr ) &
     &      vasp_output_all_proc, &
     &      vasp_output_every_nstep

!        /*   file close   */
         close ( iounit )

!        /*   otherwise, read from default input   */
         if ( ierr .ne. 0 ) then

!           /*   file open   */
            open ( iounit, file = 'input_default.dat' )

!           /*   search for tag    */
            call search_tag( '<vasp_output>', 13, iounit, ierr )

!           /*   read a line   */
            read ( iounit, *, iostat=ierr ) &
     &         vasp_output_all_proc, &
     &         vasp_output_every_nstep

!           /*   file close   */
            close ( iounit )

!        /*   otherwise, read from default input   */
         end if

!     /*   parent process only   */
      end if

!     /*   communicate   */
      call my_mpi_bcast_int_0( ierr )

!     /*   error message   */
      if ( ierr .ne. 0 ) then
         if ( myrank .eq. 0 ) then
            write ( 6, '(a)' ) &
     &         'Error - <vasp_output> read incorrectly.'
            write ( 6, '(a)' )
         end if
      end if

!     /*   error termination   */
      call error_handling_MPI &
     &     ( ierr, 'subroutine init_vasp_MPI', 24 )

!     /*   communicate   */
      call my_mpi_bcast_int_0( vasp_output_all_proc )

!     /*   communicate   */
      call my_mpi_bcast_int_0( vasp_output_every_nstep )

!-----------------------------------------------------------------------
!     /*   print input parameters                                     */
!-----------------------------------------------------------------------

!c     /*   parent process only   */
!      if ( myrank .eq. 0 ) then
!
!         write( 6, '(a)' )
!     &      'VASP input files: INCAR, POSCAR, KPOINTS, POTCAR.'
!
!         if      ( vasp_output_all_proc .eq. 0 ) then
!            write( 6, '(a,i0,a)' )
!     &         'VASP log files: STDOUT, OUTCAR ' //
!     &         'printed in master process every ',
!     &         vasp_output_every_nstep, ' steps.'
!         else if ( vasp_output_all_proc .eq. 1 ) then
!            write( 6, '(a,i0,a)' )
!     &         'VASP log files: STDOUT, OUTCAR ' //
!     &         'printed in all processes every ',
!     &         vasp_output_every_nstep, ' steps.'
!         else
!            ierr = 1
!         end if
!
!         if      ( vasp_reuse_wavefunction .eq. 0 ) then
!            write( 6, '(a)' )
!     &         'VASP: wave functions initialized every step.'
!         else if ( vasp_reuse_wavefunction .eq. 1 ) then
!            write( 6, '(a)' )
!     &         'VASP: wave functions reused every step.'
!         else
!            ierr = 1
!         end if
!
!         write( 6, '(a)' )
!
!c     /*   parent process only   */
!      end if

!-----------------------------------------------------------------------
!     /*   error termination                                          */
!-----------------------------------------------------------------------

!      if( ierr .ne. 0 ) then
!
!         if ( myrank .eq. 0 ) then
!
!            write( 6, '(a)' )
!     &        'Error - VASP settings incorrect.'
!            write( 6, '(a)' )
!
!         end if
!
!         call error_handling_MPI
!     &        ( 1, 'subroutine init_vasp_MPI', 24 )
!
!      end if

!-----------------------------------------------------------------------
!     /*   initialize VASP                                            */
!-----------------------------------------------------------------------

      call vasp_init( ierr )

!-----------------------------------------------------------------------
!     /*   error termination                                          */
!-----------------------------------------------------------------------

      if( ierr .ne. 0 ) then

         if ( myrank .eq. 0 ) then

            write( 6, '(a)' ) &
     &        'Error - unable to initialize VASP. See log file.'
            write( 6, '(a)' )

         end if

         call error_handling_MPI &
     &        ( 1, 'subroutine init_vasp_MPI', 24 )

      end if

      return
      end





!***********************************************************************
      subroutine finalize_vasp_MPI
!***********************************************************************

!-----------------------------------------------------------------------
!     /*   shared variables                                           */
!-----------------------------------------------------------------------

      use common_variables, only : myrank

!-----------------------------------------------------------------------
!     /*   local variables                                            */
!-----------------------------------------------------------------------

      implicit none

      integer :: ierr

!-----------------------------------------------------------------------
!     /*   finalize VASP                                              */
!-----------------------------------------------------------------------

!     /*   finalize VASP subroutine   */

      ierr = 0
      call vasp_finalize(ierr)
      if( ierr .ne. 0 ) then
         if ( myrank .eq. 0 ) then
            write( 6, '(a)' ) &
     &        'Error - unable to finalize VASP. See log file.'
            write( 6, '(a)' )
         end if
         call error_handling_MPI &
     &        ( 1, 'subroutine finalize_vasp_MPI', 28 )
      end if

      return
      end





!=======================================================================
!
!     The following dummy subroutines are applied
!     when libvasp.a is not available
!
!=======================================================================

#ifndef vasp

!***********************************************************************
      subroutine vasp_init( ierr )
!***********************************************************************

      use common_variables, only : myrank

      implicit none

      integer, intent(out) :: ierr

      ierr = 1

      if ( myrank .eq. 0 ) then

         write( 6, '(a)' ) &
     &      'Error termination - VASP is not linked to PIMD.'
         write( 6, '(a)' )
         write( 6, '(a)' ) &
     &      'Recompile pimd.mpi.x with ../lib/libvasp.a ' // &
     &      'with the options'
         write( 6, '(a)' )
         write( 6, '(a)' ) &
     &       '  VASP = -Dvasp'
         write( 6, '(a)' ) &
     &       '  LIBVASP = -L../lib -lvasp -ldmy'
         write( 6, '(a)' )

      end if

      call my_mpi_barrier
      call my_mpi_abort

      return
      end

#endif





#ifndef vasp

!***********************************************************************
      subroutine vasp_force( ierr )
!***********************************************************************

      use common_variables, only : myrank

      implicit none

      integer, intent(out) :: ierr

      ierr = 1

      if( myrank .eq. 0 ) then

         write(6, '(a)') 'Error termination.'
         write(6, '(a)') ''
         write(6, '(a)') 'ipotential=VASP is not available because'
         write(6, '(a)') 'VASP subroutines were not linked.'
         write(6, '(a)') 'Compile the pimd.mpi.x with the followings.'
         write(6, '(a)') '(You need libvasp.a in ../lib)'
         write(6, '(a)') ''
         write(6, '(a)') '  FLAGMP = -Dvasp'
         write(6, '(a)') '  LINKMP = -L../lib -lvasp -ldmy'
         write(6, '(a)') ''
      end if

      call my_mpi_barrier
      call my_mpi_abort

      return
      end

#endif





#ifndef vasp

!***********************************************************************
      subroutine vasp_finalize( ierr )
!***********************************************************************

      use common_variables, only : myrank

      implicit none

      integer, intent(out) :: ierr

      ierr = 1

      if( myrank .eq. 0 ) then

         write(6, '(a)') 'Error termination.'
         write(6, '(a)') ''
         write(6, '(a)') &
     &      'VASP is not available. To link VASP routines, copy' &
     &     // ' libvasp.a to'
         write(6, '(a)') &
     &      ' ../lib directory, then recompile pimd.mpi.x with the' &
     &      // ' following options'
         write(6, '(a)') &
     &      'included in makefile.'
         write(6, '(a)') ''
         write(6, '(a)') '  FLAGMP = -Dvasp'
         write(6, '(a)') '  LINKMP = -L../lib -lvasp -ldmy'
         write(6, '(a)') ''

      end if

      call my_mpi_barrier
      call my_mpi_abort

      return
      end

#endif
