/*NMSearch.cc
 *class implementation of Nelder Mead Simplex Search
 *Adam Gurson College of William & Mary 1999
 *
 * modified slightly by Anne Shepherd (pls), 8/00
 */

#include "NMSearch.h"
#include "iostream.h"

// constructors & destructors

NMSearch::NMSearch(int dim)
{
   dimensions = dim;
   functionCalls = 0;
   simplex = NULL;
   simplexValues = NULL;
   centroid = new Vector<double>(dimensions,0.0);
   reflectionPt = new Vector<double>(dimensions,0.0);
   expansionPt = new Vector<double>(dimensions,0.0);
   contractionPt = new Vector<double>(dimensions,0.0);
   alpha = 1.0;
   beta = 0.5;
   gamma = 2.0;
   sigma = 0.5;
   scratch = new Vector<double>(dimensions,0.0);
   scratch2 = new Vector<double>(dimensions,0.0);
} // NMSearch() (default)

NMSearch::NMSearch(int dim, double Alpha, double Beta,
                            double Gamma, double Sigma)
{
   dimensions = dim;
   functionCalls = 0;
   simplex = NULL;
   simplexValues = NULL;
   centroid = new Vector<double>(dimensions,0.0);
   reflectionPt = new Vector<double>(dimensions,0.0);
   expansionPt = new Vector<double>(dimensions,0.0);
   contractionPt = new Vector<double>(dimensions,0.0);
   alpha = Alpha;
   beta = Beta;
   gamma = Gamma;
   sigma = Sigma;
   scratch = new Vector<double>(dimensions,0.0);
   scratch2 = new Vector<double>(dimensions,0.0);
} // NMSearch() (special)

NMSearch::NMSearch(const NMSearch& Original)
{
   dimensions = Original.GetVarNo();
   Original.GetCurrentSimplex(simplex);
   Original.GetCurrentSimplexValues(simplexValues);
   alpha = Original.alpha;
   beta = Original.beta;
   gamma = Original.gamma;
   sigma = Original.sigma;
   minIndex = Original.minIndex;
   maxIndex = Original.maxIndex;
   if(centroid != NULL) delete centroid;
   centroid = new Vector<double>(*(Original.centroid));
   if(reflectionPt != NULL) delete reflectionPt;
   reflectionPt = new Vector<double>(*(Original.reflectionPt));
   reflectionPtValue = Original.reflectionPtValue;
   if(expansionPt != NULL) delete expansionPt;
   expansionPt = new Vector<double>(*(Original.expansionPt));
   expansionPtValue = Original.expansionPtValue;
   if(contractionPt != NULL) delete contractionPt;
   contractionPt = new Vector<double>(*(Original.contractionPt));
   contractionPtValue = Original.contractionPtValue;
   functionCalls = Original.functionCalls;
} // NMSearch() (copy constructor)

NMSearch::~NMSearch()
{
   if(simplex != NULL) delete simplex;
   if(simplexValues != NULL) delete [] simplexValues;
   delete centroid;
   delete reflectionPt;
   delete expansionPt;
   delete contractionPt;
   delete scratch;
   delete scratch2;
   //NOTE: Matrix and Vector classes have their own destructors
} // ~NMSearch

// algorithmic routines

void NMSearch::ExploratoryMoves()
{
   double secondHighestPtValue; // used for contraction/reflection decision
   toleranceHit = 0;

   FindMinMaxIndices();
   do {
      if(DEBUG) printSimplex();
      FindCentroid();
      secondHighestPtValue = simplexValues[SecondHighestPtIndex()];
    // reflection step
      FindReflectionPt();

      // stop if at maximum function calls and update the simplex
      /*changed  8/8/00 to fix the problem of maxCalls == -1  --pls
        formerly read if(functionCalls <= maxCalls)
      */
      if (maxCalls > -1
	  && functionCalls >= maxCalls) {
	FindMinMaxIndices();
	ReplaceSimplexPoint(maxIndex, *reflectionPt);
	simplexValues[maxIndex] = reflectionPtValue;
	FindMinMaxIndices(); 
	return;
      } // if using call budget
      
      // possibility 1
      if(simplexValues[minIndex] > reflectionPtValue) {
	FindExpansionPt(); // expansion step
	
	if (reflectionPtValue > expansionPtValue) {
	  ReplaceSimplexPoint(maxIndex, *expansionPt);
	  simplexValues[maxIndex] = expansionPtValue;
	} // inner if
	else {
	  ReplaceSimplexPoint(maxIndex, *reflectionPt);
	  simplexValues[maxIndex] = reflectionPtValue;
	} // else         
      } // if for possibility 1
      
    // possibility 2

      else if( (secondHighestPtValue > reflectionPtValue        ) &&
               (   reflectionPtValue >= simplexValues[minIndex]) ) {
         ReplaceSimplexPoint(maxIndex, *reflectionPt);
         simplexValues[maxIndex] = reflectionPtValue;
      } // else if for possibility 2

    // possibility 3
      else if( reflectionPtValue >= secondHighestPtValue ) {
         FindContractionPt(); // contraction step
         if(maxPrimePtId == 0) {
           if( contractionPtValue > maxPrimePtValue ) {
             ShrinkSimplex();
           } // inner if
           else {
             ReplaceSimplexPoint(maxIndex, *contractionPt);
             simplexValues[maxIndex] = contractionPtValue;
           } // inner else
         } // maxPrimePtId == 0
         else if(maxPrimePtId == 1) {
           if( contractionPtValue >= maxPrimePtValue ) {
             ShrinkSimplex();
           } // inner if
           else {
             ReplaceSimplexPoint(maxIndex, *contractionPt);
             simplexValues[maxIndex] = contractionPtValue;
           } // inner else
         } // maxPrimePtId == 1
      } // else if for possibility 3

    // if we haven't taken care of the current simplex, something's wrong
      else {
         cerr << "Error in ExploratoryMoves() - "
              << "Unaccounted for case.\nTerminating.\n";
         return;
      }
   FindMinMaxIndices();
   } while (!Stop());   // while stopping criteria is not satisfied
} // ExploratoryMoves()

