/* Ergo, version 3.3, a program for linear scaling electronic structure
 * calculations.
 * Copyright (C) 2013 Elias Rudberg, Emanuel H. Rubensson, and Pawel Salek.
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * Primary academic reference:
 * Kohn−Sham Density Functional Theory Electronic Structure Calculations 
 * with Linearly Scaling Computational Time and Memory Usage,
 * Elias Rudberg, Emanuel H. Rubensson, and Pawel Salek,
 * J. Chem. Theory Comput. 7, 340 (2011),
 * <http://dx.doi.org/10.1021/ct100611z>
 * 
 * For further information about Ergo, see <http://www.ergoscf.org>.
 */

#include <stdio.h>
#include <unistd.h>
#include <memory>
#include <limits>
#include "basisinfo.h"

#ifdef USE_CHUNKS_AND_TASKS

#include "chunks_and_tasks.h"
#include "IntegralInfoChunk.h"
#include "matrix_utilities.h"
#include "integrals_general.h"
#include "integral_matrix_wrappers.h"
#include "utilities.h"
#include "CreateAtomCenteredBasisSet.h"
#include "ComputeOverlapMatRecursive.h"

/* CHTTL registration stuff */
CHTTL_REGISTER_CHUNK_TYPE((chttl::ChunkBasic<int>));
CHTTL_REGISTER_CHUNK_TYPE((chttl::ChunkBasic<size_t>));
CHTTL_REGISTER_CHUNK_TYPE((chttl::ChunkBasic<double>));
CHTTL_REGISTER_CHUNK_TYPE((chttl::ChunkBasic<MatrixInfoStruct>));
CHTTL_REGISTER_CHUNK_TYPE((chttl::ChunkBasic<basisset_struct>));
CHTTL_REGISTER_CHUNK_TYPE((chttl::ChunkVector<int>));
CHTTL_REGISTER_CHUNK_TYPE((chttl::ChunkVector<double>));
CHTTL_REGISTER_CHUNK_TYPE((chttl::ChunkVector<Atom>));
CHTTL_REGISTER_CHUNK_TYPE((chttl::ChunkVector<CoordStruct>));
CHTTL_REGISTER_TASK_TYPE((chttl::ChunkBasicAdd<int>));
CHTTL_REGISTER_TASK_TYPE((chttl::ChunkBasicAdd<size_t>));
CHTTL_REGISTER_TASK_TYPE((chttl::ChunkVectorAdd<double>));

/* CHTML registration stuff */
CHTML_REGISTER_CHUNK_TYPE((CHTMLMatType));
CHTML_REGISTER_CHUNK_TYPE((chtml::MatrixParams<LeafMatType>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixGetElements<LeafMatType>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixMultiply<LeafMatType, false, false>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixNNZ<LeafMatType>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixAssignFromChunkIDs<LeafMatType>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixCombineElements<double>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixRescale<LeafMatType>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixAdd<LeafMatType>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixAssignFromSparse<LeafMatType>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixSquaredFrobOfErrorMatrix<LeafMatType>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixFrobTruncLowestLevel<LeafMatType>));


static void preparePermutationsHML(const BasisInfoStruct& basisInfo,
				   mat::SizesAndBlocks& sizeBlockInfo, 
				   std::vector<int>& permutation,
				   std::vector<int>& inversePermutation,
				   int blockSizeHML) {
  const int sparseMatrixBlockFactor = 4;
  sizeBlockInfo = prepareMatrixSizesAndBlocks(basisInfo.noOfBasisFuncs,
					      blockSizeHML,
					      sparseMatrixBlockFactor,
					      sparseMatrixBlockFactor,
					      sparseMatrixBlockFactor);
  getMatrixPermutation(basisInfo,
                       blockSizeHML,
                       sparseMatrixBlockFactor,
                       sparseMatrixBlockFactor,
                       sparseMatrixBlockFactor,
                       permutation,
                       inversePermutation);
}

static double rand_0_to_1() {
  int randomint = rand();
  double x = (double)randomint;
  return x / RAND_MAX;
}

static double get_random_index_0_to_nm1(int n) {
  int result = (int)((double)n * rand_0_to_1());
  if(result < 0 || result >= n)
    throw std::runtime_error("Error: (result < 0 || result >= n).");
  return result;
}

