/****************************************************************************
 *
 * DFT++:  density functional package developed by
 *         the research group of Prof. Tomas Arias, MIT.
 *
 * Principal author: Sohrab Ismail-Beigi
 *
 * Modifications for MPI version: Kenneth P Esler,
 *                                Sohrab Ismail-Beigi, and
 *                                Tairan Wang.
 *
 * Modifications for LSD version: Jason A Cline
 *
 * Modifications for lattice/Pulay forces: Gabor Csanyi and
 *                                         Sohrab Ismail-Beigi
 *
 * Copyright (C) 1996-1998 The Massachusetts Institute of Technology (MIT).
 *
 ****************************************************************************/

/*
 *     Sohrab Ismail-Beigi           Mar. 29, 1997
 *
 * Calculate all the various energy terms
 *
 */

/* $Id: calcener.c,v 1.1.1.1 1999/11/10 01:30:17 tairan Exp $ */

#include <stdio.h>
#include <math.h>
#include "header.h"
// #include "parallel.h"

/*
 * Calclate kinetic energy:
 *
 *    KE = -0.5*sum_k { w[k]*trace(diag(C[k]^F[k]*L(C[k]))) }
 *
 */
void
calc_KE(Elecinfo *einfo,Elecvars *evars,Energies *ener)
{
#ifdef DFT_PROFILING
  timerOn(18);   // Turn on calc_KE timer
#endif // DFT_PROFILING

  int k;
  diag_matrix *F = einfo->F;
  real *w = einfo->w;
  column_bundle *C = evars->C;

  /* temporary workspace */
  int max_col_length = 0;
  for (k=0; k < einfo->nkpts; k++)
    if (max_col_length < C[k].col_length)
      max_col_length = C[k].col_length;

  column_bundle LC(C[0].tot_ncols,max_col_length);

  ener->KE = 0.0;
  for (k=0; k < einfo->nkpts; k++)
    {
      // manually adjust col_length;
      LC.col_length = C[k].col_length;
      copy_innards_column_bundle(&(C[k]),&LC);
      
/* does:  ener->KE += REAL(sum_vector(diaginner(F[k],C[k],L(C[k]))))*w[k]; */
      apply_L(C[k],LC);
      ener->KE += REAL(sum_vector(diaginner(F[k],C[k],LC)))*w[k];
    }


  ener->KE *= -0.5;

#ifdef DFT_PROFILING
  timerOff(18);   // Turn off calc_KE timer
#endif // DFT_PROFILING
}

/*
 * Calculate local pseudopotential energy:
 *
 *    Eloc = real{  n^Jdag(Vlocps) } 
 *
 * THE ROUTINE USES THE LOCAL PSEUDOPOTENTIAL IN evars->Vlocps.
 * It does NOT recalculate evars->Vlocps.
 *
 */
void
calc_Eloc(Elecvars *evars,
	  Energies *ener)
{
#ifdef DFT_PROFILING
  timerOn(19);   // Turn on calc_Eloc timer
#endif // DFT_PROFILING

  vector &n = evars->n;
  vector &Vlocps = evars->Vlocps;

  ener->Eloc = REAL( n^Jdag(Vlocps) );

#ifdef DFT_PROFILING
  timerOff(19);   // Turn off calc_Eloc timer
#endif // DFT_PROFILING
}



/*
 * Calculate non-local pseudopotential energy:
 *
 *    Enl = sum_{spcies,ions,l,m,k,...}
 *                    { trace(M_lm*Vnl^C[k]*F[k]*(Vnl^C[k])^ }
 *
 * For the Kleinman-Bylander case, the filling in of the Vnl is
 * done in parallel (parallelization over atoms) using the thread above.
 */