void NMSearch::ReplaceSimplexPoint(int index, const Vector<double>& newPoint)
{
   for( int i = 0; i < dimensions; i++ ) {
      (*simplex)[index][i] = newPoint[i];
   } // for
} // ReplaceSimplexPoint()

void NMSearch::CalculateFunctionValue(int index)
{
   *scratch = (*simplex).row(index);
   int success;
   fcnCall(dimensions, (*scratch).begin(), simplexValues[index], success);
   if(!success) cerr<<"Error calculating point in CalculateFunctionValue().\n";
} // CalculateFunctionValue()

void NMSearch::SetAlpha(double newAlpha)
{
   alpha = newAlpha;
} // SetAlpha()

void NMSearch::SetBeta(double newBeta)
{
   beta = newBeta;
} // SetBeta()

void NMSearch::SetGamma(double newGamma)
{
   gamma = newGamma;
} // SetGamma()

void NMSearch::SetSigma(double newSigma)
{
   sigma = newSigma;
} // SetGamma()

bool NMSearch::Stop()
{
   if(maxCalls > -1) {
      if(functionCalls >= maxCalls)
         return true;
   }

   double mean = 0.0;

   for( int i = 0; i <= dimensions; i++) {
      if( i != minIndex ) {
         mean += simplexValues[i];
      } // if
   } //for 

   mean /= (double)dimensions;

   // Test for the suggested Nelder-Mead stopping criteria
   double total = 0.0;
   for( int i = 0; i <= dimensions; i++ ) {
      total += pow( simplexValues[i] - mean ,2);
   } //for
   total /= ((double)dimensions + 1.0);
   total = sqrt(total);
   
   
   // printSimplex();
   if(total < stoppingStepLength) {
      toleranceHit = 1;
      return true;
   }
   else
      return false;
} // Stop()

void NMSearch::fcnCall(int n, double *x, double& f, int& flag)
{
   fcn(n, x, f, flag);
   functionCalls++;
} // fcnCall()

// Simplex-altering functions

void NMSearch::InitRegularTriangularSimplex(const Vector<double> *basePoint,
                                            const double edgeLength)
{
  //  This routine constructs a regular simplex (i.e., one in which all of 
  //  the edges are of equal length) following an algorithm given by Jacoby,
  //  Kowalik, and Pizzo in "Iterative Methods for Nonlinear Optimization 
  //  Problems," Prentice-Hall (1972).  This algorithm also appears in 
  //  Spendley, Hext, and Himsworth, "Sequential Application of Simplex 
  //  Designs in Optimisation and Evolutionary Operation," Technometrics, 
  //  Vol. 4, No. 4, November 1962, pages 441--461.

   int i,j;
   double p, q, temp;
   Matrix<double> *plex = new Matrix<double>(dimensions+1,dimensions,0.0);
   for( int col = 0; col < dimensions; col++ ) {
      (*plex)[0][col] = (*basePoint)[col];
   }

   temp = dimensions + 1.0;
   q = ((sqrt(temp) - 1.0) / (dimensions * sqrt(2.0))) * edgeLength;
   p = q + ((1.0 / sqrt(2.0)) * edgeLength);

   for(i = 1; i <= dimensions; i++) { 
      for(j = 0; j <= i-2; j++) {
         (*plex)[i][j] = (*plex)[0][j] + q;
      } // inner for 1
      j = i - 1;
      (*plex)[i][j] = (*plex)[0][j] + p;
      for(j = i; j < dimensions; j++) {
            (*plex)[i][j] = (*plex)[0][j] + q;
      } // inner for 2
   } // outer for

   InitGeneralSimplex(plex);
   delete plex;
} // InitRegularTriangularSimplex()