static void get_elements_from_cht_matrix(int n, 
					 const std::vector<int> & rowind_in, 
					 const std::vector<int> & colind_in, 
					 std::vector<double> & resultValues, 
					 cht::ChunkID cid_matrix, 
					 int blockSize,
					 const std::vector<int> & permutation) {
  // Create params
  int M = n;
  int N = n;
  int leavesSizeMax = blockSize;
  cht::ChunkID cid_param = cht::registerChunk(new chtml::MatrixParams<LeafMatType>(M, N, leavesSizeMax, 0, 0));
  int nValuesToGet = rowind_in.size();
  // Create rowind vector
  std::vector<int> rowind(nValuesToGet);
  std::vector<int> colind(nValuesToGet);
  for(int i = 0; i < nValuesToGet; i++) {
    rowind[i]= permutation[rowind_in[i]];
    colind[i]= permutation[colind_in[i]];
  }
  cht::ChunkID cid_rowind = cht::registerChunk(new chttl::ChunkVector<int>(rowind));
  cht::ChunkID cid_colind = cht::registerChunk(new chttl::ChunkVector<int>(colind));
  // Register task
  cht::ChunkID cid_result = 
    cht::executeMotherTask<chtml::MatrixGetElements<LeafMatType> >(cid_param, cid_rowind, cid_colind, cid_matrix);
  // Get resulting chunk object.
  cht::shared_ptr<chttl::ChunkVector<double> const> ptr_result;
  cht::getChunk(cid_result, ptr_result);
  resultValues.resize(ptr_result->size());
  for(int i = 0; i < nValuesToGet; i++)
    resultValues[i] = (*ptr_result)[i];
  cht::deleteChunk(cid_param);
  cht::deleteChunk(cid_rowind);
  cht::deleteChunk(cid_colind);
  cht::deleteChunk(cid_result);
}

static void verify_frob_truncation(cht::ChunkID cid_A_notrunc, cht::ChunkID cid_A_trunc, ergo_real frob_trunc_threshold) {
  cht::ChunkID cid_minus1 = cht::registerChunk(new chttl::ChunkBasic<double>(-1));
  cht::ChunkID cid_A_trunc_xm1 = cht::executeMotherTask<chtml::MatrixRescale<LeafMatType> >(cid_A_trunc, cid_minus1);
  cht::ChunkID cid_error_matrix = cht::executeMotherTask<chtml::MatrixAdd<LeafMatType> >(cid_A_notrunc, cid_A_trunc_xm1);
  ergo_real frobNormOfErrorMatrix = chtml::normFrobenius<LeafMatType>(cid_error_matrix);
  printf("verify_frob_truncation: frob_trunc_threshold = %9.6g, frobNormOfErrorMatrix = %9.6g\n", frob_trunc_threshold, frobNormOfErrorMatrix);
  if(frobNormOfErrorMatrix <= frob_trunc_threshold)
    printf("OK, norm of error matrix is below threshold.\n");
  else
    throw std::runtime_error("Error in verify_frob_truncation(): norm of error matrix greater than truncation threshold.");
  cht::deleteChunk(cid_minus1);
  cht::deleteChunk(cid_A_trunc_xm1);
  cht::deleteChunk(cid_error_matrix);
}

static size_t get_nnz_for_CHTML_matrix(cht::ChunkID cid_matrix) {
  cht::ChunkID cid_nnz = cht::executeMotherTask<chtml::MatrixNNZ<LeafMatType> >(cid_matrix);
  cht::shared_ptr<chttl::ChunkBasic<size_t> const> ptr_nnz;
  cht::getChunk(cid_nnz, ptr_nnz);
  size_t nnz = ptr_nnz->x;
  cht::deleteChunk(cid_nnz);
  return nnz;
}

static double get_single_element_from_HML_matrix(int n, int row, int col, const symmMatrix & S, const std::vector<int> & permutationHML) {
  std::vector<int> rowind(1);
  std::vector<int> colind(1);
  std::vector<ergo_real> values(1);
  rowind[0] = row;
  colind[0] = col;
  S.get_values(rowind, colind, values, permutationHML, permutationHML);
  return values[0];
}

