/*SMDSearch.cc
 *Declarations of Sequential version of Torczon's Multi-Directional Search
 *Adam Gurson College of William & Mary 2000
 *
 * slightly modified by Anne Shepherd (pls), 8/00
 */

#include "SMDSearch.h"
#include <iostream.h>
#include <iomanip.h>

// constructors & destructors

SMDSearch::SMDSearch(int dim)
{
   dimensions = dim;
   simplex = new Matrix<double>(dimensions+1,dimensions,0.0);
   simplexValues = new double[dimensions+1];
   simplexVBits = new int[dimensions+1];
   refSimplex = new Matrix<double>(dimensions+1,dimensions,0.0);
   refSimplexValues = new double[dimensions+1];
   refSimplexVBits = new int[dimensions+1];
   minPoint = new Vector<double>(dimensions,0.0);
   delta = -1.0;
   sigma = 0.5;
   functionCalls = 0;
   scratch = new Vector<double>(dimensions,0.0);
   scratch2 = new Vector<double>(dimensions,0.0);
} // SMDSearch() (default)

SMDSearch::SMDSearch(int dim, double Sigma)
{
   dimensions = dim;
   simplex = new Matrix<double>(dimensions+1,dimensions,0.0);
   simplexValues = new double[dimensions+1];
   simplexVBits = new int[dimensions+1];
   refSimplex = new Matrix<double>(dimensions+1,dimensions,0.0);
   refSimplexValues = new double[dimensions+1];
   refSimplexVBits = new int[dimensions+1];
   minPoint = new Vector<double>(dimensions,0.0);
   delta = -1.0;
   sigma = Sigma;
   functionCalls = 0;
   scratch = new Vector<double>(dimensions,0.0);
   scratch2 = new Vector<double>(dimensions,0.0);
} // SMDSearch() (special)

/*Slightly modified by Anne Shepherd, 8/00, to initialize some
  fields that were left out in the original
*/
SMDSearch::SMDSearch(const SMDSearch& Original)
{
   dimensions = Original.GetVarNo();
   Original.GetCurrentSimplex(simplex);
   Original.GetCurrentSimplexValues(simplexValues);
   Original.GetCurrentSimplexVBits(simplexVBits);
   refSimplex = new Matrix<double>(*(Original.refSimplex)); //added --pls
   refSimplexValues = new double[dimensions+1];             //added --pls
   refSimplexVBits = new int[dimensions+1];                 //added --pls
   for(int i = 0; i <= dimensions; i++) {                    //added --pls
     refSimplexValues[i] = Original.refSimplexValues[i];    
   } //for
   for(int j = 0; j <= dimensions; j++) {                    //added --pls
     refSimplexVBits[j] = Original.refSimplexVBits[j];
   } //for
   minPoint = new Vector<double>(*(Original.minPoint));
   minValue = Original.minValue;
   delta = Original.delta;
   sigma = Original.sigma;
   functionCalls = Original.functionCalls;
} // SMDSearch() (copy constructor)

SMDSearch::~SMDSearch()
{
   delete simplex;
   delete [] simplexValues;
   delete [] simplexVBits;
   delete refSimplex;
   delete [] refSimplexValues;
   delete [] refSimplexVBits;
   delete minPoint;
   delete scratch;
   delete scratch2;
   //NOTE: Matrix and Vector classes have their own destructors
} // ~SMDSearch

// algorithmic routines

void SMDSearch::ExploratoryMoves()
{
   int done;
   int lastMinIndex = minIndex;
   toleranceHit = 0;

   do {
     done = 0;
     CreateRefSimplex();

     if(DEBUG) {
       printSimplex();
       printRefSimplex();
     }

     // Go through the Reflection Simplex First
     refCurrentIndex = lastMinIndex;
     while( !done && GetAnotherIndex(refCurrentIndex, refSimplexVBits) ) {
       CalculateRefFunctionValue(refCurrentIndex);
       refSimplexVBits[refCurrentIndex] = 1;

       if(DEBUG) printRefSimplex();

       if( refSimplexValues[refCurrentIndex] < minValue ) {
         (*minPoint) = (*refSimplex).row(refCurrentIndex);
         minValue = refSimplexValues[refCurrentIndex];
         lastMinIndex = minIndex;
         minIndex = refCurrentIndex;
         SwitchSimplices();
         done = 1;
       } // if

       /* Changed to fix the problem of maxCalls == -1.
          --pls, 8/8/00
       */
       //if( functionCalls >= maxCalls ) return; 
       if( maxCalls > -1                            // pls
           && functionCalls >= maxCalls ) return;
     } // while (reflection search)

     // Go through the Primary Simplex Next
     while( !done && GetAnotherIndex(currentIndex, simplexVBits) ) {
       CalculateFunctionValue(currentIndex); 
       simplexVBits[currentIndex] = 1;
       // NOTE: currentIndex initialized in InitGeneralSimplex()

       if(DEBUG) printSimplex();

       if( simplexValues[currentIndex] < minValue ) {
         (*minPoint) = (*simplex).row(currentIndex);
         minValue = simplexValues[currentIndex];
         lastMinIndex = minIndex;
         minIndex = currentIndex;
         done = 1;
       } // if

       /* Changed to fix the problem of maxCalls == -1.
          --pls, 8/8/00
       */
       //if( functionCalls >= maxCalls ) return; 
       if( maxCalls > -1                            // pls
           && functionCalls >= maxCalls ) return;
     } // while (primary search)
     
     // Still there's no new min now, shrink
     if( !done )
       ShrinkSimplex();

   } while (!Stop());   // while stopping criteria is not satisfied
} // ExploratoryMoves()

