/*SHHSearch.cc
 *class implementation of Spendley, Hext and Himsworth Simplex Search
 *Adam Gurson College of William & Mary 1999
 *
 * Modified in 8/00 by Anne Shepherd (pls) to fix a few little bugs
 */

#include "SHHSearch.h"
#include <iostream.h>
#include <iomanip.h>
#define stoppingStepLength 10e-8
// constructors & destructors

SHHSearch::SHHSearch(int dim)
{
   dimensions = dim;
   functionCalls = 0;
   minIndex = 0;
   simplex = NULL;
   simplexValues = NULL;
   simplexAges = NULL;
   centroid = new Vector<double>(dimensions,0.0);
   reflectionPt = new Vector<double>(dimensions,0.0);
   sigma = 0.5;
   scratch = new Vector<double>(dimensions,0.0);
   scratch2 = new Vector<double>(dimensions,0.0);
} // SHHSearch() (default)

SHHSearch::SHHSearch(int dim, double Sigma)
{
   dimensions = dim;
   functionCalls = 0;
   minIndex = 0;
   simplex = NULL;
   simplexValues = NULL;
   simplexAges = NULL;
   centroid = new Vector<double>(dimensions,0.0);
   reflectionPt = new Vector<double>(dimensions,0.0);
   sigma = Sigma;
   scratch = new Vector<double>(dimensions,0.0);
   scratch2 = new Vector<double>(dimensions,0.0);
} // SHHSearch() (special)

SHHSearch::SHHSearch(const SHHSearch& Original)
{
   dimensions = Original.GetVarNo();
   Original.GetCurrentSimplex(simplex);
   Original.GetCurrentSimplexValues(simplexValues);
   Original.GetCurrentSimplexAges(simplexAges);
   sigma = Original.sigma;
   minIndex = Original.minIndex;
   replacementIndex = Original.replacementIndex;
   if(centroid != NULL) delete centroid;
   centroid = new Vector<double>(*(Original.centroid));
   if(reflectionPt != NULL) delete reflectionPt;
   reflectionPt = new Vector<double>(*(Original.reflectionPt));
   reflectionPtValue = Original.reflectionPtValue;
   functionCalls = Original.functionCalls;
} // SHHSearch() (copy constructor)

SHHSearch::~SHHSearch()
{
   if(simplex != NULL) delete simplex;
   if(simplexValues != NULL) delete [] simplexValues;
   if(simplexAges != NULL) delete [] simplexAges;
   delete centroid;
   delete reflectionPt;
   delete scratch;
   delete scratch2;
   //NOTE: Matrix and Vector classes have their own destructors
} // ~SHHSearch

// algorithmic routines

void SHHSearch::ExploratoryMoves()
{
   toleranceHit = 0;

   replacementIndex = -1;
   do {
     FindMinReplacementIndices(replacementIndex);
     if(DEBUG) printSimplex();

     // if any point has been here for a significantly long
     // time, the simplex is most likely circling a local
     // minimum, so shrink the simplex
     if( AgesTooOld() ) {
       ShrinkSimplex();
       ResetAges();
       FindMinReplacementIndices(-1);
       if(DEBUG) printSimplex();
       
       /* Changed to fix the problem of maxCalls == -1.
          --pls, 8/8/00
       */
        if(maxCalls > -1 && functionCalls >= maxCalls) {            // pls
           // if (functionCalls >= maxCalls)  {               //original
                FindMinReplacementIndices(-1);
                return;
        } //if     
     }
     FindCentroid();
     FindReflectionPt();
     ReplaceSimplexPoint(replacementIndex, *reflectionPt);
     simplexValues[replacementIndex] = reflectionPtValue;
     UpdateAges(replacementIndex);
   } while (!Stop());   // while stopping criterion is not satisfied
   FindMinReplacementIndices(-1);
} // ExploratoryMoves()

void SHHSearch::ReplaceSimplexPoint(int index, const Vector<double>& newPoint)
{
   for( int i = 0; i < dimensions; i++ ) {
      (*simplex)[index][i] = newPoint[i];
   } // for
} // ReplaceSimplexPoint()

void SHHSearch::CalculateFunctionValue(int index)
{
   *scratch = (*simplex).row(index);
   int success;
   fcnCall(dimensions, (*scratch).begin(), simplexValues[index], success);
   if(!success) cerr<<"Error calculating point in CalculateFunctionValue().\n";
} // CalculateFunctionValue()

void SHHSearch::SetSigma(double newSigma)
{
   sigma = newSigma;
} // SetSigma()

bool SHHSearch::Stop()
{
   if(maxCalls > -1) {
      if(functionCalls >= maxCalls)
         return true;
   }

   double mean = 0.0;

   for( int i = 0; i <= dimensions; i++) {
      if( i != minIndex ) {
         mean += simplexValues[i];
      } // if
   } //for 

   mean /= (double)dimensions;

   // Test for the suggested Nelder-Mead stopping criteria
   double total = 0.0;
   for( int i = 0; i <= dimensions; i++ ) {
      total += pow( simplexValues[i] - mean ,2);
   } //for
   total /= ((double)dimensions + 1.0);
   total = sqrt(total);

   
   if(total < stoppingStepLength) {
      toleranceHit = 1;
      return true;
   }
   else
      return false;
} // Stop()