void
calc_Enl(Basis *basis,
	 Ioninfo *ioninfo,
	 Elecinfo *einfo,Elecvars *evars,
	 Energies *ener)
{
#ifdef DFT_PROFILING
  timerOn(20);   // Turn on calc_Enl timer
#endif // DFT_PROFILING

  int sp,lm,i,k;
  column_bundle *C = evars->C;
  diag_matrix *F = einfo->F;
  real *w = einfo->w;
  real Enl;

  Enl = (real)0.0;
  for (sp=0; sp < ioninfo->nspecies; sp++)
    for (lm=0; lm < ioninfo->species[sp].nlm; lm++)
      {
	if (ioninfo->species[sp].ngamma[lm] > 1)
	  {
	    dft_log(DFT_SILENCE,
		    "\nMultiple-projectors:  running slow calc_Enl!\n");

	    /* this is the slow way where we go one atom at a time...
	     * the smarter way would be to somehow make a new class
	     * which is a block-diagonal matrix class (a string of matrix
	     * classes on the diagonal of a bigger one), where each
	     * diagonal is just Mnl below, and to define an
	     * block_diag_matrix*matrix (returning matrix) operator.
	     * Then we can do what we do with the Kleinman-Bylander
	     * below with minimal changes. */
	    matrix VdagC(ioninfo->species[sp].ngamma[lm],einfo->nbands);
	    matrix &Mnl = ioninfo->species[sp].M[lm]; /* reference */
	    int max_col_length = 0;
	    for (k=0; k < einfo->nkpts; k++)
	      if (max_col_length < C[k].col_length)
		max_col_length = C[k].col_length;
	    
	    column_bundle Vnl(ioninfo->species[sp].ngamma[lm],
			      max_col_length,"local");

	    for (k=0; k < einfo->nkpts; k++) {
	      // manually adjust the number of basis;
	      Vnl.col_length = C[k].col_length;
	      copy_innards_column_bundle(&(C[k]),&Vnl);
	      for (i=0; i < ioninfo->species[sp].natoms; i++) {
		  Vnl_pseudo(sp,i,lm,einfo->kvec[k],&basis[k],ioninfo,Vnl);
		  VdagC = Vnl^C[k];
		  Enl += REAL(w[k]*trace(Mnl*VdagC*F[k]*herm_adjoint(VdagC)));
	      }
	    }
	  }
	/* Kleinman-Bylander:  bunch up all local potentials for
	 * the atoms of this species and state into a big column_bundle
	 * and work on them instead (should be faster due to ^ and *
	 * operators being block-multiplies, etc.) */
	else
	  {
	    matrix VdagC(ioninfo->species[sp].natoms,einfo->nbands);
	    scalar Mnl = ioninfo->species[sp].M[lm](0,0);
	    int max_col_length = 0;
	    for (k=0; k < einfo->nkpts; k++)
	      if (max_col_length < C[k].col_length)
		max_col_length = C[k].col_length;

	    // Vnl is created as distributed column_bundle.
	    // the dimension that's distributed is ioninfo->species[sp].natoms
	    column_bundle Vnl(ioninfo->species[sp].natoms,max_col_length);

	    for (k=0; k < einfo->nkpts; k++)
	      {
		// manually adjust the number of basis;
		Vnl.col_length = C[k].col_length;
		copy_innards_column_bundle(&(C[k]),&Vnl);

 		for (i=0; i < Vnl.my_ncols; i++)
 		  {
 		    register int j;
		    column_bundle Vnloneatom(1,basis[k].nbasis,"local");
 		    Vnl_pseudo(sp,i+Vnl.start_ncol,lm,einfo->kvec[k],
 			       &basis[k],ioninfo,Vnloneatom);
 		    for (j=0; j < basis[k].nbasis; j++)
 		      Vnl.col[i].c[j] = Vnloneatom.col[0].c[j];
 		  }

		/* Now use Vnl! */
		VdagC = Vnl^C[k];
		Enl += REAL(w[k]*trace(Mnl*VdagC*F[k]*herm_adjoint(VdagC)));
	      }
	  }

      }
  ener->Enl = Enl;

#ifdef DFT_PROFILING
  timerOff(20);   // Turn off calc_Enl timer
#endif // DFT_PROFILING
}

/*
 * Hartree energy.
 *
 * The routine assumes that the poisson equation has already been solved
 * and uses the electrostatic potential in evars->d.
 *
 */