void SMDSearch::ReplaceSimplexPoint(int index, const Vector<double>& newPoint)
{
   for( int i = 0; i < dimensions; i++ ) {
      (*simplex)[index][i] = newPoint[i];
   } // for
} // ReplaceSimplexPoint()

void SMDSearch::CalculateFunctionValue(int index)
{
   *scratch = (*simplex).row(index);
   int success;
   fcnCall(dimensions, (*scratch).begin(), simplexValues[index], success);
   if(!success) cerr<<"Error calculating point at index "
                    << index << "in CalculateFunctionValue().\n";
} // CalculateFunctionValue()

void SMDSearch::SetSigma(double newSigma)
{
   sigma = newSigma;
} // SetSigma()

bool SMDSearch::Stop()
{
   if(maxCalls > -1) {
      if(functionCalls >= maxCalls)
         return true;
   }

   if(delta < stoppingStepLength) {
      toleranceHit = 1;
      return true;
   }
   else
      return false;
} // Stop()

void SMDSearch::fcnCall(int n, double *x, double& f, int& flag)
{
   fcn(n, x, f, flag);
   functionCalls++;
} // fcnCall()

// Simplex-altering functions

void SMDSearch::InitRegularTriangularSimplex(const Vector<double> *basePoint,
                                             const double edgeLength)
{
  //  This routine constructs a regular simplex (i.e., one in which all of 
  //  the edges are of equal length) following an algorithm given by Jacoby,
  //  Kowalik, and Pizzo in "Iterative Methods for Nonlinear Optimization 
  //  Problems," Prentice-Hall (1972).  This algorithm also appears in 
  //  Spendley, Hext, and Himsworth, "Sequential Application of Simplex 
  //  Designs in Optimisation and Evolutionary Operation," Technometrics, 
  //  Vol. 4, No. 4, November 1962, pages 441--461.

   int i,j;
   double p, q, temp;
   Matrix<double> *plex = new Matrix<double>(dimensions+1,dimensions,0.0);
   for( int col = 0; col < dimensions; col++ ) {
      (*plex)[0][col] = (*basePoint)[col];
   }

   temp = dimensions + 1.0;
   q = ((sqrt(temp) - 1.0) / (dimensions * sqrt(2.0))) * edgeLength;
   p = q + ((1.0 / sqrt(2.0)) * edgeLength);

   for(i = 1; i <= dimensions; i++) { 
      for(j = 0; j <= i-2; j++) {
         (*plex)[i][j] = (*plex)[0][j] + q;
      } // inner for 1
      j = i - 1; // is this line necessary (redundant)
      (*plex)[i][j] = (*plex)[0][j] + p;
      for(j = i; j < dimensions; j++) {
            (*plex)[i][j] = (*plex)[0][j] + q;
      } // inner for 2
   } // outer for

   delta = edgeLength;
   InitGeneralSimplex(plex);
   delete plex;
} // InitRegularTriangularSimplex()

void SMDSearch::InitFixedLengthRightSimplex(const Vector<double> *basePoint,
                                            const double edgeLength)
{
  // to take advantage of code reuse, this function simply turns
  // edgeLength into a vector of dimensions length, and then
  // calls InitVariableLengthRightSimplex()

  //  cerr << *basePoint;
   double* edgeLengths = new double[dimensions];
   for( int i = 0; i < dimensions; i++ ) {
      edgeLengths[i] = edgeLength;
   }
   InitVariableLengthRightSimplex(basePoint,edgeLengths);
   delete [] edgeLengths;
} // InitFixedLengthRightSimplex()