void SHHSearch::fcnCall(int n, double *x, double& f, int& flag)
{
   fcn(n, x, f, flag);
   functionCalls++;
} // fcnCall()

// Simplex-altering functions

void SHHSearch::InitRegularTriangularSimplex(const Vector<double> *basePoint,
                                                       const double edgeLength)
{
  //  This routine constructs a regular simplex (i.e., one in which all of 
  //  the edges are of equal length) following an algorithm given by Jacoby,
  //  Kowalik, and Pizzo in "Iterative Methods for Nonlinear Optimization 
  //  Problems," Prentice-Hall (1972).  This algorithm also appears in 
  //  Spendley, Hext, and Himsworth, "Sequential Application of Simplex 
  //  Designs in Optimisation and Evolutionary Operation," Technometrics, 
  //  Vol. 4, No. 4, November 1962, pages 441--461.

   int i,j;
   double p, q, temp;
   Matrix<double> *plex = new Matrix<double>(dimensions+1,dimensions,0.0);
   for( int col = 0; col < dimensions; col++ ) {
      (*plex)[0][col] = (*basePoint)[col];
   }

   temp = dimensions + 1.0;
   q = ((sqrt(temp) - 1.0) / (dimensions * sqrt(2.0))) * edgeLength;
   p = q + ((1.0 / sqrt(2.0)) * edgeLength);

   for(i = 1; i <= dimensions; i++) { 
      for(j = 0; j <= i-2; j++) {
         (*plex)[i][j] = (*plex)[0][j] + q;
      } // inner for 1
      j = i - 1;
      (*plex)[i][j] = (*plex)[0][j] + p;
      for(j = i; j < dimensions; j++) {
            (*plex)[i][j] = (*plex)[0][j] + q;
      } // inner for 2
   } // outer for

   InitGeneralSimplex(plex);
   delete plex;
} // InitRegularTriangularSimplex()

void SHHSearch::InitFixedLengthRightSimplex(const Vector<double> *basePoint,
                                           const double edgeLength)
{
  // to take advantage of code reuse, this function simply turns
  // edgeLength into a vector of dimensions length, and then
  // calls InitVariableLengthRightSimplex()

   double* edgeLengths = new double[dimensions];
   for( int i = 0; i < dimensions; i++ ) {
      edgeLengths[i] = edgeLength;
   }
   InitVariableLengthRightSimplex(basePoint,edgeLengths);
   delete [] edgeLengths;
} // InitFixedLengthRightSimplex()

void SHHSearch::InitVariableLengthRightSimplex(const Vector<double> *basePoint,
                                              const double* edgeLengths)
{
   Matrix<double> *plex = new Matrix<double>(dimensions+1,dimensions,0.0);
   for( int i = 0; i < dimensions; i++ ) {
      // we're building the basePoint component-by-component into
      // the (n+1)st row
      (*plex)[dimensions][i] = (*basePoint)[i];

      // now fill in the ith row with the proper point
      for( int j = 0; j < dimensions; j++ ) {
         (*plex)[i][j] = (*basePoint)[j];
         if( i == j )
            (*plex)[i][j] += edgeLengths[i];
      }
   } // for
   InitGeneralSimplex(plex);
   delete plex;
} // InitVariableLengthRightSimplex()

void SHHSearch::InitGeneralSimplex(const Matrix<double> *plex)
{
   functionCalls = 0;
   if( simplex != NULL ) { delete simplex; }
   if( simplexValues != NULL ) { delete [] simplexValues;}
   simplex = new Matrix<double>((*plex));
   simplexValues = new double[dimensions+1];
   simplexAges = new double[dimensions+1];
   ResetAges();

   int success;
   for( int i = 0; i <= dimensions; i++ ) {
      *scratch = (*plex).row(i);
      fcnCall(dimensions, (*scratch).begin(), simplexValues[i], success);
      if(!success) cerr<<"Error with point #"<<i<<" in initial simplex.\n";
   } // for
   FindMinReplacementIndices(-1);
} // InitGeneralSimplex()

void SHHSearch::ReadSimplexFile(istream& fp)
{
   if(fp == NULL) {
      cerr<<"No Input Stream in ReadSimplexFile()!\n";
      return; // There's no file handle!!
   }

   Matrix<double> *plex = new Matrix<double>(dimensions+1,dimensions);
   for( int i = 0; i <= dimensions; i++ ) {
      for ( int j = 0; j < dimensions; j++ ) {
         fp >> (*plex)[i][j];
      } // inner for
   } // outer for
   InitGeneralSimplex(plex);
   delete plex;
} // ReadSimplexFile()

// Query functions

int SHHSearch::GetFunctionCalls() const
{
   return functionCalls;
} // GetFunctionCalls()

void SHHSearch::GetMinPoint(Vector<double>* &minimum) const
{
   minimum = new Vector<double>((*simplex).row(minIndex));
} // GetMinPoint()