static ergo_real compare_CHTML_matrix_to_HML_matrix(int n,
						    cht::ChunkID cid_matrix_CHTML,
						    int blockSizeCHTML,
						    const std::vector<int> & permutationCHTML,
						    const symmMatrix & matrix_HML,
						    const std::vector<int> & permutationHML,
						    const std::vector<int> & permutation_HML_to_CHTML
						    ) {
  int nElementsToCheck = 888;
  printf("Checking result by looking at %d matrix elements...\n", nElementsToCheck);
  double maxAbsDiff = 0;
  std::vector<int> rowind(nElementsToCheck);
  std::vector<int> colind(nElementsToCheck);
  std::vector<double> values(nElementsToCheck);
  for(int i = 0; i < nElementsToCheck; i++) {
    int j = get_random_index_0_to_nm1(n);
    int k = get_random_index_0_to_nm1(n);
    rowind[i] = j;
    colind[i] = k;
  }
  get_elements_from_cht_matrix(n, 
			       rowind, colind, values, 
			       cid_matrix_CHTML, blockSizeCHTML, permutationCHTML);
  ergo_real maxAbsElement = 0;
  for(int i = 0; i < nElementsToCheck; i++) {
    int j = rowind[i];
    int k = colind[i];
    // Check matrix element (k,j).
    j = permutation_HML_to_CHTML[j];
    k = permutation_HML_to_CHTML[k];
    double refValue = get_single_element_from_HML_matrix(n, j, k, matrix_HML, permutationHML);
    // Get matrix element from CHTML matrix.
    double matrixElementValue = values[i];
    // Compare
    double absDiff = fabs(matrixElementValue - refValue);
    if(absDiff > maxAbsDiff)
      maxAbsDiff = absDiff;
    ergo_real absElement = fabs(matrixElementValue);
    if(absElement > maxAbsElement)
      maxAbsElement = absElement;
  }
  printf("Checked %d matrix elements, maxAbsDiff = %8.4g\n", nElementsToCheck, maxAbsDiff);
  printf("Largest abs value of any checked element: %12.6f\n", maxAbsElement);
  return maxAbsDiff;
}

static void report_timing(const Util::TimeMeter & tm, const char* s) {
  double secondsTaken = Util::TimeMeter::get_wall_seconds() - tm.get_start_time_wall_seconds();
  printf("'%s' took %12.5f wall seconds.\n", s, secondsTaken);
}

static void report_nnz_for_matrix(size_t nnz, size_t n, const char* matrixName) {
  printf("NNZ for '%15s' : %12.0f  <-->  %9.3f %% nonzero elements  <-->  %9.3f nonzero elements per row.\n", 
	 matrixName, (double)nnz, (double)nnz*100.0/((double)n*n), (double)nnz/((double)n));
}


