/****************************************************************************
 *
 * DFT++:  density functional package developed by
 *         the research group of Prof. Tomas Arias, MIT.
 *
 * Principal author: Sohrab Ismail-Beigi
 *
 * Modifications for MPI version: Kenneth P Esler,
 *                                Sohrab Ismail-Beigi, and
 *                                Tairan Wang.
 *
 * Modifications for LSD version: Jason A Cline
 *
 * Modifications for lattice/Pulay forces: Gabor Csanyi and
 *                                         Sohrab Ismail-Beigi
 *
 * Copyright (C) 1996-1998 The Massachusetts Institute of Technology (MIT).
 *
 ****************************************************************************/

/*
 * Routines that do various matrix multiplications.  These routines
 * are the computational Kernels for all the matrix multiplies in the code,
 * so optimizing them is the way to improve matrix multiplication
 * performance.
 */

/* $Id: matrix_mult.c,v 1.2 1999/11/11 02:40:31 tairan Exp $ */

#include "math.h"
#include "header.h"

//
// This routine does the multiplication
//
//       Bres[i][j] += sum_k { B1[i][k]*B2[j][k] }
//
// B1   is s1 x cs in memory layout
// B2   is s2 x cs in memory layout
// Bres is s1 x s2 in memory layout
//
// The matrix multiply is done for the sublocks in memory of size:
// B1   n1 x nc
// B2   n2 x nc
// Bres n1 x n2
//
void
small_block_matrix_mult(int n1, int n2, int nc,
			int s1, int s2, int cs,
			scalar *B1, scalar *B2, scalar *Bres)
{
  // The code blocks below does this:
  //    scalar sum = 0.0;
  //    for (register int k=0; k<nc; k++)
  //       sum += B1[r1][k] * B2[r2][k];
  //    Bres[r1][r2] += sum;
  //
  // If we have even sized blocks, do register level 2 x 2 blocks
  if ( (n1%2==0) && (n2%2==0) && (nc%2==0) )
    {
      for (int r1 = 0; r1 < n1; r1+=2)		
	for (int r2 = 0; r2 < n2; r2+=2)
	  {
	    register int r1cs = r1*cs;
	    register int r1s2 = r1*s2;
	    register int r2cs = r2*cs;
	    register double ax, bx, cx, dx, ex, fx, gx, hx;
	    register double ay, by, cy, dy, ey, fy, gy, hy;
	    double wx, xx, yx, zx, wy, xy, yy, zy;
	    wx = xx = yx = zx = wy = xy = yy = zy = 0.0;
	    for (register int k=0; k < nc; k+=2)
	      {
		// a = B1[r1][k]    b = B1[r1][k+1],
		// c = B1[r1+1][k]  d = B1[r1+1][k+1]
		ax = B1[r1cs+k].x;
		ay = B1[r1cs+k].y;
		bx = B1[r1cs+k+1].x;
		by = B1[r1cs+k+1].y;
		cx = B1[r1cs+cs+k].x;
		cy = B1[r1cs+cs+k].y;
		dx = B1[r1cs+cs+k+1].x;
		dy = B1[r1cs+cs+k+1].y;
		// e = B2[r2][k]    f = B2[r2][k+1]
		// g = B2[r2+1][k]  h = B2[r2+1][k+1]
		ex = B2[r2cs+k].x;
		ey = B2[r2cs+k].y;
		fx = B2[r2cs+k+1].x;
		fy = B2[r2cs+k+1].y;
		gx = B2[r2cs+cs+k].x;
		gy = B2[r2cs+cs+k].y;
		hx = B2[r2cs+cs+k+1].x;
		hy = B2[r2cs+cs+k+1].y;
		// w = a*e + b*f    x = a*g + b*h
		// y = c*e + d*f    z = c*g + d*h
		wx += ax*ex - ay*ey + bx*fx - by*fy;
		wy += ax*ey + ay*ex + bx*fy + by*fx;
		xx += ax*gx - ay*gy + bx*hx - by*hy;
		xy += ax*gy + ay*gx + bx*hy + by*hx;
		yx += cx*ex - cy*ey + dx*fx - dy*fy;
		yy += cx*ey + cy*ex + dx*fy + dy*fx;
		zx += cx*gx - cy*gy + dx*hx - dy*hy;
		zy += cx*gy + cy*gx + dx*hy + dy*hx;
	      }
	    // Bres[r1][r2]   += w    Bres[r1][r2+1]   += x
	    // Bres[r1+1][r2] += y    Bres[r1+1][r2+1] += z
	    Bres[r1s2+r2].x += wx;
	    Bres[r1s2+r2].y += wy;
	    Bres[r1s2+r2+1].x += xx;
	    Bres[r1s2+r2+1].y += xy;
	    Bres[r1s2+s2+r2].x += yx;
	    Bres[r1s2+s2+r2].y += yy;
	    Bres[r1s2+s2+r2+1].x += zx;
	    Bres[r1s2+s2+r2+1].y += zy;	
	  }
    }
  // The blocks aren't even sized; so do the simpler matrix multiply
  else
    {
      for (int r1 = 0; r1 < n1; r1++)
	for (int r2 = 0; r2 < n2; r2++)
	  {
	    register int r1cs = r1*cs;
	    register int r1s2 = r1*s2;
	    register int r2cs = r2*cs;
	    register double sx = 0.0;
	    register double sy = 0.0;
	    for (register int k=0; k < nc; k++)
	      {
		register double ax;
		register double ay;
		register double bx;
		register double by;
		ax = B1[r1cs+k].x;
		ay = B1[r1cs+k].y;
		bx = B2[r2cs+k].x;
		by = B2[r2cs+k].y;
		sx += ax * bx - ay * by;
		sy += ay * bx + ax * by;
	      }
	    Bres[r1s2+r2].x += sx;
	    Bres[r1s2+r2].y += sy;
	  }
    }
}