void SMDSearch::InitVariableLengthRightSimplex(const Vector<double> *basePoint,
                                               const double* edgeLengths)
{
   Matrix<double> *plex = new Matrix<double>(dimensions+1,dimensions,0.0);
   for( int i = 0; i < dimensions; i++ ) {
      // we're building the basePoint component-by-component into
      // the (n+1)st row
      (*plex)[dimensions][i] = (*basePoint)[i];

      // now fill in the ith row with the proper point
      for( int j = 0; j < dimensions; j++ ) {
         (*plex)[i][j] = (*basePoint)[j];
         if( i == j )
            (*plex)[i][j] += edgeLengths[i];
      }
      
      if( edgeLengths[i] > delta ) delta = edgeLengths[i];
   }
   InitGeneralSimplex(plex);
   delete plex;
} // InitVariableLengthRightSimplex()

void SMDSearch::InitGeneralSimplex(const Matrix<double> *plex)
{
   functionCalls = 0;
   (*simplex) = (*plex);

   // zero out the valid bits
   for(int i = 0; i < dimensions; i++)
     simplexVBits[i] = 0;

   // NOTE: basePoint MUST be located in the last row of plex
   Vector<double> basePoint = (*plex).row(dimensions);

   // evaluate f(basePoint) and initialize it as the min
   int success;
   fcnCall(dimensions, (basePoint).begin(), simplexValues[dimensions], success);
   if(!success) cerr<<"Error with basePoint in initial simplex.\n";
   simplexVBits[dimensions] = 1;
   (*minPoint) = (basePoint);
   minValue = simplexValues[dimensions];
   currentIndex = minIndex = dimensions;

   // if we still haven't defined delta, go through the simplex and
   // define delta to be the length of the LONGEST simplex side
   double temp;
   if( delta < 0.0 ) {
     for( int j = 0; j < dimensions; j++ ) {
       for ( int k = j+1; k <= dimensions; k++ ) {
         temp = ( ((*simplex).row(j)) - ((*simplex).row(k)) ).l2norm();
         if( temp > delta ) delta = temp;
       } // inner for
     } // outer for
   } // outer if

   // if delta is still not defined, there is a definite problem
   if( delta < 0.0 )
     cout << "Error in simplex initialization: delta not set.\n";
   // cout << "\nminValue = " << minValue << " and minPoint = " << *minPoint << endl;
   //printSimplex();
} // InitGeneralSimplex()

void SMDSearch::ReadSimplexFile(istream& fp)
{
   if(fp == NULL) {
      cerr<<"No Input Stream in ReadSimplexFile()!\n";
      return; // There's no file handle!!
   }
   
   Vector<double> *basePoint = new Vector<double>(dimensions);
   Matrix<double> *plex = new Matrix<double>(dimensions+1,dimensions);
   for( int i = 0; i <= dimensions; i++ ) {
      for ( int j = 0; j < dimensions; j++ ) {
         fp >> (*plex)[i][j];
      } // inner for
   } // outer for
   (*basePoint) = (*plex).row(dimensions);
   InitGeneralSimplex(plex);
   delete basePoint;
   delete plex;
} // ReadSimplexFile()

// Query functions

int SMDSearch::GetFunctionCalls() const
{
   return functionCalls;
} // GetFunctionCalls()

void SMDSearch::GetMinPoint(Vector<double>* &minimum) const
{
   minimum = new Vector<double>((*minPoint));
} // GetMinPoint()

double SMDSearch::GetMinVal() const
{
   return minValue;
} // GetMinVal()

void SMDSearch::GetCurrentSimplex(Matrix<double>* &plex) const
{
   plex = new Matrix<double>((*simplex));
} // GetCurrentSimplex()

void SMDSearch::GetCurrentSimplexValues(double* &simValues) const
{
   simValues = new double[dimensions+1];
   for( int i = 0; i <= dimensions; i++ ) {
      simValues[i] = simplexValues[i];
   } // for
} // GetCurrentSimplexValues()

void SMDSearch::GetCurrentSimplexVBits(int* &simVBits) const
{
   simVBits = new int[dimensions+1];
   for( int i = 0; i <= dimensions; i++ ) {
      simVBits[i] = simplexVBits[i];
   } // for
} // GetCurrentSimplexValues()

int SMDSearch::GetVarNo() const
{
   return dimensions;
} // GetVarNo()