int main(int argc, char *argv[])
{
  int nAtoms = 3;
  if(argc > 1)
    nAtoms = atoi(argv[1]);
  std::string moleculeStr;
  if(argc > 2)
    moleculeStr = argv[2];
  else
    moleculeStr = "random";
  int nWorkers = 2;
  if(argc > 3)
    nWorkers = atoi(argv[3]);
  if(nWorkers < 1) {
    printf("Error: (nWorkers < 1).\n");
    return 1;
  }
  int blockSizeCHTML = 8;
  if(argc > 4)
    blockSizeCHTML = atoi(argv[4]);
  if(blockSizeCHTML < 1) {
    printf("Error: (blockSizeCHTML < 1).\n");
    return 1;
  }
  double coordDiffLimitForBasis = 10.0;
  if(argc > 5)
    coordDiffLimitForBasis = atof(argv[5]);
  if(coordDiffLimitForBasis < 1e-5) {
    printf("Error: (coordDiffLimitForBasis < 1e-5).\n");
    return 1;
  }
  int blockSizeHML = 16;
  if(argc > 6)
    blockSizeHML = atoi(argv[6]);
  if(blockSizeHML < 1) {
    printf("Error: (blockSizeHML < 1).\n");
    return 1;
  }
  double frob_trunc_threshold = 1e-8;
  if(argc > 7)
    frob_trunc_threshold = atof(argv[7]);
  if(frob_trunc_threshold < 0) {
    printf("Error: (frob_trunc_threshold < 0).\n");
    return 1;
  }
  double cache_size_in_GB = 0.5;
  if(argc > 8)
    cache_size_in_GB = atof(argv[8]);
  if(cache_size_in_GB < 0) {
    printf("Error: (cache_size_in_GB < 0).\n");
    return 1;
  }
  int nThreads = 1;
  if(argc > 9)
    nThreads = atoi(argv[9]);
  if(nThreads < 1) {
    printf("Error: (nThreads < 1).\n");
    return 1;
  }
  int compareToOtherPermutation = 1;
  if(argc > 10)
    compareToOtherPermutation = atoi(argv[10]);
  if(compareToOtherPermutation != 0 && compareToOtherPermutation != 1) {
    printf("Error: (compareToOtherPermutation != 0 && compareToOtherPermutation != 1).\n");
    return 1;
  }
  int runOldCodeAlso = 1;
  if(argc > 11)
    runOldCodeAlso = atoi(argv[11]);
  if(runOldCodeAlso != 0 && runOldCodeAlso != 1) {
    printf("Error: (runOldCodeAlso != 0 && runOldCodeAlso != 1).\n");
    return 1;
  }
  int leafInternalBlockSizeCHTML = 2;
  if(argc > 12)
    leafInternalBlockSizeCHTML = atoi(argv[12]);
  if(leafInternalBlockSizeCHTML < 1) {
    printf("Error: (leafInternalBlockSizeCHTML < 1).\n");
    return 1;
  }
  if(blockSizeCHTML % leafInternalBlockSizeCHTML != 0) {
    printf("Error: blockSizeCHTML does not match leafInternalBlockSizeCHTML: (blockSizeCHTML %% leafInternalBlockSizeCHTML != 0).\n");
    return -1;
  }
  int runOldNosymmMultAlso = 1;
  if(argc > 13)
    runOldNosymmMultAlso = atoi(argv[13]);
  if(runOldNosymmMultAlso != 0 && runOldNosymmMultAlso != 1) {
    printf("Error: (runOldNosymmMultAlso != 0 && runOldNosymmMultAlso != 1).\n");
    return 1;
  }

  cht::extras::setNoOfWorkerThreads(nThreads);
  cht::setOutputMode(cht::Output::AllInTheEnd);
  cht::extras::setNWorkers(nWorkers);
  cht::extras::setCacheSize((size_t)(cache_size_in_GB * 1000000000));

  Util::TimeMeter tmChtStart;
  cht::start();
  report_timing(tmChtStart, "cht::start()");

  printf("===== Parameters (set by command-line args in the same order as below) =======\n");
  printf("nAtoms = %d\n", nAtoms);
  printf("moleculeStr = '%s'\n", moleculeStr.c_str());
  printf("nWorkers = %d\n", nWorkers);
  printf("blockSizeCHTML = %4d\n", blockSizeCHTML);
  printf("coordDiffLimitForBasis = %8.3f\n", coordDiffLimitForBasis);
  printf("blockSizeHML = %d\n", blockSizeHML);
  printf("frob_trunc_threshold = %7.3g\n", frob_trunc_threshold);
  printf("cache_size_in_GB = %8.4f GB\n", cache_size_in_GB);
  printf("nThreads    = %2d\n", nThreads);
  printf("compareToOtherPermutation = %d\n", compareToOtherPermutation);
  printf("runOldCodeAlso = %d\n", runOldCodeAlso);
  printf("leafInternalBlockSizeCHTML = %d\n", leafInternalBlockSizeCHTML);
  printf("runOldNosymmMultAlso = %d\n", runOldNosymmMultAlso);
  printf("===== End of parameters ======================================================\n");

  bool useLinearMolecule = false;
  if(moleculeStr == "linear")
    useLinearMolecule = true;

  bool useRandomMolecule = false;
  if(moleculeStr == "random")
    useRandomMolecule = true;

  std::auto_ptr<IntegralInfo> biBasic(new IntegralInfo(true));
  BasisInfoStruct* bis = new BasisInfoStruct();

  static Molecule m; /* Don't allocate it on stack, it's too large. FIXME: IS THIS STILL NEEDED, MAYBE IT CAN BE ON STACK NOW? */

  if(useLinearMolecule) {
    if(nAtoms < 1)
      throw std::runtime_error("Error: (nAtoms < 1), not allowed in useLinearMolecule case.\n");
    double spacing = 5.0;
    printf("Creating linear molecule with spacing %7.3f between atoms.\n", spacing);
    for(int i = 0; i < nAtoms; i++) {
      double x = 0;;
      double y = 0;
      double z = i * spacing;
      m.addAtom(1, x, y, z);
    }
  }
  else if(useRandomMolecule) {
    if(nAtoms < 1)
      throw std::runtime_error("Error: (nAtoms < 1), not allowed in useRandomMolecule case.\n");
    double atomsPerUnitVolume = 0.001;
    double boxVolume = nAtoms / atomsPerUnitVolume;
    double boxWidth = pow(boxVolume, 1.0/3.0);
    printf("Creating random 3D molecule, boxVolume = %9.3f, boxWidth = %9.3f\n", boxVolume, boxWidth);
    for(int i = 0; i < nAtoms; i++) {
      double x = boxWidth * rand_0_to_1();
      double y = boxWidth * rand_0_to_1();
      double z = boxWidth * rand_0_to_1();
      int atomType = 1;
      if(i % 2 == 0)
	atomType = 1;
      m.addAtom(atomType, x, y, z);
    }
  }
  else {
    // Not linear and not random, then we expect a filename for a molecule file.
    if(m.setFromMoleculeFile(moleculeStr.c_str(), 
			     0, /* we are guessing the net charge here */
			     NULL) != 0) {
      std::cerr << "Error in setFromMoleculeFile for filename '" << moleculeStr << "'." << std::endl;
      throw std::runtime_error("Error in m.setFromMoleculeFile().");
    }
    // Verify that nAtoms is -1 in this case.
    if(nAtoms != -1)
      throw std::runtime_error("Error: when using molecule file, nAtoms argument should be set to -1.");
    nAtoms = m.getNoOfAtoms();
    printf("Molecule file '%s' read OK, nAtoms = %d\n", moleculeStr.c_str(), nAtoms);
  }

  const char* basisSetName = "STO-3G";

  static const char *dirv[] = {
    ".", "basis", "../basis",
    ERGO_DATA_PREFIX "/basis",
    ERGO_DATA_PREFIX,
    ERGO_SPREFIX "/basis",
    ERGO_SPREFIX
  };
  basisset_struct* basissetDef = new basisset_struct;
  memset(basissetDef, 0, sizeof(basissetDef));
  if(read_basisset_file(basissetDef, basisSetName, 6, dirv, 0) != 0)
    throw std::runtime_error("Error in read_basisset_file().");

  std::vector<Atom> atomList(m.getNoOfAtoms());
  for(int i = 0; i < m.getNoOfAtoms(); i++)
    atomList[i] = m.getAtom(i);

  cht::ChunkID cid_atomList = cht::registerChunk(new chttl::ChunkVector<Atom>(atomList));
  cht::ChunkID cid_basissetDef = cht::registerChunk(new chttl::ChunkBasic<basisset_struct>(*basissetDef));
  cht::ChunkID cid_integralInfo = cht::registerChunk(new IntegralInfoChunk(*biBasic));
  cht::ChunkID cid_coordDiffLimit = cht::registerChunk(new chttl::ChunkBasic<ergo_real>(coordDiffLimitForBasis));

  Util::TimeMeter tmCreateAtomCenteredBasisSet;
  cht::ChunkID cid_basisSet = cht::executeMotherTask<CreateAtomCenteredBasisSet>(cid_atomList, cid_basissetDef, cid_integralInfo, cid_coordDiffLimit);
  report_timing(tmCreateAtomCenteredBasisSet, "executeMotherTask<CreateAtomCenteredBasisSet>");
  Util::TimeMeter tmGetBasisSetCoords;
  cht::ChunkID cid_coordList = cht::executeMotherTask<GetBasisSetCoords>(cid_basisSet);
  report_timing(tmGetBasisSetCoords, "executeMotherTask<GetBasisSetCoords>");

  cht::shared_ptr<DistrBasisSetChunk const> ptr_c;
  cht::getChunk(cid_basisSet, ptr_c);
  int nBasisFuncs = ptr_c->noOfBasisFuncs;
  printf("nBasisFuncs = %d\n", nBasisFuncs);

  cht::shared_ptr<chttl::ChunkVector<CoordStruct> const> ptr_coordList;
  cht::getChunk(cid_coordList, ptr_coordList);

  std::vector<int> permutationCHTML, inversePermutationCHTML;
  std::vector<ergo_real> xcoords(nBasisFuncs);
  std::vector<ergo_real> ycoords(nBasisFuncs);
  std::vector<ergo_real> zcoords(nBasisFuncs);
  for(int i = 0; i < nBasisFuncs; i++) {
    xcoords[i] = (*ptr_coordList)[i].coords[0];
    ycoords[i] = (*ptr_coordList)[i].coords[1];
    zcoords[i] = (*ptr_coordList)[i].coords[2];
  }
  if(blockSizeCHTML % leafInternalBlockSizeCHTML != 0) {
    printf("Error: (blockSizeCHTML %% leafInternalBlockSizeCHTML != 0).\n");
    return -1;
  }
  int sparse_block_size_lowest = leafInternalBlockSizeCHTML;
  int first_factor = blockSizeCHTML / sparse_block_size_lowest;
  getMatrixPermutationOnlyFactor2(xcoords, ycoords, zcoords,
				  sparse_block_size_lowest,
				  first_factor,
				  permutationCHTML,
				  inversePermutationCHTML);

  cht::ChunkID cid_indexList = cht::registerChunk(new chttl::ChunkVector<int>(permutationCHTML));

  cht::ChunkID cid_basisSet_2 = cht::executeMotherTask<SetFuncIndexesForBasisSet>(cid_basisSet, cid_indexList);

  cht::ChunkID cid_largest_simple_integral = cht::executeMotherTask<GetLargestSimpleIntegralForBasisSet>(cid_basisSet_2);
  cht::shared_ptr<chttl::ChunkBasic<ergo_real> const> ptr_largest_simple_integral;
  cht::getChunk(cid_largest_simple_integral, ptr_largest_simple_integral);
  ergo_real largest_simple_integral = ptr_largest_simple_integral->x;
  printf("largest_simple_integral = %22.11f\n", largest_simple_integral);
  const ergo_real MATRIX_ELEMENT_THRESHOLD_VALUE = 1e-12;
  ergo_real maxAbsValue = MATRIX_ELEMENT_THRESHOLD_VALUE / largest_simple_integral;
  cht::ChunkID cid_maxAbsValue = cht::registerChunk(new chttl::ChunkBasic<ergo_real>(maxAbsValue));

  cht::ChunkID cid_basisSet_3 = cht::executeMotherTask<SetExtentsForBasisSet>(cid_basisSet_2, cid_maxAbsValue);

  MatrixInfoStruct info;
  info.n = nBasisFuncs;
  info.leavesSizeMax = blockSizeCHTML;
  info.leafInternalBlocksize = leafInternalBlockSizeCHTML;
  cht::ChunkID cid_info = cht::registerChunk(new chttl::ChunkBasic<MatrixInfoStruct>(info));

  Util::TimeMeter tmComputeOverlapMatRecursive;
  cht::ChunkID cid_matrix_S_notrunc = cht::executeMotherTask<ComputeOverlapMatRecursive>(cid_basisSet_3, cid_basisSet_3, cid_info);
  report_timing(tmComputeOverlapMatRecursive, "executeMotherTask<ComputeOverlapMatRecursive>");

  /* OK, now we have computed the overlap matrix using CHT, the matrix chunk is cid_matrix_S. Now truncate it. */
  Util::TimeMeter tmTruncFrobenius;
  cht::ChunkID cid_matrix_S_trunc = chtml::truncFrobenius<LeafMatType>(cid_matrix_S_notrunc, frob_trunc_threshold);
  report_timing(tmTruncFrobenius, "chtml::truncFrobenius");

  /* Check that truncation did not remove too much.  */
  Util::TimeMeter tmVerifyFrobTrunc;
  verify_frob_truncation(cid_matrix_S_notrunc, cid_matrix_S_trunc, frob_trunc_threshold);
  report_timing(tmVerifyFrobTrunc, "verify_frob_truncation");

  if(bis->addBasisfuncsForMolecule(m, basisSetName,
                                   0, NULL, *biBasic, 0, 0, 0) != 0)
    throw std::runtime_error("bis->addBasisfuncsForMolecule failed.");

  int n = bis->noOfBasisFuncs;

  // Get permutation vector to translate between the two basis function orderings.
  std::vector<int> permutation_HML_to_CHTML;
  if(compareToOtherPermutation) {
    permutation_HML_to_CHTML.resize(n);
    for(int i = 0; i < n; i++) {
      // Find basis function that matches this position in space.
      ergo_real x = bis->basisFuncList[i].centerCoords[0];
      ergo_real y = bis->basisFuncList[i].centerCoords[1];
      ergo_real z = bis->basisFuncList[i].centerCoords[2];
      const ergo_real MAX_COORD_DIFF_FOR_SAME_POS = 1e-7;
      int foundIndex = -1;
      int foundCount = 0;
      for(int j = 0; j < n; j++) {
	ergo_real absdx = fabs(x - xcoords[j]);
	ergo_real absdy = fabs(y - ycoords[j]);
	ergo_real absdz = fabs(z - zcoords[j]);
	if(absdx < MAX_COORD_DIFF_FOR_SAME_POS && absdy < MAX_COORD_DIFF_FOR_SAME_POS && absdz < MAX_COORD_DIFF_FOR_SAME_POS) {
	  foundIndex = j;
	  foundCount++;
	}
      }
      if(foundCount != 1)
	throw std::runtime_error("Error getting permutation vector to translate between the two basis function orderings; (foundCount != 1)");
      permutation_HML_to_CHTML[i] = foundIndex;
    }
  }
  std::vector<int> permutation_HML_to_CHTML_inv;
  if(compareToOtherPermutation) {
    permutation_HML_to_CHTML_inv.resize(n);
    for(int i = 0; i < n; i++)
      permutation_HML_to_CHTML_inv[permutation_HML_to_CHTML[i]] = i;
  }

  // Get overlap matrix
  std::vector<int> permutationHML, inversePermutationHML;
  mat::SizesAndBlocks sizeBlockInfo;
  preparePermutationsHML(*bis, sizeBlockInfo, permutationHML, inversePermutationHML, blockSizeHML);
  symmMatrix S_notrunc;
  S_notrunc.resetSizesAndBlocks(sizeBlockInfo, sizeBlockInfo); 
  Util::TimeMeter tmComputeOverlapOld;
  if(runOldCodeAlso) {
    if(compute_overlap_matrix_sparse(*bis, S_notrunc, 
				     permutationHML) != 0)
      throw std::runtime_error("Error in compute_overlap_matrix_sparse.");
    report_timing(tmComputeOverlapOld, "Old-style compute_overlap_matrix_sparse()");
  }
  else
    printf("Skipping old-style compute_overlap_matrix_sparse().\n");
  /* Truncate S. */
  symmMatrix S_trunc(S_notrunc);
  S_trunc.frob_thresh(frob_trunc_threshold);

  symmMatrix S2;
  S2.resetSizesAndBlocks(sizeBlockInfo, sizeBlockInfo);
  if(runOldCodeAlso) {
    Util::TimeMeter tmComputeS2Old;
    S2 = 1.0 * S_trunc * S_trunc;
    report_timing(tmComputeS2Old, "Computing S*S using HML for old-style computed S_trunc");
  }
  else
    printf("Skipping old-style computation of S*S.\n");

  if(runOldNosymmMultAlso) {
    normalMatrix S_trunc_nosymm(S_trunc);
    normalMatrix S2_nosymm;
    S2_nosymm.resetSizesAndBlocks(sizeBlockInfo, sizeBlockInfo);
    Util::TimeMeter tmComputeS2OldNosymm;
    S2_nosymm = 1.0 * S_trunc_nosymm * S_trunc_nosymm;
    report_timing(tmComputeS2OldNosymm, "Computing S*S using HML without symmetry for old-style computed S_trunc_nosymm");
    normalMatrix S2_for_comparison(S2);
    normalMatrix diffMatrix(S2_nosymm);
    diffMatrix += ((ergo_real)-1.0) * S2_for_comparison;
    ergo_real frobNormOfDiffMatrix = diffMatrix.frob();
    printf("Checking HML nosymm S2 computation: frobNormOfDiffMatrix = %9.4g\n", frobNormOfDiffMatrix);
    if(frobNormOfDiffMatrix > 1e-3)
      throw std::runtime_error("Error: too large diff for S*S using HML without symmetry.");
  }

  printf("nAtoms = %d, n = %d\n", nAtoms, n);
  printf("noOfShells = %d\n", bis->noOfShells);
  printf("noOfSimplePrimitives = %d\n", bis->noOfSimplePrimitives);

  if(compareToOtherPermutation) {
    // Check result by checking some elements.
    ergo_real maxAbsDiff_S_notrunc = compare_CHTML_matrix_to_HML_matrix(n, cid_matrix_S_notrunc, blockSizeCHTML, permutationCHTML, S_notrunc, permutationHML, permutation_HML_to_CHTML_inv);
    if(maxAbsDiff_S_notrunc > 1e-8)
      throw std::runtime_error("Error: wrong result for S, (maxAbsDiff_S_notrunc > 1e-8).");
  }

  // Get frob norm of overlap matrix.
  printf("Getting frob norms for S...\n");
  ergo_real frob_norm_S_expected = S_notrunc.frob();
  ergo_real frob_norm_S_from_cht = chtml::normFrobenius<LeafMatType>(cid_matrix_S_notrunc);
  ergo_real frob_norm_S_absdiff = fabs(frob_norm_S_from_cht-frob_norm_S_expected);
  printf("frob_norm_S_expected  = %15.9f, frob_norm_S_from_cht  = %15.9f, absdiff = %9.4g\n", frob_norm_S_expected, frob_norm_S_from_cht, frob_norm_S_absdiff);
  if(runOldCodeAlso) {
    if(frob_norm_S_absdiff > 1e-5)
      throw std::runtime_error("Error: too large diff in computed frob norm of S matrix.");
  }

  // Check nnz for computed overlap matrix
  report_nnz_for_matrix(get_nnz_for_CHTML_matrix(cid_matrix_S_notrunc), n, "S no trunc");

  // Check nnz for computed overlap matrix after truncation
  report_nnz_for_matrix(get_nnz_for_CHTML_matrix(cid_matrix_S_trunc), n, "S truncated");

  // Check nnz for HML overlap matrix 
  report_nnz_for_matrix(S_notrunc.nnz(), n, "HML S no trunc");

  // Check nnz for HML overlap matrix after truncation
  report_nnz_for_matrix(S_trunc.nnz(), n, "HML S trunc");

  // Do multiplication S*S
  std::vector<cht::ChunkID> inputChunksMmul(2);
  inputChunksMmul[0] = cid_matrix_S_trunc;
  inputChunksMmul[1] = cid_matrix_S_trunc;
  cht::resetStatistics();
  Util::TimeMeter tmMatrixMultiply;
  printf("Before cht::executeMotherTask for chtml::MatrixMultiply\n");
  cht::ChunkID cid_matrix_S2 = 
    cht::executeMotherTask< chtml::MatrixMultiply<LeafMatType, false, false> > (inputChunksMmul);
  printf("After cht::executeMotherTask for chtml::MatrixMultiply\n");
  report_timing(tmMatrixMultiply, "executeMotherTask for chtml::MatrixMultiply");
  cht::reportStatistics();

  // Get frob norm of overlap matrix.
  printf("Getting frob norms for S2...\n");
  ergo_real frob_norm_S2_expected = S2.frob();
  ergo_real frob_norm_S2_from_cht = chtml::normFrobenius<LeafMatType>(cid_matrix_S2);
  ergo_real frob_norm_S2_absdiff = fabs(frob_norm_S2_from_cht-frob_norm_S2_expected);
  printf("frob_norm_S2_expected = %15.9f, frob_norm_S2_from_cht = %15.9f, absdiff = %9.4g\n", frob_norm_S2_expected, frob_norm_S2_from_cht, frob_norm_S2_absdiff);
  if(runOldCodeAlso) {
    if(frob_norm_S2_absdiff > 1e-5)
      throw std::runtime_error("Error: too large diff in computed frob norm of S2 matrix.");
  }

  // Check nnz for S2 matrix
  report_nnz_for_matrix(get_nnz_for_CHTML_matrix(cid_matrix_S2), n, "S2");

  // Check nnz for HML computed S2 matrix
  report_nnz_for_matrix(S2.nnz(), n, "HML S2");

  if(compareToOtherPermutation) {
    // Compare to S2 computed in traditional way.
    // Check S2 result by checking some elements.
    ergo_real maxAbsDiff_S2 = compare_CHTML_matrix_to_HML_matrix(n, cid_matrix_S2, blockSizeCHTML, permutationCHTML, S2, permutationHML, permutation_HML_to_CHTML_inv);
    if(maxAbsDiff_S2 > 1e-7)
      throw std::runtime_error("Error: wrong result for S2, (maxAbsDiff_S2 > 1e-7).");
  }

  printf("Calling cht::deleteChunk...\n");
  cht::deleteChunk(cid_atomList);
  cht::deleteChunk(cid_basissetDef);
  cht::deleteChunk(cid_integralInfo);
  cht::deleteChunk(cid_coordDiffLimit);
  cht::deleteChunk(cid_basisSet);
  cht::deleteChunk(cid_coordList);
  cht::deleteChunk(cid_indexList);
  cht::deleteChunk(cid_basisSet_2);
  cht::deleteChunk(cid_maxAbsValue);
  cht::deleteChunk(cid_basisSet_3);
  cht::deleteChunk(cid_largest_simple_integral);
  cht::deleteChunk(cid_info);
  cht::deleteChunk(cid_matrix_S_notrunc);
  cht::deleteChunk(cid_matrix_S2);
  cht::deleteChunk(cid_matrix_S_trunc);
  printf("After cht::deleteChunk.\n");

  cht::stop();

  delete bis;
  delete basissetDef;

  puts("CHT test succeeded."); 
  unlink("ergoscf.out");
  return 0;
}

#else

int main(int argc, char *argv[])
{
  printf("Skipping Chunks&Tasks overlap matrix creation test since USE_CHUNKS_AND_TASKS macro not defined.\n");
  return 0;
}

#endif
