/****************************************************************************
 *
 * DFT++:  density functional package developed by
 *         the research group of Prof. Tomas Arias, MIT.
 *
 * Principal author: Sohrab Ismail-Beigi
 *
 * Modifications for MPI version: Kenneth P Esler,
 *                                Sohrab Ismail-Beigi, and
 *                                Tairan Wang.
 *
 * Modifications for LSD version: Jason A Cline
 *
 * Modifications for lattice/Pulay forces: Gabor Csanyi and
 *                                         Sohrab Ismail-Beigi
 *
 * Copyright (C) 1996-1998 The Massachusetts Institute of Technology (MIT).
 *
 ****************************************************************************/

/*
 * Sohrab Ismail-Beigi,           October 16, 1997
 *
 * A program that performs minimization on the electronic degrees
 * of freedom.  The program prints everything out to stdout (LOGFILE
 * below).  The algorithm used for minimization is specified in the
 * input file.
 *
 * The signal handling ensures that if certain signals are received,
 * wave-functions are dumped, etc.  See signal.c.
 *
 * There are two invocations of the program from the command-line:
 * (1) with no arguments, the file INPUTFILE (below) is read
 * (2) with one argument, which is the name of the input file to be read.
 *
 * The format of the input file is:
 *
 * #
 * # comments are lines that start with '#'
 * #
 * algorithm   # minimization algorithm to use
 * niter       # number of minimization iterations
 * init_step   # size of initial minimization step (-1.0 is okay probably)
 * elecfile    # name of file with electronic state information
 * latfile     # name of file with the lattice vectors
 * ionsfile    # name of file with ionic info:  Z, positions, pseudopots,...
 * ewaldfile   # name of file with ewald setup information
 * <Yflag>     # <Yflag> is a string describing initial wave-functions
 * [filename]  # name of file to read for wave-functions if Yflag == 'read'
 * <FFTflag>   # a string describing how FFT box sizes are picked
 * [nx ny nz]  # x,y,z FFT box sizes if FFTflag == 'specified'
 *
 * <Yflag> must be either 'read' or 'random' (without quotes).  'random'
 * means that the initial wave-functions are filled with random-numbers.
 * 'read' means that a file should be read:  the name is given on the
 * next line.
 *
 * <FFTflag> must be either 'automatic' or 'specified'.  'automatic'
 * means the program figures out its on FFT box sizes.  'specified'
 * means it tries to use the sizes provided on the next line.
 *
 */

/* $Id: grab_n.c,v 1.1.1.1 1999/11/10 01:30:17 tairan Exp $ */

#include <math.h>
#include <stdio.h>
#include <string.h>
#include <time.h>
#include <unistd.h>
/* My header files */
#include "header.h"

#define INPUTFILE    "grab_n.in"
#define REPORT_LEVEL 1
#define LOGFILE      stdout

#define LN_LEN  150
#define STR_LEN 80

void calc_partial_n(Elecinfo *einfo, Elecvars *evars, diag_matrix& F, int* w0);