void NMSearch::InitFixedLengthRightSimplex(const Vector<double> *basePoint,
                                           const double edgeLength)
{
  // to take advantage of code reuse, this function simply turns
  // edgeLength into a vector of dimensions length, and then
  // calls InitVariableLengthRightSimplex()

   double* edgeLengths = new double[dimensions];
   for( int i = 0; i < dimensions; i++ ) {
      edgeLengths[i] = edgeLength;
   }
   InitVariableLengthRightSimplex(basePoint,edgeLengths);
   delete [] edgeLengths;
} // InitFixedLengthRightSimplex()

void NMSearch::InitVariableLengthRightSimplex(const Vector<double> *basePoint,
                                              const double* edgeLengths)
{
   Matrix<double> *plex = new Matrix<double>(dimensions+1,dimensions,0.0);
   for( int i = 0; i < dimensions; i++ ) {
      // we're building the basePoint component-by-component into
      // the (n+1)st row
      (*plex)[dimensions][i] = (*basePoint)[i];

      // now fill in the ith row with the proper point
      for( int j = 0; j < dimensions; j++ ) {
         (*plex)[i][j] = (*basePoint)[j];
         if( i == j )
            (*plex)[i][j] += edgeLengths[i];
      }
   } // for
   InitGeneralSimplex(plex);
   delete plex;
} // InitVariableLengthRightSimplex()

void NMSearch::InitGeneralSimplex(const Matrix<double> *plex)
{
   functionCalls = 0;
   if( simplex != NULL ) { delete simplex; }
   if( simplexValues != NULL ) { delete [] simplexValues;}
   simplex = new Matrix<double>((*plex));
   simplexValues = new double[dimensions+1];

   int success;
   for( int i = 0; i <= dimensions; i++ ) {
      *scratch = (*plex).row(i);
      fcnCall(dimensions, (*scratch).begin(), simplexValues[i], success);
      if(!success) cerr<<"Error with point #"<<i<<" in initial simplex.\n";
   } // for
   FindMinMaxIndices();
} // InitGeneralSimplex()

void NMSearch::ReadSimplexFile(istream& fp)
{
   if(fp == NULL) {
      cerr<<"No Input Stream in ReadSimplexFile()!\n";
      return; // There's no file handle!!
   }

   Matrix<double> *plex = new Matrix<double>(dimensions+1,dimensions);
   for( int i = 0; i <= dimensions; i++ ) {
      for ( int j = 0; j < dimensions; j++ ) {
         fp >> (*plex)[i][j];
      } // inner for
   } // outer for
   InitGeneralSimplex(plex);
   delete plex;
} // ReadSimplexFile()

// Query functions

int NMSearch::GetFunctionCalls() const
{
   return functionCalls;
} // GetFunctionCalls()

void NMSearch::GetMinPoint(Vector<double>* &minimum) const
{
   minimum = new Vector<double>((*simplex).row(minIndex));
} // GetMinPoint()

double NMSearch::GetMinVal() const
{
   return simplexValues[minIndex];
} // GetMinVal()

void NMSearch::GetCurrentSimplex(Matrix<double>* &plex) const
{
   plex = new Matrix<double>((*simplex));
} // GetCurrentSimplex()

void NMSearch::GetCurrentSimplexValues(double* &simValues) const
{
   simValues = new double[dimensions+1];
   for( int i = 0; i <= dimensions; i++ ) {
      simValues[i] = simplexValues[i];
   } // for
} // GetCurrentSimplexValues()

int NMSearch::GetVarNo() const
{
   return dimensions;
} // GetVarNo()

int NMSearch::GetTolHit() const
{
   return toleranceHit;
} // GetTolHit()

// private functions

void NMSearch::FindMinMaxIndices()
{
   if(simplexValues == NULL) {
      cerr << "Error in FindMinMaxIndices() - "
           << "The vector of simplexValues is NULL!!\n";
      return;
   }
   minIndex = 0;
   maxIndex = dimensions;
   double min = simplexValues[0];
   double max = simplexValues[dimensions];
   for( int i = 1; i <= dimensions; i++ ) {
      if( simplexValues[i] < min ) {
         min = simplexValues[i];
         minIndex = i;
      } // if
      if( simplexValues[dimensions-i] > max ) {
         max = simplexValues[dimensions-i];
         maxIndex = dimensions - i;
      } // if
   } // for
} // FindMinMaxIndices()