// define block sizes used for block matrix multiplies
#define BL1 32
#define BL2 32
#define BLK 32

//
// This routine does the operation M = Y1^Y2 via blocked matrix
// multiplies that call the routine above for the actual work.
// So all that is done below is to loop over blocks, load the
// blocks from memory, and then to output the result blocks to M.
//
// n1 and n2 are the number of columns of Y1 and Y2 resp.
// N is the length of the columns of Y1 and Y2.
//
// offsetMrow/col are offsets to the output M (see the formulae below).
//
// transpose==1 means Y1 and Y2 are actually input in transpose format.
// transpose==0 means they are in "normal" format
//
// i.e. in transpose==0 mode, the routine does
//
// M(i+offsetMrow,j+offsetMcol) =
//         sum_k { conjugate(Y1.col[i].c[k])*Y2.col[j].c[k] }
//
// offsetY2rowtranspose is an offset added to the row accessing of Y2
// when in transpose mode (specially needed when doing distributed
// case when data comes in transposed... see dist_multiply.c for more
// juicy details).
//
void
Y1dagY2_block_matrix_mult(const column_bundle &Y1,
			  const column_bundle &Y2,
			  matrix &M,
			  int n1, int n2, int N,
			  int offsetMrow,int offsetMcol,
			  int transpose,int offsetY2rowtranspose)
{
  // loop over blocks of size BL1xBL2 of the output matrix M
  int ib,jb;
  for (ib=0; ib < n1; ib+=BL1)
    for (jb=0; jb < n2; jb+=BL2)
      {
	// calculate sizes of output block
	int si = (n1-ib) >= BL1 ? BL1 : n1%BL1;
	int sj = (n2-jb) >= BL2 ? BL2 : n2%BL2;

	// zero output block
	scalar out[BL1][BL2];
	int i,j;
	for (i=0; i < si; i++)
	  for (j=0; j < sj; j++)
	    out[i][j] = 0.0;

	// input blocks for multiply loop below
	scalar b1[BL1][BLK],b2[BL2][BLK];

	// loop over long direction k for the sum
	int kb;
	for (kb=0; kb < N; kb+=BLK)
	  {
	    // size of k-block
	    int sk = (N-kb) >= BLK ? BLK : N%BLK;

	    // get data from Y1 and Y2 into input blocks:
	    // if in transpose mode...
	    int k;
	    if (transpose)
	      {
		for (k=0; k < sk; k++)
		  for (i=0; i < si; i++)
		    {
#if defined SCALAR_IS_COMPLEX
		      b1[i][k] = conjugate(Y1.col[kb+k].c[ib+i]);
#elif defined SCALAR_IS_REAL
		      b1[i][k] = Y1.col[kb+k].c[ib+i];
#else
#error scalar is neither real nor complex!
#endif
		    }
		for (k=0; k < sk; k++)
		  for (j=0; j < sj; j++)
		    b2[j][k] = Y2.col[kb+k].c[jb+j+offsetY2rowtranspose];
	      } 
	    // non-transpose mode
	    else
	      {
		for (i=0; i < si; i++)
		  for (k=0; k < sk; k++)
		    {
#if defined SCALAR_IS_COMPLEX
		      b1[i][k] = conjugate(Y1.col[ib+i].c[kb+k]);
#elif defined SCALAR_IS_REAL
		      b1[i][k] = Y1.col[ib+i].c[kb+k];
#else
#error scalar is neither real nor complex!
#endif
		    }
		for (j=0; j < sj; j++)
		  for (k=0; k < sk; k++)
		    b2[j][k] = Y2.col[jb+j].c[kb+k];
	      }
	    // multiply blocks: out[i][j] += sum_k { b1[i][k]*b2[j][k] }
	    small_block_matrix_mult(si,sj,sk,BL1,BL2,BLK,
				    (scalar *)b1,(scalar *)b2,(scalar *)out);
	  } // over kb

	// now write out the output block to the matrix M
	for (i=0; i < si; i++)
	  for (j=0; j < sj; j++)
	    M(ib+i+offsetMrow,jb+j+offsetMcol) = out[i][j];
      } // over ib,jb
}