void
calc_EH(Elecvars *evars,
	Energies *ener)
{
#ifdef DFT_PROFILING
  timerOn(21);   // Turn on calc_EH timer
#endif // DFT_PROFILING

  vector &n = evars->n;
  vector &d = evars->d;

  ener->EH = 0.5*REAL(n^Jdag(Obar(d)));

#ifdef DFT_PROFILING
  timerOff(21);   // Turn off calc_EH timer
#endif // DFT_PROFILING
}


/* Exchange correlation energy */
void
calc_Exc(Elecvars *evars,
	 Energies *ener)
{
#ifdef DFT_PROFILING
  timerOn(22);   // Turn on calc_Exc timer
#endif // DFT_PROFILING

  vector &n = evars->n;

  // enable non-linear core correction
  if (evars->ncore.n > 0) 
    n += evars->ncore;  // add core charge to charge density

  // Choice for LDA versus GGA
  if (evars->ex_opt != DFT_EXCORR_GGA) // LDA by default
    ener->Exc = REAL(n^Jdag_O_J(exc(n)));
  else
    ener->Exc = REAL(n^Jdag_O_J(exc(n)+exGC(n)));

  if (evars->ncore.n > 0) 
    n -= evars->ncore;  // recover valence charge density

#ifdef DFT_PROFILING
  timerOff(22);   // Turn off calc_Exc timer
#endif // DFT_PROFILING
}


/* Local pseudopotential core energy */
void
calc_Ecore(Ioninfo *ioninfo,
	   Basis *basis,
	   Elecinfo *einfo,Energies *ener)
{
#ifdef DFT_PROFILING
  timerOn(23);   // Turn on calc_Ecore timer
#endif // DFT_PROFILING

  ener->Ecore = 
      Vloc_pseudoGzeroEnergy(einfo->nelectrons,&basis[einfo->nkpts],
			     ioninfo);

#ifdef DFT_PROFILING
  timerOff(23);   // Turn off calc_Ecore timer
#endif // DFT_PROFILING
}

/* Ewald energy */
void
calc_Eewald(Ioninfo *ioninfo,
	    Basis *basis,
	    Energies *ener)
{
#ifdef DFT_PROFILING
  timerOn(24);   // Turn on calc_Eewald timer
#endif // DFT_PROFILING

  ener->Eewald = Ewald(ioninfo,basis[0].latvec);

#ifdef DFT_PROFILING
  timerOff(24);   // Turn on calc_Eewald timer
#endif // DFT_PROFILING
}

/* Calculate Pulay correction energy */
void
calc_Epulay(Ioninfo *ioninfo,
 	    Elecinfo *einfo,
 	    Basis *basis,
 	    Energies *ener)
{
  int sp, nkpts = einfo->nkpts;
  real NGidealperVol,NGactualperVol;
 
  NGidealperVol = sqrt(2.0)*pow(einfo->Ecut,1.5)/(3.0*M_PI*M_PI);
  NGactualperVol = (real)basis[nkpts].nbasis/basis[nkpts].unit_cell_volume;
  ener->Epulay = 0.0;
  for (sp=0; sp < ioninfo->nspecies; sp++)
    ener->Epulay += 
      ioninfo->species[sp].natoms*
      ioninfo->species[sp].dEperNatoms_dNGperVol;
  ener->Epulay *= (NGidealperVol-NGactualperVol);
}
 
/* Total sum of all the above:  sets ener->Etot to the sum. */
void
calc_Etot(Energies *ener)
{
  ener->Etot  = ener->KE      +
                ener->Eloc    + 
                ener->Enl     +
                ener->EH      +
                ener->Exc     +
                ener->Ecore   +
                ener->Eewald  +
                ener->Epulay;
}

/*
 * Calculates the Helmholtz free energy for the system :  F = Etot - T*S
 * Etot is assumed to be already calculated (ener->Etot).  The entropy S,
 * the temperature kT, the chemical potential mu, and the fillings and weights
 * are all inside of einfo.  This is used only in the case were we calculate 
 * fermi fillings (i.e. einfo->calc__fillings == 1).
 *
 * If w[k] are the k-point weights and f[k][i] is the i'th band's filling
 * at the k'th k-point, then (here fillings obey 0<f[k][i]<1 and
 * we assume we have a spin-compensated system)
 *
 * T*S = -2.0*sum_k { w[k]*sum_i { f[k][i]*ln(f[k][i]) + 
 *                                 (1-f[k][i])*ln(1-f[k][i]) } }
 *
 */
