
import java.util.*;

public class GradientMinimizer
  implements OneVarFn
{
  public static final int ALG_SD = 0;
  public static final int ALG_CGFR = 1;
  public static final int ALG_CGPR = 2;
  private int algo = ALG_SD; /* Algorithm to be used */

  private int n;
  private Variable X[];

  /* Expressions for the function, gradient, and Hessian in postfix */
  private Vector fnPF;
  private Vector gradPF[];
  private Vector hessPF[][];

  /* What information do we have about the function? */
  private boolean haveFunction;
  private boolean haveGradient;
  private boolean haveHessian;

  private double xpos[];
  private double grad[];
  private double ograd[];
  private double h[];
  private double xtmp[];
  private double hess[][];

  private int nIter;
  private int nFnEval;

  private OneVarMinimizer ovmz;

  private double fnVal;

  public GradientMinimizer( int dim )
  {
    n = dim;
    X = new Variable[n];
    int i;
    for( i = 0; i < dim; i++ )
      X[i] = Variable.create( "x" + (i+1), 0.0 );
    // The first variable is called x1

    xpos = new double[n];
    grad = new double[n];
    ograd = new double[n];
    h = new double[n];
    xtmp = new double[n];
    hess = new double[n][n];

    gradPF = new Vector[n];
    hessPF = new Vector[n][n];

    ovmz = new OneVarMinimizer( this );
  }

  public void setAlgorithm( int a )
  {
    algo = a;
  }
  public boolean setFunction( String fnExpr )
  {
    fnPF = XToken.compileToPostFix( fnExpr );
    if( fnPF == null )
      return false;
    haveFunction = true;
    return true;
  }
  public boolean setGradient( String gradExpr[] )
  {
    int i;
    for( i = 0; i < n; i++ ) {
      gradPF[i] = XToken.compileToPostFix( gradExpr[i] );
      if( gradPF[i] == null )
        return false;
    }
    haveGradient = true;
    return true;
  }
  public boolean setHessian( String hessExpr[][] )
  {
    int i, j;
    for( i = 0; i < n; i++ ) {
      for( j = 0; j < n; j++ ) {
        hessPF[i][j] = XToken.compileToPostFix( hessExpr[i][j] );
        if( hessPF[i][j] == null )
          return false;
      }
    }
    haveHessian = true;
    return true;
  }

  public double evalF( double x[] )
  {
    int i;
    for( i = 0; i < n; i++ )
      X[i].setValue( x[i] );
    return XToken.eval( fnPF );
  }

  public double eval( double alpha, Object user_data )
  {
    int i;
    for( i = 0; i < n; i++ )
      xtmp[i] = xpos[i] - alpha * h[i]; 
    return evalF( xtmp );
  }
  public void computeGradient( double x[] )
  {
    if( !haveGradient )
      return;
    int i;
    for( i = 0; i < n; i++ )
      X[i].setValue( x[i] );
    for( i = 0; i < n; i++ )
      grad[i] = XToken.eval( gradPF[i] );
  }
  public void computeHessian( double x[] )
  {
    if( !haveHessian )
      return;
    int i, j;
    for( i = 0; i < n; i++ )
      X[i].setValue( x[i] );
    for( i = 0; i < n; i++ )
      for( j = 0; j < n; j++ )
        hess[i][j] = XToken.eval( hessPF[i][j] );
  }

  public void setInitialPoint( double x[] )
  {
    int i;
    for( i = 0; i < n; i++ )
      xpos[i] = x[i];
    nIter = 0;
  }
  public void initCGSD( double x[] )
  {
    setInitialPoint( x );
    computeGradient( x );
    int i;
    for( i = 0; i < n; i++ )
      h[i] = grad[i];
    fnVal = evalF( x );
  }
  public void initNewton( double x[] )
  {
    initCGSD( x );
    computeHessian( x );
  }

  public String dblArrayToString( double x[] )
  {
    int i;
    String txt = new String( "(" );
    int n = x.length;
    for( i = 0; i < n; i++ ) {
      if( i > 0 )
        txt += ",";
      txt += " " + x[i];
    }
    txt += " )";
    return txt;
  }

  public double normOne( double x[] )
  {
    int i, dim;
    double s, d;

    s = 0.0;
    dim = x.length;
    for( i = 0; i < dim; i++ )
      s += Math.abs( x[i] );
    return s;
  }

  public double dotProduct( double x[], double y[] )
  {
    double sum=0.0;
    int i;
    for( i = 0; i < n; i++ )
      sum += x[i] * y[i];
    return sum;
  }

  public void debugState()
  {
    System.out.println( "" );
    System.out.println( "Iteration:" + nIter );
    System.out.println( "Position x:" + dblArrayToString( xpos ) );
    System.out.println( "Function value f(x):" + fnVal );
    System.out.println( "Gradient g:" + dblArrayToString( grad ) );
    System.out.println( "Search direction h:" + dblArrayToString( h ) );
  }

  public void iterCGSD( boolean doCG, boolean doPolakRibiere )
  {
    double num, den, beta, alpha;
    int i;

    nIter++;
    /* If h is already too small there is nothing to do. */
    if( getHOneNorm() < 1e-10 )
      return;
    /* Go along the search direction ... */
    if( !ovmz.minimize( 0.0, 1e-5, this, 1e-7, 1000 ) )
      System.out.println( "NO CONVERGENCE" );
    alpha = ovmz.xBest;
    /* ... and find where the function is minimized. */

    /* Update new position */
    for( i = 0; i < n; i++ )
      xpos[i] -= alpha*h[i];

    if( doCG ) { /* Conjugate Gradient */
      /* Save gradient for Polak-Ribiere */
      if( doPolakRibiere )
        for( i = 0; i < n; i++ )
          ograd[i] = grad[i];
      den = dotProduct( grad, grad );
      /* Compute new gradient */
      computeGradient( xpos );
      num = dotProduct( grad, grad );
      if( doPolakRibiere )
        num -= dotProduct( grad, ograd );
      beta = num / den;

      /* Compute new search direction */
      for( i = 0; i < n; i++ )
        h[i] = grad[i] + beta * h[i];
    } else { /* Steepest Descent */
      /* Compute new gradient */
      computeGradient( xpos );
      /* h is same as g  */
      for( i = 0; i < n; i++ )
        h[i] = grad[i];
    }
    fnVal = evalF( xpos );
  }

  public void iterSD()
  {
    iterCGSD( false, false );
  }
  public void iterCGFletcherReeves()
  {
    iterCGSD( true, false );
  }

  public void iterCGPolakRibiere()
  {
    iterCGSD( true, true );
  }

  public void next()
  {
    switch( algo ) {
      case ALG_SD:
        iterSD();
        break;
      case ALG_CGFR:
        iterCGFletcherReeves();
        break;
      case ALG_CGPR:
        iterCGPolakRibiere();
        break;
    }
  }

  public double getHOneNorm()
  {
    return normOne( h );
  }

  public double[] getCurrentX()
  {
    return xpos;
  }

  public double[] getCurrentG()
  {
    return grad;
  }

  public double[] getCurrentH()
  {
    return h;
  }

  public double getCurrentFnVal()
  {
    return fnVal;
  }

  public int countIterations()
  {
    return nIter;
  }

  public static void main( String args[] )
  {
    int dim = 2;
    String fExp;
    String grExp[] = new String[dim];
    double x[] = new double[dim];

    // RosenBrock
    x[0] = -1.2;
    x[1] = 1;
    fExp = "100*( x2 - x1*x1)^2 + (1-x1)^2";
    grExp[0] = "-400*x1*(x2-x1*x1)-2*(1-x1)";
    grExp[1] = "200*(x2-x1*x1)";

    // Quadratic
    /*
    x[0] = 8;
    x[1] = 0.6;
    fExp = "x1*x1/100 + x2*x2";
    grExp[0] = "x1/50";
    grExp[1] = "2*x2";
    */

    GradientMinimizer gmz = new GradientMinimizer(dim);
    gmz.setFunction( fExp );
    gmz.setGradient( grExp );

    gmz.initCGSD( x );
    gmz.debugState();
    int i;
    for( i = 0; i< 10000; i++ ) {
      // gmz.iterCGPolakRibiere();
      // gmz.iterCGFletcherReeves();
      gmz.iterSD();
      gmz.debugState();
      if( gmz.getHOneNorm() < 1e-3 )
        break;
    }
  }
}