//
// This routine does the operation YM = Y*M or YM += Y*M via blocked matrix
// multiplies that call the routine small_block_matrix_mult()
// to do the actual FLOP work. So all that is done below is to loop over
// blocks, load the blocks from memory, and then to output the result
// blocks to YM.
//
// N is the length of the columns of YM and Y.
// nrM is the number of rows of M.
// ncM is the number of cols of M.
//
// offsetMrow/col are offsets of the input M (see the formulae below) used
// to read information out of it.
//
// accum==0 does YM =  Y*M
// accum==1 does YM += Y*M
//
// transpose==0 means they are in normal format
// transpose==1 means Y and YM are actually input in transpose format.
//
// The overall routine calculation does
//
//        YM(i,j) = sum_k { Y(i,k)*M(k+offsetMrow,j+offsetMcol) }
//
// In transpose mode, often the routine is called with Y and YM being
// the same column_bundle, so the routine is written so as to be able to do
// the multiplication in place:  i.e. YM and Y being the same will not
// affect things.  This is the role of the temporary matrix temp
// below.
//
void
Y_M_block_matrix_mult(const column_bundle &Y,
		      const matrix &M,
		      column_bundle &YM,
		      int N, int nrM, int ncM,
		      int offsetMrow, int offsetMcol,
		      int transpose,
		      int accum)
{
  // temporary work space
  matrix temp(BL1,ncM);

  // loop over blocks of size BL1 on the rows of Y (or cols of Ytranspose)
  int rYb;
  for (rYb=0; rYb < N; rYb+=BL1)
    {
      // compute size of block in the rY direction
      int srY =   (N-rYb) >= BL1 ? BL1 :   N%BL1;

      // loop over blocks of size BL2 on the columns of M
      int cMb;
      for (cMb=0; cMb < ncM; cMb+=BL2)
	{
	  // compute actual size of block in the cM direction
	  int scM = (ncM-cMb) >= BL2 ? BL2 : ncM%BL2;
	  
	  // zero out output block
	  scalar out[BL1][BL2];
	  int rY,cM;
	  for (rY=0; rY < srY; rY++)
	    for (cM=0; cM < scM; cM++)
	      out[rY][cM] = 0.0;
	  
	  // input blocks for multiply loop below
	  scalar by[BL1][BLK],bm[BL2][BLK];
	  
	  // loop over blocks of rows of M
	  int kb;
	  for (kb=0; kb < nrM; kb+=BLK)
	    {
	      // size of block in k direction
	      int sk = (nrM-kb) >= BLK ? BLK : nrM%BLK;
	      
	      // read in data of Y and M into input blocks
	      int k;
	      if (transpose)
		for (rY=0; rY < srY; rY++)
		  for (k=0; k < sk; k++)
		    by[rY][k] = Y.col[rYb+rY].c[kb+k];
	      else
		for (k=0; k < sk; k++)
		  for (rY=0; rY < srY; rY++)
		    by[rY][k] = Y.col[kb+k].c[rYb+rY];
	      for (k=0; k < sk; k++)
		for (cM=0; cM < scM; cM++)
		  bm[cM][k] = M(kb+k+offsetMrow,cMb+cM+offsetMcol);
	      
	      // do the multiply
	      small_block_matrix_mult(srY,scM,sk,BL1,BL2,BLK,
				      (scalar *)by,(scalar *)bm,(scalar *)out);
	    } // over kb
	  
	  // write out to the output block to the temporary space
	  for (rY=0; rY < srY; rY++)
	    for (cM=0; cM < scM; cM++)
	      temp(rY,cMb+cM) = out[rY][cM];
	  
	} // over cMb

      // write out temporary block to main memory (depending on
      // accumulate and transpose flags)
      int rY,cM;
      if (transpose)
	{
	  if (accum)
	    for (rY=0; rY < srY; rY++)
	      for (cM=0; cM < M.nc; cM++)
		YM.col[rYb+rY].c[cM] += temp(rY,cM);
	  else
	    for (rY=0; rY < srY; rY++)
	      for (cM=0; cM < M.nc; cM++)
		YM.col[rYb+rY].c[cM] = temp(rY,cM);
	}
      else
	{
	  if (accum)
	    for (cM=0; cM < M.nc; cM++)
	      for (rY=0; rY < srY; rY++)
		YM.col[cM].c[rYb+rY] += temp(rY,cM);
	  else
	    for (cM=0; cM < M.nc; cM++)
	      for (rY=0; rY < srY; rY++)
		YM.col[cM].c[rYb+rY] = temp(rY,cM);
	}

    } // over rYb
}  