void
calc_F(Elecinfo *einfo,Energies *ener)
{
  real E,TS,wk,fki,kT;
  int k,i;

  if (einfo->calc_fillings == 0)
    die("\n\ncalc_F() called with einfo->calcfillings == 0!!!\n\n");

  /* Calculate T*S, and N */
  TS = (real)0.0;
  kT = einfo->kT;
  E = ener->Etot;
  for (k=0; k < einfo->nkpts; k++)
    for (i=0; i < einfo->nbands; i++)
      {
	wk = einfo->w[k];
	fki = REAL(einfo->F[k].c[i])/2.0;
	if (fki < 0.0 || fki > 1.0)
	    die("\n\nf[%d][%d] is not in [0.0,1.0] in calc_G()!\n\n", k,i);

	/* If fki is not 0.0 or 1.0 (to one part in 10^30), then add its
	 * contribution to TS */
	if ( fki > 1.0e-30 && (1.0-fki) > 1.0e-30)
	  TS -= kT*2.0*wk*( fki*log(fki) + (1.0-fki)*log(1.0-fki) );
      }
  ener->F = E - TS;
}

/*
 * Calculates ALL of the above energies
 */
void
calc_all_energies(Basis *basis,
		  Ioninfo *ioninfo,
		  Elecinfo *einfo,Elecvars *evars,
		  Energies *ener)
{
  calc_KE(einfo,evars,ener);
  calc_Eloc(evars,ener);
  calc_Enl(basis,ioninfo,einfo,evars,ener);
  calc_EH(evars,ener);
  calc_Exc(evars,ener);
  calc_Ecore(ioninfo,basis,einfo,ener);
  calc_Eewald(ioninfo,basis,ener);
  calc_Epulay(ioninfo,einfo,basis,ener);
  calc_Etot(ener);
  if (einfo->calc_fillings == 1)
    calc_F(einfo,ener);
}

/*
 * Calculate the energy terms depenent on the electronic variables,
 * and recalculate the total energy.
 */
void
calc_elec_dependent_energies(Basis *basis,
			     Ioninfo *ioninfo,
			     Elecinfo *einfo,Elecvars *evars,
			     Energies *ener)
{
  calc_KE(einfo,evars,ener);
  calc_Eloc(evars,ener);
  calc_Enl(basis,ioninfo,einfo,evars,ener);
  calc_EH(evars,ener);
  calc_Exc(evars,ener);
  calc_Etot(ener);
  if (einfo->calc_fillings == 1)
    calc_F(einfo,ener);
}

/*
 * Calculate the core and Ewald energies only and recalculate the total.
 */
void
calc_core_ewald_pulay_energies(Basis *basis,
			       Ioninfo *ioninfo,
			       Elecinfo *einfo,
			       Energies *ener)
{
  calc_Ecore(ioninfo,basis,einfo,ener);
  calc_Eewald(ioninfo,basis,ener);
  calc_Epulay(ioninfo,einfo,basis,ener);
  calc_Etot(ener);
  if (einfo->calc_fillings == 1)
    calc_F(einfo,ener);
}

/*
 * Print out the energies in a nice and neat way.
 */
void
print_energies(Elecinfo *einfo,Energies *ener, Output *out)
{
  out->printf("KE     = %15.7le\n",ener->KE);
  out->printf("Eloc   = %15.7le\n",ener->Eloc);
  out->printf("Enl    = %15.7le\n",ener->Enl);
  out->printf("EH     = %15.7le\n",ener->EH);
  out->printf("Exc    = %15.7le\n",ener->Exc);
  out->printf("Ecore  = %15.7le\n",ener->Ecore);
  out->printf("Eewald = %15.7le\n",ener->Eewald);
  out->printf("Epulay = %15.7le\n",ener->Epulay);
  out->printf("Etot   =                        %19.12le\n",ener->Etot);
  if (einfo->calc_fillings == 1)
    out->printf("F      =                        %19.12le\n",ener->F);
  out->flush();
}
