0

I am trying to implement a simple 2nd order polynomial gradient descent algorithm in Java. It is not converging and becomes unstable. How do I fix it?

public class PolyGradientDescent {
public static double getValue(double input) {
    return 3 * input * input - 4 * input + 3.5;
}

public static void fit() {
    double x0 = Math.random();
    double x1 = Math.random();
    double x2 = Math.random();
    double size = 15;
    double learningrate = 0.0001;

    for(int i = 0; i < 400; i++) {
        double partial_x2 = 0;
        double partial_x1 = 0;
        double partial_x0 = 0;
        for(double x = 0; x < size+0.001; x++) {
            double xx = x * x;
            double y_predict = xx * x2 + x * x1 + x0;
            double delta = getValue(x) - y_predict;
            partial_x2 += xx * delta;
            partial_x1 += x * delta;
            partial_x0 += delta;
        }
        x0 = x0 + (2 / size) * partial_x0 * learningrate;
        x1 = x1 + (2 / size) * partial_x1 * learningrate;
        x2 = x2 + (2 / size) * partial_x2 * learningrate;
        System.out.println(x0 + "\t" + x1 + "\t" + x2 + "\t" + "\t" + partial_x2 + "\t" + partial_x1 + "\t" + partial_x0);
    }
    for(double x = 0; x < size+0.001; x++) {
        System.out.println("Y: " + getValue(x) + ", Y_Predict: " + (x2 * x * x + x1 * x + x0));
    }
}

public static void main(String[] args) {
    fit();
}
}

1 Answers1

1

I tested your code in python and it works just fine, when I decrease the learning rate (divided by 100) by a bit more (and the epochs multiplied by 100).

I also changed the way the derivative was calculated to make it more mathematically correct :)

import random

def getValue(x):
    return 3 * x * x - 4 * x + 3.5

def fit():
    x0 = random.randrange(-100, 101) / 100
    x1 = random.randrange(-100, 101) / 100
    x2 = random.randrange(-100, 101) / 100
    size = 15
    learningrate = 0.000001

    for i in range(40000):
        partial_x2 = 0
        partial_x1 = 0
        partial_x0 = 0
        for x in range(16):
            xx = x * x
            y_predict = xx * x2 + x * x1 + x0
            delta = getValue(x) - y_predict

            # for the partial derivatives, I pulled the sign and the 2 into this step, and also devided the term later by -2, because this would be the true derivative
            partial_x2 -= 2 * xx * delta
            partial_x1 -= 2 * x * delta
            partial_x0 -= 2 * delta

        x0 = x0 - (1 / size) * partial_x0 * learningrate
        x1 = x1 - (1 / size) * partial_x1 * learningrate
        x2 = x2 - (1 / size) * partial_x2 * learningrate

    for x in range(16):
        print("Y: " + str(getValue(x)) + ", Y_Predict: " + str(x2 * x * x + x1 * x + x0))

fit()
Evator
  • 163
  • 2
  • 7
  • Thanks for your help. I appreciate. Does the factor of 2 come from chain rule applied to the ordinary least squares cost function? If so, that makes sense. I ran your function and it performs similar to mine. In fact, I think they are equal, no? I'm a little disappointed at the performance of this algorithm. Here, we are only trying to fit 15 points which align perfectly on a 2nd order polynomial and even after 1 million iterations, the delta between y and y-predict is still greater than 1 for a couple of points. Is that expected? – PentiumPro200 Jan 21 '22 at 07:02
  • 1
    Yes you are right, and they should be equal as I just pulled the -2 from one term into the other. Your algorithm is inefficient because you are working with polynomials which produce large values. The label for X: 15 is Y: 618.5. That means if your function predicts 0 in the beginning, partial_x2 becomes 2 * 15 * 15 * 618.5. That's why you need the small learning rate. What you could try is gradient clipping where you clip the partial_xn values to be between e.g. -20 and 20 and then you can choose a much higher learning rate. I have tried it and it worked quite well. I hope this helps :) – Evator Jan 23 '22 at 23:43
  • 1
    thanks. I found this helpful too -- added momentum. https://www.cs.toronto.edu/~lczhang/321/notes/notes08.pdf – PentiumPro200 Jan 25 '22 at 04:10