int
main(int argc,char**argv)
{
  /* Initialize MPI through column_bundle class. */
  System::GlobalInit(&argc,&argv);  

  Basis *basis;        /* The basis set for the calculation */
  Ioninfo ioninfo;     /* Ionic information */
  Elecinfo elecinfo;   /* Electronic state information */
  Elecvars elecvars;   /* The electronic variables: Y, C, U, n, ... */
  Energies ener;       /* Holds energies */
  Control cntrl;       /* Holds convergence control data */

  /* various other local vars */
  int k;
  int niter;
  real stepsize;
  MPI_FILE *filep;  // change to special file handler
  char line[LN_LEN],elecfilename[STR_LEN],latticefilename[STR_LEN],ionsfilename[STR_LEN],
       init_Y_filename[STR_LEN],init_Y_action[STR_LEN],FFT_action[STR_LEN],
       ewaldfilename[STR_LEN],algorithm[STR_LEN];
  int nx,ny,nz;
  time_t timenow;
  int my_report_level;

  /* If we're processor responsible for IO, then we get to print stuff to the screen! */
  if ( System::Get_ID() == System::Get_IO() )
    my_report_level = REPORT_LEVEL;
  else
    my_report_level = 0;

  /* Read input file for information */
  if (argc == 1)
    {
      if ( (filep = MPI_fopen(INPUTFILE,"r")) == (MPI_FILE *)0 )
	{
	  sprintf(line,"\n%s:  can't read '%s'.  Aborting.\n\n",
		  argv[0],INPUTFILE);
	  die(line);
	}
    }
  else if (argc == 2)
    {
      if ( (filep = MPI_fopen(argv[1],"r")) == (MPI_FILE *)0 )
	{
	  sprintf(line,"\n%s:  can't read '%s'.  Aborting.\n\n",
		  argv[0],argv[1]);
	  die(line);
	}
    }
  else
    {
      sprintf(line,"\nUsage:  %s [inputfile]\n\n",argv[0]);
      die(line);
    }

  timenow = time(0);
  if (my_report_level > 0) {
    fprintf(LOGFILE,"\n");
    fprintf(LOGFILE,"******************************************************\n");
    fprintf(LOGFILE,"Current date and time: %s\n",ctime(&timenow));
    if (argc==1)
      fprintf(LOGFILE,"%s:  reading file '%s'\n",argv[0],INPUTFILE);
    else
      fprintf(LOGFILE,"%s:  reading file '%s'\n",argv[0],argv[1]);
  }

  do { fgets(line,LN_LEN,filep); } while(line[0] == '#');
  sscanf(line,"%s",algorithm);
  if (my_report_level > 0)
    fprintf(LOGFILE,"minimization algorithm = %s\n",algorithm);
  if ( strcmp(algorithm,"EOM") != 0 &&
       strcmp(algorithm,"CG") != 0 &&
       strcmp(algorithm,"CG_nocos") != 0 &&
       strcmp(algorithm,"PCG") != 0 &&
       strcmp(algorithm,"PCG_nocos") != 0    )
    die("\nalgorithm must be one of {EOM,CG,CG_nocos,PCG,PCG_nocos}!\n\n");

  do { fgets(line,LN_LEN,filep); } while(line[0] == '#');
  sscanf(line,"%d",&niter);
  if (my_report_level > 0)
    fprintf(LOGFILE,"number of minimization iterations = %d\n",niter);

  do { fgets(line,LN_LEN,filep); } while(line[0] == '#');
  sscanf(line,"%lg",&stepsize);
  if (my_report_level > 0)
    fprintf(LOGFILE,"stepsize = %lg\n",stepsize);

  do { fgets(line,LN_LEN,filep); } while(line[0] == '#');
  sscanf(line,"%s",elecfilename);
  if (my_report_level > 0)
    fprintf(LOGFILE,"electronic state file = '%s'\n",elecfilename);

  do { fgets(line,LN_LEN,filep); } while(line[0] == '#');
  sscanf(line,"%s",latticefilename);
  if (my_report_level > 0)
    fprintf(LOGFILE,"lattice  file = '%s'\n",latticefilename);

  do { fgets(line,LN_LEN,filep); } while(line[0] == '#');
  sscanf(line,"%s",ionsfilename);
  if (my_report_level > 0)
    fprintf(LOGFILE,"ions file = '%s'\n",ionsfilename);

  do { fgets(line,LN_LEN,filep); } while(line[0] == '#');
  sscanf(line,"%s",ewaldfilename);
  if (my_report_level > 0)
    fprintf(LOGFILE,"ewald file = '%s'\n",ewaldfilename);

  do { fgets(line,LN_LEN,filep); } while(line[0] == '#');
  sscanf(line,"%s",init_Y_action);
  if (my_report_level > 0)
    fprintf(LOGFILE,"Yflag = %s",init_Y_action);

  if (strcmp(init_Y_action,"random") == 0)
    {
      if (my_report_level > 0)
	fprintf(LOGFILE,"\n");
    }
  else if (strcmp(init_Y_action,"read") == 0)
    {
      do { fgets(line,LN_LEN,filep); } while(line[0] == '#');
      sscanf(line,"%s",init_Y_filename);
      if (my_report_level > 0)
	fprintf(LOGFILE," file '%s'\n",init_Y_filename);
    }
  else
    {
      sprintf(line,"%s:  initial Yflag must be 'random' or 'read'.\n\n",
	      argv[0]);
      die(line);
    }

  do { fgets(line,LN_LEN,filep); } while(line[0] == '#');
  sscanf(line,"%s",FFT_action);
  if (my_report_level > 0)
    fprintf(LOGFILE,"FFTflag = %s",FFT_action);
  nx = ny = nz = 0;
  if (strcmp(FFT_action,"automatic") == 0)
    {
      if (my_report_level > 0)
	fprintf(LOGFILE,"\n");
    }
  else if (strcmp(FFT_action,"specified") == 0)
    {
      do { fgets(line,LN_LEN,filep); } while(line[0] == '#');
      sscanf(line,"%d %d %d",&nx,&ny,&nz);
      if (my_report_level > 0)
	fprintf(LOGFILE,":  %d by %d by %d\n",nx,ny,nz);
    }
  else
    {
      sprintf(line,
	      "%s:  initial FFTflag must be 'specified' or 'automatic'.\n\n",
	      argv[0]);
      die(line);
    }

  // Add additional stuffs for grabbing the partial charge density.
  int nkpts, nbands;
  do { fgets(line,LN_LEN,filep); } while(line[0] == '#');
  sscanf(line,"%d %d",&nkpts, &nbands);
  if (my_report_level > 0)
    fprintf(LOGFILE,"nkpts = %d\tnbands = %d\n",nkpts,nbands);

  diag_matrix F(nbands);
  real f;
  int *w0;
  w0 = (int *)mymalloc(sizeof(int)*nkpts,"w0","grab_n");
  myfprintf(LOGFILE,my_report_level,"Partial density k-points choice ...\n");
  for (k=0; k< nkpts; k++) {
    fscanf(filep,"%d",w0+k);
    myfprintf(LOGFILE,my_report_level,"%d ",w0[k]);
  }
  myfprintf(LOGFILE,my_report_level,"Partial density fillings ...\n");
  for (int b=0; b< nbands; b++) {
    fscanf(filep,"%lg",&f);
    F.c[b] = f;
    myfprintf(LOGFILE,my_report_level,"%lg ",REAL(F.c[b]));
  }
  myfprintf(LOGFILE,my_report_level,"\n\n");

  if (my_report_level > 0) {
    fprintf(LOGFILE,"******************************************************\n");
    fprintf(LOGFILE,"\n");
  }

  fclose(filep);

  /* Read the electronic state information: k-points, fillings, weights... */
  setup_elecinfo(&elecinfo,elecfilename,&basis,cntrl,my_report_level,LOGFILE);

  /* Read the lattice vectors and set up the basis */
  setup_basis(basis,latticefilename,elecinfo,
	      nx,ny,nz,my_report_level,LOGFILE);

  F.basis = &basis[elecinfo.nkpts];

  /* Read the ionic positions and pseudopotentials */
  setup_ioninfo(&basis[elecinfo.nkpts],&ioninfo,ionsfilename,&elecinfo,my_report_level,LOGFILE);

  /* Setup Ewald calculation */
  setup_Ewald(ewaldfilename,my_report_level,LOGFILE);

  /* Setup the electronic variables */
  init_elecvars(&elecinfo,basis,&elecvars);

  /* If the flag==1, then randomize initial wavefunctions and then
   * orthonormalize them. */
  if (strcmp(init_Y_action,"random") == 0)
    {
      if (my_report_level > 0)
	fprintf(LOGFILE,"\nYou must read in the wavefunction!@!!!\n");
      System::GlobalFinalize();
      exit(-1);
    }
  else
    {
      int dieflag = 0;
      /* Try to read the Y-file...check first that it's legible */
      if (my_report_level > 0)
	{
	  // cannot use MPI_fopen, since we don't want to read in
	  // the whole binary file!!
	  FILE* testfilep = fopen(init_Y_filename,"r");
	  if (testfilep == (FILE *)0)
	    {
	      sprintf(line,
		      "\nCan't open '%s' to read initial wave-functions.\n\n",
		      init_Y_filename);
	      dieflag = 1;
	    }
	  fclose(testfilep);

	  fprintf(LOGFILE,
		  "\n-------> Reading Y from '%s'\n\n",init_Y_filename);
	}
      System::GlobalCheck(dieflag, line);
      read_column_bundle_array(init_Y_filename,elecinfo.nkpts,elecvars.Y);
    }

  /* setup the FFT3D() routines */
  setupFFT3D(basis[elecinfo.nkpts].Nx,
	     basis[elecinfo.nkpts].Ny,
	     basis[elecinfo.nkpts].Nz,
	     my_report_level,LOGFILE);

  /* Calculate core and Ewald energies */
  /* calc_core_ewald_pulay_energies(&basis,&ioninfo,&elecinfo,&ener,
			   my_report_level,LOGFILE);*/

  /* Calculate the local pseudopotential */
  /* Vloc_pseudo(&basis,&ioninfo,elecvars.Vlocps.c,my_report_level,LOGFILE); */

  /* Setup signal handling */
  setup_signals(&elecinfo,&elecvars,LOGFILE);

#ifdef DFT_PROFILING
  timerOff(0);  // turn off initialization timer.

  timerOn(1);  // turn on total computation timer.
#endif // DFT_PROFILING

#ifdef TRACE_MEM
  if (my_report_level > 0)
    mem_trace_report(LOGFILE);
#endif // TRACE_MEM

  // do our own stuff
  calc_U(&elecinfo,&elecvars);
  calc_C(&elecinfo,&elecvars);
  calc_partial_n(&elecinfo,&elecvars,F,w0);


#ifdef TRACE_MEM
  if (my_report_level > 0)
    mem_trace_report(LOGFILE);
#endif // TRACE_MEM

#ifdef DFT_PROFILING
  timerOff(1); // turn off total computation timer.
#endif // DFT_PROFILING

  /* Write out final electronic variables */
  if (my_report_level > 0)
    {
      fprintf(LOGFILE,"\nDone!  Dumping final variables:\n\n");
      fflush(LOGFILE);
    }

  /* Set up date/time stamp */
  struct tm *mytm;
  char stamp[80],fname[100];

  timenow = time(0);
  mytm = localtime(&timenow);
  sprintf(stamp,"%d.%d.%d:%d:%d",
	  mytm->tm_mon+1,mytm->tm_mday,
	  mytm->tm_hour,mytm->tm_min,mytm->tm_sec);

  /* Dump the data! */

  sprintf(fname,"n.%s",stamp);
  if (my_report_level > 0)
    {
      fprintf(LOGFILE,"Dumping '%s'...",fname);
      fflush(LOGFILE);
    }
  elecvars.n.write(fname); // vector::write is parallelized.
  if (my_report_level > 0)
    {
      fprintf(LOGFILE,"done.\n\n");
      fflush(LOGFILE);
    }

  /* Free up all the used memory */
  free_basis(basis,elecinfo.nkpts);
  free_ioninfo(&ioninfo);
  free_elecinfo(&elecinfo);
  free_elecvars(&elecinfo,&elecvars);

  /* End the MPI stuff */
  System::GlobalFinalize();

  return 0;
}

void
calc_partial_n(Elecinfo *einfo, Elecvars *evars, diag_matrix& F, int* w0)
{
  int k;

  real *w = einfo->w;
  column_bundle *C = evars->C;
  vector &n = evars->n;

  real wsum = 0.0;
  for (k=0; k < einfo->nkpts; k++) {
    w[k] *= w0[k];
    wsum += w[k];
  }

  n.zero_out();
  for (k=0; k < einfo->nkpts; k++) {
    w[k] /= wsum;
    n += ((scalar)w[k])*diagouterI(F,C[k]);
  }
}