int SMDSearch::GetTolHit() const
{
   return toleranceHit;
} // GetTolHit()

// private functions

void SMDSearch::CreateRefSimplex()
{ 
  // copy the known flip point over
  for( int i = 0; i < dimensions; i++ )
    (*refSimplex)[currentIndex][i] = (*simplex)[currentIndex][i];
  refSimplexValues[currentIndex] = simplexValues[currentIndex];
  refSimplexVBits[currentIndex] = simplexVBits[currentIndex];
  refCurrentIndex = currentIndex;

  // reflect the remaining points   
  for( int j = 0; j <= dimensions; j++ ) {
    if( j != currentIndex ) {
      refSimplexVBits[j] = 0;
      (*scratch) = ( (*simplex).row(currentIndex) * 2.0 ) - (*simplex).row(j);
      for( int k = 0; k < dimensions; k++ ) 
         (*refSimplex)[j][k] = (*scratch)[k];
    } // if
  } // outer for
} // CreateRefSimplex()

void SMDSearch::SwitchSimplices()
{
  // this allows us to remove the need to delete and
  // reallocate memory by simply swapping pointers
  // and using the same two "simplex memory slots"
  // for the entire search

  Matrix<double> *tmp1 = simplex;
  double         *tmp2 = simplexValues;
  int            *tmp3 = simplexVBits;
  int             tmp4 = currentIndex;

  simplex = refSimplex;
  simplexValues = refSimplexValues;
  simplexVBits = refSimplexVBits;
  currentIndex = refCurrentIndex;

  refSimplex = tmp1;
  refSimplexValues = tmp2;
  refSimplexVBits = tmp3;
  refCurrentIndex = tmp4;
} // SwitchSimplices()

void SMDSearch::ShrinkSimplex()
{
  if(DEBUG) cout << "Shrinking Simplex.\n\n";

   delta *= sigma;
   currentIndex = minIndex;
   Vector<double> *lowestPt = scratch;
   *lowestPt = (*simplex).row(minIndex);
   Vector<double> *tempPt = scratch2;
 
   for( int i = 0; i <= dimensions; i++ ) {
      if( i != minIndex ) {
         *tempPt = (*simplex).row(i);
         (*tempPt) = (*tempPt) + ( sigma * ( (*lowestPt)-(*tempPt) ) );
         for( int j = 0; j < dimensions; j++ ) {
            (*simplex)[i][j] = (*tempPt)[j];
         } // inner for

         simplexVBits[i] = 0;         

      } // if
   } // outer for
} // ShrinkSimplex()

int SMDSearch::GetAnotherIndex(int& index, int*& validBits)
{
  if ( !validBits[index] ) return 1;

  int initialIndex = index;
   
  do {
    index++;
    if( index > dimensions ) index = 0;
  } while ( ( index != initialIndex) && 
            ( validBits[index] ) );
    
  if( index == initialIndex )
    return 0;
  else
    return 1;
} // GetAnotherIndex()

void SMDSearch::CalculateRefFunctionValue(int index)
{
   *scratch = (*refSimplex).row(index);
   int success;
   fcnCall(dimensions, (*scratch).begin(), 
           refSimplexValues[index], success);
   if(!success) cerr<<"Error calculating point at index "
                    << index << "in CalculateFunctionValue().\n";
} // CalculateFunctionValue()

void SMDSearch::printSimplex() const
{
  cout << "Primary Simplex:\n";

  for( int i = 0; i <= dimensions; i++ ) {
     cout << "Point: ";
     for ( int j = 0; j < dimensions; j++ ) {
       cout << (*simplex)[i][j] << " ";
     } // inner for
     cout << "   Value: " << simplexValues[i];
     
     if( simplexVBits[i] )
       cout << "   Valid\n";
     else
       cout << "   Invalid\n";
  } // outer for

  cout << "FCalls: " << functionCalls 
       << "   Delta: " << delta << "\n\n";
} // printSimplex()

void SMDSearch::printRefSimplex() const
{
  cout << "Reflection Simplex:\n";

  for( int i = 0; i <= dimensions; i++ ) {
     cout << "Point: ";
     for ( int j = 0; j < dimensions; j++ ) {
       cout << (*refSimplex)[i][j] << " ";
     } // inner for
     cout << "   Value: " << refSimplexValues[i];
     
     if( refSimplexVBits[i] )
       cout << "   Valid\n";
     else
       cout << "   Invalid\n";
  } // outer for

  cout << "FCalls: " << functionCalls 
       << "   Delta: " << delta << "\n\n";
} // printRefSimplex()











