/*
 * Decompiled with CFR 0.152.
 */
package timeseries.models.arima;

import com.google.common.primitives.Doubles;
import java.awt.Color;
import java.awt.Component;
import java.awt.Font;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import javax.swing.JFrame;
import org.knowm.xchart.XChartPanel;
import org.knowm.xchart.XYChart;
import org.knowm.xchart.XYChartBuilder;
import org.knowm.xchart.XYSeries;
import org.knowm.xchart.internal.chartpart.Chart;
import org.knowm.xchart.style.Styler;
import org.knowm.xchart.style.XYStyler;
import org.knowm.xchart.style.markers.Circle;
import org.knowm.xchart.style.markers.Marker;
import stats.distributions.Normal;
import timeseries.TimeSeries;
import timeseries.models.Forecast;
import timeseries.models.arima.Arima;
import timeseries.operators.LagPolynomial;

public final class ArimaForecast
implements Forecast {
    private final Arima model;
    private final TimeSeries forecast;
    private final TimeSeries upperValues;
    private final TimeSeries lowerValues;
    private final double alpha;
    private final double criticalValue;
    private final TimeSeries fcstErrors;

    private ArimaForecast(Arima model, int steps, double alpha) {
        this.model = model;
        this.forecast = model.pointForecast(steps);
        this.alpha = alpha;
        this.criticalValue = new Normal().quantile(1.0 - alpha / 2.0);
        this.fcstErrors = this.getFcstErrors(this.criticalValue);
        this.upperValues = this.computeUpperPredictionBounds(steps, alpha);
        this.lowerValues = this.computeLowerPredictionBounds(steps, alpha);
    }

    public static ArimaForecast forecast(Arima model, int steps, double alpha) {
        return new ArimaForecast(model, steps, alpha);
    }

    public static ArimaForecast forecast(Arima model, int steps) {
        return new ArimaForecast(model, steps, 0.05);
    }

    public static ArimaForecast forecast(Arima model) {
        return new ArimaForecast(model, 12, 0.05);
    }

    @Override
    public TimeSeries forecast() {
        return this.forecast;
    }

    @Override
    public TimeSeries upperPredictionValues() {
        return this.upperValues;
    }

    @Override
    public TimeSeries lowerPredictionValues() {
        return this.lowerValues;
    }

    @Override
    public TimeSeries computeUpperPredictionBounds(int steps, double alpha) {
        double criticalValue = new Normal().quantile(1.0 - alpha / 2.0);
        double[] upperPredictionValues = new double[steps];
        double[] errors = this.getStdErrors(criticalValue);
        for (int t = 0; t < steps; ++t) {
            upperPredictionValues[t] = this.forecast.at(t) + errors[t];
        }
        return new TimeSeries(this.forecast.timePeriod(), this.forecast.observationTimes().get(0), upperPredictionValues);
    }

    @Override
    public TimeSeries computeLowerPredictionBounds(int steps, double alpha) {
        double criticalValue = new Normal().quantile(alpha / 2.0);
        double[] lowerPredictionValues = new double[steps];
        double[] errors = this.getStdErrors(criticalValue);
        for (int t = 0; t < steps; ++t) {
            lowerPredictionValues[t] = this.forecast.at(t) + errors[t];
        }
        return new TimeSeries(this.forecast.timePeriod(), this.forecast.observationTimes().get(0), lowerPredictionValues);
    }

    private double[] getPsiCoefficients() {
        int steps = this.forecast.n();
        LagPolynomial arPoly = LagPolynomial.autoRegressive(this.model.arSarCoefficients());
        LagPolynomial diffPoly = LagPolynomial.differences(this.model.order().d);
        LagPolynomial seasDiffPoly = LagPolynomial.seasonalDifferences(this.model.seasonalFrequency(), this.model.order().D);
        double[] phi = diffPoly.times(seasDiffPoly).times(arPoly).inverseParams();
        double[] theta = this.model.maSmaCoefficients();
        double[] psi = new double[steps];
        psi[0] = 1.0;
        System.arraycopy(theta, 0, psi, 1, Math.min(steps - 1, theta.length));
        for (int j = 1; j < psi.length; ++j) {
            for (int i = 0; i < Math.min(j, phi.length); ++i) {
                int n = j;
                psi[n] = psi[n] + psi[j - i - 1] * phi[i];
            }
        }
        return psi;
    }

    private TimeSeries getFcstErrors(double criticalValue) {
        double[] errors = this.getStdErrors(criticalValue);
        return new TimeSeries(this.forecast.timePeriod(), this.forecast.observationTimes().get(0), errors);
    }

    private double[] getStdErrors(double criticalValue) {
        double[] psiCoeffs = this.getPsiCoefficients();
        double[] stdErrors = new double[this.forecast.n()];
        double sigma = Math.sqrt(this.model.sigma2());
        double psiWeightSum = 0.0;
        for (int i = 0; i < stdErrors.length; ++i) {
            double sd = sigma * Math.sqrt(psiWeightSum += psiCoeffs[i] * psiCoeffs[i]);
            stdErrors[i] = criticalValue * sd;
        }
        return stdErrors;
    }

    @Override
    public void plot() {
        new Thread(() -> {
            ArrayList<Date> xAxis = new ArrayList<Date>(this.forecast.observationTimes().size());
            ArrayList<Date> xAxisObs = new ArrayList<Date>(this.model.timeSeries().n());
            for (OffsetDateTime dateTime : this.model.timeSeries().observationTimes()) {
                xAxisObs.add(Date.from(dateTime.toInstant()));
            }
            for (OffsetDateTime dateTime : this.forecast.observationTimes()) {
                xAxis.add(Date.from(dateTime.toInstant()));
            }
            List errorList = Doubles.asList((double[])this.fcstErrors.asArray());
            List seriesList = Doubles.asList((double[])this.model.timeSeries().asArray());
            List forecastList = Doubles.asList((double[])this.forecast.asArray());
            XYChart chart = ((XYChartBuilder)((XYChartBuilder)((XYChartBuilder)((XYChartBuilder)new XYChartBuilder().theme(Styler.ChartTheme.GGPlot2)).height(800)).width(1200)).title("ARIMA Forecast")).build();
            XYSeries observationSeries = chart.addSeries("Past", xAxisObs, seriesList);
            XYSeries forecastSeries = chart.addSeries("Future", xAxis, forecastList, errorList);
            observationSeries.setMarker((Marker)new Circle());
            observationSeries.setMarkerColor(Color.DARK_GRAY);
            forecastSeries.setMarker((Marker)new Circle());
            forecastSeries.setMarkerColor(Color.BLUE);
            observationSeries.setLineWidth(1.0f);
            forecastSeries.setLineWidth(1.0f);
            ((XYStyler)chart.getStyler()).setDefaultSeriesRenderStyle(XYSeries.XYSeriesRenderStyle.Line).setErrorBarsColor(Color.RED);
            observationSeries.setLineColor(Color.DARK_GRAY);
            forecastSeries.setLineColor(Color.BLUE);
            XChartPanel panel = new XChartPanel((Chart)chart);
            JFrame frame = new JFrame("ARIMA Forecast");
            frame.setDefaultCloseOperation(2);
            frame.add((Component)panel);
            frame.pack();
            frame.setVisible(true);
        }).start();
    }

    @Override
    public void plotForecast() {
        new Thread(() -> {
            ArrayList<Date> xAxis = new ArrayList<Date>(this.forecast.observationTimes().size());
            for (OffsetDateTime dateTime : this.forecast.observationTimes()) {
                xAxis.add(Date.from(dateTime.toInstant()));
            }
            List errorList = Doubles.asList((double[])this.fcstErrors.asArray());
            List forecastList = Doubles.asList((double[])this.forecast.asArray());
            XYChart chart = ((XYChartBuilder)((XYChartBuilder)((XYChartBuilder)((XYChartBuilder)new XYChartBuilder().theme(Styler.ChartTheme.GGPlot2)).height(600)).width(800)).title("ARIMA Forecast")).build();
            chart.setXAxisTitle("Time");
            chart.setYAxisTitle("Forecast Values");
            ((XYStyler)chart.getStyler()).setAxisTitleFont(new Font("Arial", 0, 14));
            ((XYStyler)chart.getStyler()).setDefaultSeriesRenderStyle(XYSeries.XYSeriesRenderStyle.Line).setErrorBarsColor(Color.RED).setChartFontColor(new Color(112, 112, 112));
            XYSeries forecastSeries = chart.addSeries("Forecast", xAxis, forecastList, errorList);
            forecastSeries.setMarker((Marker)new Circle());
            forecastSeries.setMarkerColor(Color.BLUE);
            forecastSeries.setLineWidth(1.0f);
            forecastSeries.setLineColor(Color.BLUE);
            XChartPanel panel = new XChartPanel((Chart)chart);
            JFrame frame = new JFrame("ARIMA Forecast");
            frame.setDefaultCloseOperation(2);
            frame.add((Component)panel);
            frame.pack();
            frame.setVisible(true);
        }).start();
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append(String.format("%-18.18s", "| Date ")).append("  ").append(String.format("%-13.13s", "| Forecast ")).append("  ").append(String.format("%-13.13s", "| Lower " + String.format("%.1f", (1.0 - this.alpha) * 100.0) + "%")).append("  ").append(String.format("%-13.13s", "| Upper " + String.format("%.1f", (1.0 - this.alpha) * 100.0) + "%")).append(" |").append("\n").append(String.format("%-70.70s", " -------------------------------------------------------------- ")).append("\n");
        for (int i = 0; i < this.forecast.n(); ++i) {
            builder.append(String.format("%-18.18s", "| " + this.forecast.observationTimes().get(i).toLocalDateTime())).append("  ").append(String.format("%-13.13s", "| " + Double.toString(this.forecast.at(i)))).append("  ").append(String.format("%-13.13s", "| " + Double.toString(this.lowerValues.at(i)))).append("  ").append(String.format("%-13.13s", "| " + Double.toString(this.upperValues.at(i)))).append(" |").append("\n");
        }
        return builder.toString();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        ArimaForecast that = (ArimaForecast)o;
        if (Double.compare(that.alpha, this.alpha) != 0) {
            return false;
        }
        if (Double.compare(that.criticalValue, this.criticalValue) != 0) {
            return false;
        }
        if (!this.model.equals(that.model)) {
            return false;
        }
        if (!this.forecast.equals(that.forecast)) {
            return false;
        }
        if (!this.upperValues.equals(that.upperValues)) {
            return false;
        }
        if (!this.lowerValues.equals(that.lowerValues)) {
            return false;
        }
        return this.fcstErrors.equals(that.fcstErrors);
    }

    public int hashCode() {
        int result = this.model.hashCode();
        result = 31 * result + this.forecast.hashCode();
        result = 31 * result + this.upperValues.hashCode();
        result = 31 * result + this.lowerValues.hashCode();
        long temp = Double.doubleToLongBits(this.alpha);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        temp = Double.doubleToLongBits(this.criticalValue);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        result = 31 * result + this.fcstErrors.hashCode();
        return result;
    }
}