//
// Does the matrix multiply mprod = m1*m2
// i.e. mprod(i,j) = sum_k { m1(i,k)*m2(k,j) }
// by doing block matrix multiplies.
//
void
matrix_matrix_block_matrix_mult(const matrix &m1,
				const matrix &m2,
				matrix &mprod)
{
  // loop over blocks of output mprod
  int ib,jb;
  for (ib=0; ib < m1.nr; ib+=BL1)
    for (jb=0; jb < m2.nc; jb+=BL2)
      {
	// input blocks
	scalar ini[BL1][BLK],inj[BL2][BLK];
	
	// output blocks
	scalar out[BL1][BL2];

	// compute size of blocks
	int si = (m1.nr-ib) >= BL1 ? BL1 : m1.nr%BL1;
	int sj = (m2.nc-jb) >= BL2 ? BL2 : m2.nc%BL2;

	// zero output block 
	register int i,j;
	for (i=0; i < si; i++)
	  for (j=0; j < sj; j++)
	    out[i][j] = 0.0;

	// loop over blocks of columns of m1 (i.e. rows of m2)
	int kb;
	for (kb=0; kb < m1.nc; kb+=BLK)
	  {
	    // Size of block in k-direction
	    int sk = (m1.nc-kb) >= BLK ? BLK : m1.nc%BLK;

	    // read in input blocks 
	    int k;
	    for (i=0; i < si; i++)
	      for (k=0; k < sk; k++)
		ini[i][k] = m1(ib+i,kb+k);
	    for (k=0; k < sk; k++)
	      for (j=0; j < sj; j++)
		inj[j][k] = m2(kb+k,jb+j);

	    // Do block mult
	    small_block_matrix_mult(si,sj,sk,BL1,BL2,BLK,
				    (scalar *)ini,(scalar *)inj,(scalar *)out);
	  } // kb blocks

	// write out block
	for (i=0; i < si; i++)
	  for (j=0; j < sj; j++)
	    mprod(ib+i,jb+j) = out[i][j];

      } // (ib,jb) blocks
}
