I am trying to implement a simple 2nd order polynomial gradient descent algorithm in Java. It is not converging and becomes unstable. How do I fix it?
public class PolyGradientDescent {
public static double getValue(double input) {
return 3 * input * input - 4 * input + 3.5;
}
public static void fit() {
double x0 = Math.random();
double x1 = Math.random();
double x2 = Math.random();
double size = 15;
double learningrate = 0.0001;
for(int i = 0; i < 400; i++) {
double partial_x2 = 0;
double partial_x1 = 0;
double partial_x0 = 0;
for(double x = 0; x < size+0.001; x++) {
double xx = x * x;
double y_predict = xx * x2 + x * x1 + x0;
double delta = getValue(x) - y_predict;
partial_x2 += xx * delta;
partial_x1 += x * delta;
partial_x0 += delta;
}
x0 = x0 + (2 / size) * partial_x0 * learningrate;
x1 = x1 + (2 / size) * partial_x1 * learningrate;
x2 = x2 + (2 / size) * partial_x2 * learningrate;
System.out.println(x0 + "\t" + x1 + "\t" + x2 + "\t" + "\t" + partial_x2 + "\t" + partial_x1 + "\t" + partial_x0);
}
for(double x = 0; x < size+0.001; x++) {
System.out.println("Y: " + getValue(x) + ", Y_Predict: " + (x2 * x * x + x1 * x + x0));
}
}
public static void main(String[] args) {
fit();
}
}