/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.DoubleArrayList;
import smile.math.Math;
import smile.math.kernel.MercerKernel;
import smile.regression.Regression;
import smile.regression.RegressionTrainer;
import smile.util.MulticoreExecutor;

public class SVR<T>
implements Regression<T>,
Serializable {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = LoggerFactory.getLogger(SVR.class);
    private static final double TAU = 1.0E-12;
    private MercerKernel<T> kernel;
    private double C = 1.0;
    private double eps = 0.1;
    private double tol = 0.001;
    private List<SupportVector> sv = new ArrayList<SupportVector>();
    private double b = 0.0;
    private int nsv = 0;
    private int nbsv = 0;
    private transient SupportVector svmin = null;
    private transient SupportVector svmax = null;
    private transient double gmin = Double.MAX_VALUE;
    private transient double gmax = -1.7976931348623157E308;
    private transient int gminindex;
    private transient int gmaxindex;

    public SVR(T[] x, double[] y, MercerKernel<T> kernel, double eps, double C) {
        this(x, y, null, kernel, eps, C);
    }

    public SVR(T[] x, double[] y, double[] weight, MercerKernel<T> kernel, double eps, double C) {
        this(x, y, weight, kernel, eps, C, 0.001);
    }

    public SVR(T[] x, double[] y, MercerKernel<T> kernel, double eps, double C, double tol) {
        this(x, y, null, kernel, eps, C, tol);
    }

    public SVR(T[] x, double[] y, double[] weight, MercerKernel<T> kernel, double eps, double C, double tol) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (weight != null && x.length != weight.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and instance weight don't match: %d != %d", x.length, weight.length));
        }
        if (eps <= 0.0) {
            throw new IllegalArgumentException("Invalid error threshold: " + eps);
        }
        if (C < 0.0) {
            throw new IllegalArgumentException("Invalid soft margin penalty: " + C);
        }
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tolerance of convergence test:" + tol);
        }
        this.kernel = kernel;
        this.eps = eps;
        this.C = C;
        this.tol = tol;
        int n = x.length;
        for (int i = 0; i < n; ++i) {
            double w = 1.0;
            if (weight != null && (w = weight[i]) <= 0.0) {
                throw new IllegalArgumentException("Invalid instance weight: " + w);
            }
            SupportVector v = new SupportVector();
            v.x = x[i];
            v.y = y[i];
            v.C = w * C;
            v.g[0] = eps + y[i];
            v.g[1] = eps - y[i];
            v.k = kernel.k(x[i], x[i]);
            this.sv.add(v);
        }
        this.minmax();
        int phase = Math.min((int)n, (int)1000);
        int count = 1;
        while (this.smo(tol)) {
            if (count % phase == 0) {
                logger.info("SVR finishes {} SMO iterations", (Object)count);
            }
            ++count;
        }
        logger.info("SVR finishes training");
        Iterator<SupportVector> iter = this.sv.iterator();
        while (iter.hasNext()) {
            SupportVector v = iter.next();
            if (v.alpha[0] != 0.0 || v.alpha[1] != 0.0) continue;
            iter.remove();
        }
        this.nsv = this.sv.size();
        this.nbsv = 0;
        for (SupportVector v : this.sv) {
            v.kcache = null;
            if (v.alpha[0] != C && v.alpha[1] != C) continue;
            ++this.nbsv;
        }
        logger.info("{} support vectors, {} bounded", (Object)this.nsv, (Object)this.nbsv);
    }

    @Override
    public double predict(T x) {
        double f = this.b;
        for (SupportVector v : this.sv) {
            f += (v.alpha[1] - v.alpha[0]) * this.kernel.k(v.x, x);
        }
        return f;
    }

    private void minmax() {
        this.gmin = Double.MAX_VALUE;
        this.gmax = -1.7976931348623157E308;
        for (SupportVector v : this.sv) {
            double g = -v.g[0];
            double a = v.alpha[0];
            if (g < this.gmin && a > 0.0) {
                this.svmin = v;
                this.gmin = g;
                this.gminindex = 0;
            }
            if (g > this.gmax && a < v.C) {
                this.svmax = v;
                this.gmax = g;
                this.gmaxindex = 0;
            }
            g = v.g[1];
            a = v.alpha[1];
            if (g < this.gmin && a < v.C) {
                this.svmin = v;
                this.gmin = g;
                this.gminindex = 1;
            }
            if (!(g > this.gmax) || !(a > 0.0)) continue;
            this.svmax = v;
            this.gmax = g;
            this.gmaxindex = 1;
        }
    }

    private void gram(SupportVector i) {
        int n = this.sv.size();
        int m = MulticoreExecutor.getThreadPoolSize();
        i.kcache = new DoubleArrayList(n);
        if (n < 100 || m < 2) {
            for (SupportVector v : this.sv) {
                i.kcache.add(this.kernel.k(i.x, v.x));
            }
        } else {
            ArrayList<KernelTask> tasks = new ArrayList<KernelTask>(m + 1);
            int step = n / m;
            if (step < 100) {
                step = 100;
            }
            int start = 0;
            int end = step;
            for (int l = 0; l < m - 1; ++l) {
                tasks.add(new KernelTask(i, start, end));
                start += step;
                end += step;
            }
            tasks.add(new KernelTask(i, start, n));
            try {
                Iterator l = MulticoreExecutor.run(tasks).iterator();
                while (l.hasNext()) {
                    double[] ki;
                    for (double kij : ki = (double[])l.next()) {
                        i.kcache.add(kij);
                    }
                }
            }
            catch (Exception ex) {
                for (SupportVector v : this.sv) {
                    i.kcache.add(this.kernel.k(i.x, v.x));
                }
            }
        }
    }

    private boolean smo(double epsgr) {
        double delta;
        double curv;
        SupportVector i = this.svmax;
        int ii = this.gmaxindex;
        double old_alpha_i = i.alpha[ii];
        if (i.kcache == null) {
            this.gram(i);
        }
        SupportVector j = this.svmin;
        int jj = this.gminindex;
        double old_alpha_j = j.alpha[jj];
        double best = 0.0;
        double gi = ii == 0 ? -i.g[0] : i.g[1];
        for (SupportVector v : this.sv) {
            double gain;
            double curv2 = i.k + v.k - 2.0 * this.kernel.k(i.x, v.x);
            if (curv2 <= 0.0) {
                curv2 = 1.0E-12;
            }
            double gj = -v.g[0];
            if (v.alpha[0] > 0.0 && gj < gi && (gain = -Math.sqr((double)(gi - gj)) / curv2) < best) {
                best = gain;
                j = v;
                jj = 0;
                old_alpha_j = j.alpha[0];
            }
            gj = v.g[1];
            if (!(v.alpha[1] < v.C) || !(gj < gi) || !((gain = -Math.sqr((double)(gi - gj)) / curv2) < best)) continue;
            best = gain;
            j = v;
            jj = 1;
            old_alpha_j = j.alpha[1];
        }
        if (j.kcache == null) {
            this.gram(j);
        }
        if ((curv = i.k + j.k - 2.0 * this.kernel.k(i.x, j.x)) <= 0.0) {
            curv = 1.0E-12;
        }
        if (ii != jj) {
            delta = (-i.g[ii] - j.g[jj]) / curv;
            double diff = i.alpha[ii] - j.alpha[jj];
            int n = ii;
            i.alpha[n] = i.alpha[n] + delta;
            int n2 = jj;
            j.alpha[n2] = j.alpha[n2] + delta;
            if (diff > 0.0) {
                if (j.alpha[jj] < 0.0) {
                    j.alpha[jj] = 0.0;
                    i.alpha[ii] = diff;
                }
            } else if (i.alpha[ii] < 0.0) {
                i.alpha[ii] = 0.0;
                j.alpha[jj] = -diff;
            }
            if (diff > i.C - j.C) {
                if (i.alpha[ii] > i.C) {
                    i.alpha[ii] = i.C;
                    j.alpha[jj] = i.C - diff;
                }
            } else if (j.alpha[jj] > j.C) {
                j.alpha[jj] = j.C;
                i.alpha[ii] = j.C + diff;
            }
        } else {
            delta = (i.g[ii] - j.g[jj]) / curv;
            double sum = i.alpha[ii] + j.alpha[jj];
            int n = ii;
            i.alpha[n] = i.alpha[n] - delta;
            int n3 = jj;
            j.alpha[n3] = j.alpha[n3] + delta;
            if (sum > i.C) {
                if (i.alpha[ii] > i.C) {
                    i.alpha[ii] = i.C;
                    j.alpha[jj] = sum - i.C;
                }
            } else if (j.alpha[jj] < 0.0) {
                j.alpha[jj] = 0.0;
                i.alpha[ii] = sum;
            }
            if (sum > j.C) {
                if (j.alpha[jj] > j.C) {
                    j.alpha[jj] = j.C;
                    i.alpha[ii] = sum - j.C;
                }
            } else if (i.alpha[ii] < 0.0) {
                i.alpha[ii] = 0.0;
                j.alpha[jj] = sum;
            }
        }
        double delta_alpha_i = i.alpha[ii] - old_alpha_i;
        double delta_alpha_j = j.alpha[jj] - old_alpha_j;
        int si = 2 * ii - 1;
        int sj = 2 * jj - 1;
        for (int k = 0; k < this.sv.size(); ++k) {
            SupportVector v = this.sv.get(k);
            v.g[0] = v.g[0] - ((double)si * i.kcache.get(k) * delta_alpha_i + (double)sj * j.kcache.get(k) * delta_alpha_j);
            v.g[1] = v.g[1] + ((double)si * i.kcache.get(k) * delta_alpha_i + (double)sj * j.kcache.get(k) * delta_alpha_j);
        }
        this.minmax();
        this.b = -(this.gmax + this.gmin) / 2.0;
        return !(this.gmax - this.gmin < epsgr);
    }

    public double getC() {
        return this.C;
    }

    public double getEpsilon() {
        return this.eps;
    }

    public double getTolerance() {
        return this.tol;
    }

    class KernelTask
    implements Callable<double[]> {
        SupportVector i;
        int start;
        int end;

        KernelTask(SupportVector i, int start, int end) {
            this.i = i;
            this.start = start;
            this.end = end;
        }

        @Override
        public double[] call() {
            double[] ki = new double[this.end - this.start];
            for (int j = this.start; j < this.end; ++j) {
                ki[j - this.start] = SVR.this.kernel.k(this.i.x, ((SupportVector)((SVR)SVR.this).sv.get((int)j)).x);
            }
            return ki;
        }
    }

    public static class Trainer<T>
    extends RegressionTrainer<T> {
        private MercerKernel<T> kernel;
        private double C = 1.0;
        private double eps = 0.001;
        private double tol = 0.001;

        public Trainer(MercerKernel<T> kernel, double eps, double C) {
            if (C < 0.0) {
                throw new IllegalArgumentException("Invalid soft margin penalty: " + C);
            }
            this.kernel = kernel;
            this.C = C;
            this.eps = eps;
        }

        public Trainer setTolerance(double tol) {
            if (tol <= 0.0) {
                throw new IllegalArgumentException("Invalid tolerance of convergence test:" + tol);
            }
            this.tol = tol;
            return this;
        }

        @Override
        public SVR<T> train(T[] x, double[] y) {
            SVR<T> svr = new SVR<T>(x, y, this.kernel, this.eps, this.C, this.tol);
            return svr;
        }
    }

    class SupportVector
    implements Serializable {
        private static final long serialVersionUID = 1L;
        T x;
        double y;
        double[] alpha = new double[2];
        private double C = 1.0;
        double[] g = new double[2];
        double k;
        DoubleArrayList kcache;

        SupportVector() {
        }
    }
}