double SHHSearch::GetMinVal() const
{
   return simplexValues[minIndex];
} // GetMinVal()

void SHHSearch::GetCurrentSimplex(Matrix<double>* &plex) const
{
   plex = new Matrix<double>((*simplex));
} // GetCurrentSimplex()

void SHHSearch::GetCurrentSimplexValues(double* &simValues) const
{
   simValues = new double[dimensions+1];
   for( int i = 0; i <= dimensions; i++ ) {
      simValues[i] = simplexValues[i];
   } // for
} // GetCurrentSimplexValues()

void SHHSearch::GetCurrentSimplexAges(double* &simAges) const
{
   simAges = new double[dimensions+1];
   for( int i = 0; i <= dimensions; i++ ) {
      simAges[i] = simplexAges[i];
   } // for
} // GetCurrentSimplexAges()

int SHHSearch::GetVarNo() const
{
   return dimensions;
} // GetVarNo()

int SHHSearch::GetTolHit() const
{
   return toleranceHit;
} // GetTolHit()

// private functions

void SHHSearch::FindMinReplacementIndices(int replacementSkipIndex)
{
   if(simplexValues == NULL) {
      cerr << "Error in FindMinReplacementIndices() - "
           << "The vector of simplexValues is NULL!!\n";
      return;
   }
   int newMinIndex = 0;
   replacementIndex = 0;
   double min = simplexValues[0];
   double replaceVal = simplexValues[0];
   if (replacementSkipIndex == 0) {
     replacementIndex = 1;
     replaceVal = simplexValues[1];
   }
   for( int i = 1; i <= dimensions; i++ ) {
      if( simplexValues[i] < min ) {
         min = simplexValues[i];
         newMinIndex = i;
      } // if
      if( (i != replacementSkipIndex) && (simplexValues[i] > replaceVal) ) {
         replaceVal = simplexValues[i];
         replacementIndex = i;
      } // if
   } // for
   if (simplexValues[newMinIndex] < simplexValues[minIndex]) {
     minIndex = newMinIndex;
     ResetAges();
   }
} // FindMinReplacementIndices()

void SHHSearch::FindCentroid()
{
   (*centroid) = 0.0;
   for( int i = 0; i <= dimensions; i++ ) {
      if( i != replacementIndex ) {
         (*centroid) = (*centroid) + (*simplex).row(i);
      } // if
   } // for
   (*centroid) = (*centroid) * ( 1.0 / (double)dimensions );
} // FindCentroid()

void SHHSearch::FindReflectionPt()
{ 
   (*reflectionPt) = 0.0;
   (*reflectionPt) = ( (*centroid) * 2.0 ) - (*simplex).row(replacementIndex);
   int success;
   fcnCall(dimensions, (*reflectionPt).begin(), reflectionPtValue, success);
   if(!success) {
      cerr << "Error finding f(x) for reflection point at"
           << "function call #" << functionCalls << ".\n";
   } // if
} // FindReflectionPt()

void SHHSearch::ShrinkSimplex()
{
   Vector<double> *lowestPt = scratch;
   *lowestPt = (*simplex).row(minIndex);
   Vector<double> *tempPt = scratch2;
   int success;
   for( int i = 0; i <= dimensions; i++ ) {
      if( i != minIndex ) {
         *tempPt = (*simplex).row(i);
         (*tempPt) = (*tempPt) + ( sigma * ( (*lowestPt)-(*tempPt) ) );
         for( int j = 0; j < dimensions; j++ ) {
            (*simplex)[i][j] = (*tempPt)[j];
         } // inner for
         fcnCall(dimensions,(*tempPt).begin(),simplexValues[i],success);
         if (!success) cerr << "Error shrinking the simplex.\n";
         
         // stop if at maximum function calls
         /* Modified 8/00 by Anne Shepherd to deal with the
            case where maxCalls == -1
         */
         if (maxCalls > -1                            // pls
             && functionCalls >= maxCalls) {return;}
      } // if
   } // outer for
} // ShrinkSimplex()

int SHHSearch::AgesTooOld()
{
  if( simplexAges[minIndex] > (dimensions+1) )
    return 1;
  else
    return 0;
} // AgesTooOld()

void SHHSearch::UpdateAges(int newIndex)
{
  for( int i = 0; i <= dimensions; i++ ) {
    if( i == newIndex )
      simplexAges[i] = 1;
    else
      simplexAges[i]++;
  } // for
} // ResetAges()

void SHHSearch::ResetAges()
{
   for( int i = 0; i <= dimensions; i++ ) 
      simplexAges[i] = 1;
} // ResetAges()

void SHHSearch::printSimplex() const
{
  for( int i = 0; i <= dimensions; i++ ) {
     cout << "Point: ";
     for ( int j = 0; j < dimensions; j++ ) {
       cout << (*simplex)[i][j] << " ";
     } // inner for
     cout << "   Value: " << simplexValues[i] 
          << "   Age: " << simplexAges[i] << "\n";
  } // outer for

  cout << "\nFCalls: " << functionCalls << "\n\n";
}