int NMSearch::SecondHighestPtIndex()
{
   if(simplexValues == NULL) {
      cerr << "Error in SecondHighestPtValue() - "
           << "The vector of simplexValues is NULL!!\n";
      return -1;
   }
   int secondMaxIndex = minIndex;
   double secondMax = simplexValues[minIndex];
   for( int i = 0; i <= dimensions; i++ ) {
      if(i != maxIndex) {
         if( simplexValues[i] > secondMax ) {
            secondMax = simplexValues[i];
            secondMaxIndex = i;
         } // inner if
      } // outer if
   } // for
   return secondMaxIndex;
} // SecondHighestPtValue()

void NMSearch::FindCentroid()
{
   (*centroid) = 0.0;
   for( int i = 0; i <= dimensions; i++ ) {
      if( i != maxIndex ) {
         (*centroid) = (*centroid) + (*simplex).row(i);
      } // if
   } // for
   (*centroid) = (*centroid) * ( 1.0 / (double)dimensions );
} // FindCentroid()

void NMSearch::FindReflectionPt()
{ 
   (*reflectionPt) = 0.0;
   (*reflectionPt) = ( (*centroid) * (1.0 + alpha) ) -
                     ( alpha * (*simplex).row(maxIndex) );
   int success;
   fcnCall(dimensions, (*reflectionPt).begin(), reflectionPtValue, success);
   if(!success) {
      cerr << "Error finding f(x) for reflection point at"
           << "function call #" << functionCalls << ".\n";
   } // if
} // FindReflectionPt()

void NMSearch::FindExpansionPt()
{
   (*expansionPt) = 0.0;
   (*expansionPt) = ( (*centroid) * (1.0 - gamma) ) +
                    ( gamma * (*reflectionPt) );
   int success;
   fcnCall(dimensions, (*expansionPt).begin(), expansionPtValue, success);
   if(!success) {
      cerr << "Error finding f(x) for expansion point at"
           << "function call #" << functionCalls << ".\n";
   } // if
} // FindExpansionPt()

void NMSearch::FindContractionPt()
{
   // need to first define maxPrimePt
   Vector<double> *maxPrimePt = scratch;
   if(simplexValues[maxIndex] <= reflectionPtValue) {
      *maxPrimePt = (*simplex).row(maxIndex);
      maxPrimePtValue = simplexValues[maxIndex];
      maxPrimePtId = 1;
   } // if
   else {
      maxPrimePt = reflectionPt;
      maxPrimePtValue = reflectionPtValue;
      maxPrimePtId = 0;
   } // else

   (*contractionPt) = ( (*centroid) * (1.0 - beta) ) +
                      ( beta * (*maxPrimePt) );
   int success;
   fcnCall(dimensions, (*contractionPt).begin(), contractionPtValue, success);
   if(!success) {
      cerr << "Error finding f(x) for contraction point at"
           << "function call #" << functionCalls << ".\n";
   } // if
} // FindContractionPt()

void NMSearch::ShrinkSimplex()
{
   // stop if at maximum function calls
  // changed 5/01 to reflect maxcalls = -1 possibility ---pls
  if ( (maxCalls != (-1)) 
       && (functionCalls >= maxCalls) ) {return;}

   Vector<double> *lowestPt = scratch;
   *lowestPt = (*simplex).row(minIndex);
   Vector<double> *tempPt = scratch2;
   int success;
   for( int i = 0; i <= dimensions; i++ ) {
      if( i != minIndex ) {
         *tempPt = (*simplex).row(i);
         (*tempPt) = (*tempPt) + ( sigma * ( (*lowestPt)-(*tempPt) ) );
         for( int j = 0; j < dimensions; j++ ) {
            (*simplex)[i][j] = (*tempPt)[j];
         } // inner for
         fcnCall(dimensions,(*tempPt).begin(),simplexValues[i],success);
         if (!success) cerr << "Error shrinking the simplex.\n";
         
         // stop if at maximum function calls 
	 // changed 5/01 to reflect maxcalls = -1 possibility ---pls
	 if ( (maxCalls != (-1)) 
	      && (functionCalls >= maxCalls) ) {return;}
	 
      } // if
   } // outer for
} // ShrinkSimplex()

void NMSearch::printSimplex() const
{
  for( int i = 0; i <= dimensions; i++ ) {
     cout << "   Point:";
     for ( int j = 0; j < dimensions; j++ ) {
        cout << (*simplex)[i][j] << "\t";
     } // inner for
     cout << "Value:" << simplexValues[i] << "\n";
  } // outer for

  cout << "\nFCalls: " << functionCalls << endl << endl;
}
