/*
 * Decompiled with CFR 0.152.
 */
package timeseries.models;

import com.google.common.primitives.Doubles;
import java.awt.Color;
import java.awt.Component;
import java.awt.Font;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import javax.swing.JFrame;
import org.knowm.xchart.XChartPanel;
import org.knowm.xchart.XYChart;
import org.knowm.xchart.XYChartBuilder;
import org.knowm.xchart.XYSeries;
import org.knowm.xchart.internal.chartpart.Chart;
import org.knowm.xchart.style.Styler;
import org.knowm.xchart.style.XYStyler;
import org.knowm.xchart.style.markers.Marker;
import org.knowm.xchart.style.markers.None;
import stats.distributions.Normal;
import timeseries.TimeSeries;
import timeseries.models.Forecast;
import timeseries.models.Model;
import timeseries.models.RandomWalk;

public final class RandomWalkForecast
implements Forecast {
    private final Model model;
    private final TimeSeries forecast;
    private final TimeSeries upperValues;
    private final TimeSeries lowerValues;
    private final double criticalValue;
    private final TimeSeries fcstErrors;

    public RandomWalkForecast(RandomWalk model, int steps, double alpha) {
        this.model = model;
        this.forecast = model.pointForecast(steps);
        this.criticalValue = new Normal(0.0, model.residuals().stdDeviation()).quantile(1.0 - alpha / 2.0);
        this.fcstErrors = this.getFcstErrors();
        this.upperValues = this.computeUpperPredictionBounds(steps, alpha);
        this.lowerValues = this.computeLowerPredictionBounds(steps, alpha);
    }

    public RandomWalkForecast(TimeSeries series, int steps, double alpha) {
        this.model = new RandomWalk(series);
        this.forecast = this.model.pointForecast(steps);
        this.criticalValue = new Normal(0.0, this.model.residuals().stdDeviation()).quantile(1.0 - alpha / 2.0);
        this.fcstErrors = this.getFcstErrors();
        this.upperValues = this.computeUpperPredictionBounds(steps, alpha);
        this.lowerValues = this.computeLowerPredictionBounds(steps, alpha);
    }

    @Override
    public TimeSeries forecast() {
        return this.forecast;
    }

    @Override
    public TimeSeries upperPredictionValues() {
        return this.upperValues;
    }

    @Override
    public TimeSeries lowerPredictionValues() {
        return this.lowerValues;
    }

    @Override
    public TimeSeries computeUpperPredictionBounds(int steps, double alpha) {
        double[] upperPredictionValues = new double[steps];
        double criticalValue = new Normal(0.0, this.model.residuals().stdDeviation()).quantile(1.0 - alpha / 2.0);
        for (int t = 0; t < steps; ++t) {
            upperPredictionValues[t] = this.forecast.at(t) + criticalValue * Math.sqrt(t + 1);
        }
        return new TimeSeries(this.forecast.timePeriod(), this.forecast.observationTimes().get(0), upperPredictionValues);
    }

    @Override
    public TimeSeries computeLowerPredictionBounds(int steps, double alpha) {
        double[] upperPredictionValues = new double[steps];
        double criticalValue = new Normal(0.0, this.model.residuals().stdDeviation()).quantile(1.0 - alpha / 2.0);
        for (int t = 0; t < steps; ++t) {
            upperPredictionValues[t] = this.forecast.at(t) - criticalValue * Math.sqrt(t + 1);
        }
        return new TimeSeries(this.forecast.timePeriod(), this.forecast.observationTimes().get(0), upperPredictionValues);
    }

    private TimeSeries getFcstErrors() {
        double[] errors = new double[this.forecast.n()];
        for (int t = 0; t < errors.length; ++t) {
            errors[t] = this.criticalValue * Math.sqrt(t + 1);
        }
        return new TimeSeries(this.forecast.timePeriod(), this.forecast.observationTimes().get(0), errors);
    }

    @Override
    public void plot() {
        new Thread(() -> {
            ArrayList<Date> xAxis = new ArrayList<Date>(this.forecast.observationTimes().size());
            ArrayList<Date> xAxisObs = new ArrayList<Date>(this.model.timeSeries().n());
            for (OffsetDateTime dateTime : this.model.timeSeries().observationTimes()) {
                xAxisObs.add(Date.from(dateTime.toInstant()));
            }
            for (OffsetDateTime dateTime : this.forecast.observationTimes()) {
                xAxis.add(Date.from(dateTime.toInstant()));
            }
            List errorList = Doubles.asList((double[])this.fcstErrors.asArray());
            List seriesList = Doubles.asList((double[])this.model.timeSeries().asArray());
            List forecastList = Doubles.asList((double[])this.forecast.asArray());
            XYChart chart = ((XYChartBuilder)((XYChartBuilder)((XYChartBuilder)((XYChartBuilder)new XYChartBuilder().theme(Styler.ChartTheme.GGPlot2)).height(800)).width(1200)).title("Random Walk Past and Future")).build();
            XYSeries observationSeries = chart.addSeries("Past", xAxisObs, seriesList);
            XYSeries forecastSeries = chart.addSeries("Future", xAxis, forecastList, errorList);
            observationSeries.setMarker((Marker)new None());
            forecastSeries.setMarker((Marker)new None());
            observationSeries.setLineWidth(0.75f);
            forecastSeries.setLineWidth(1.5f);
            ((XYStyler)chart.getStyler()).setDefaultSeriesRenderStyle(XYSeries.XYSeriesRenderStyle.Line).setErrorBarsColor(Color.RED);
            observationSeries.setLineColor(Color.BLACK);
            forecastSeries.setLineColor(Color.BLUE);
            XChartPanel panel = new XChartPanel((Chart)chart);
            JFrame frame = new JFrame("Random Walk Past and Future");
            frame.setDefaultCloseOperation(2);
            frame.add((Component)panel);
            frame.pack();
            frame.setVisible(true);
        }).start();
    }

    @Override
    public void plotForecast() {
        new Thread(() -> {
            ArrayList<Date> xAxis = new ArrayList<Date>(this.forecast.observationTimes().size());
            for (OffsetDateTime dateTime : this.forecast.observationTimes()) {
                xAxis.add(Date.from(dateTime.toInstant()));
            }
            List errorList = Doubles.asList((double[])this.fcstErrors.asArray());
            List forecastList = Doubles.asList((double[])this.forecast.asArray());
            XYChart chart = ((XYChartBuilder)((XYChartBuilder)((XYChartBuilder)((XYChartBuilder)new XYChartBuilder().theme(Styler.ChartTheme.GGPlot2)).height(600)).width(800)).title("Random Walk Forecast")).build();
            chart.setXAxisTitle("Time");
            chart.setYAxisTitle("Forecast Values");
            ((XYStyler)chart.getStyler()).setAxisTitleFont(new Font("Arial", 0, 14));
            ((XYStyler)chart.getStyler()).setDefaultSeriesRenderStyle(XYSeries.XYSeriesRenderStyle.Line).setErrorBarsColor(Color.RED).setChartFontColor(new Color(112, 112, 112));
            XYSeries forecastSeries = chart.addSeries("Forecast", xAxis, forecastList, errorList);
            forecastSeries.setMarker((Marker)new None());
            forecastSeries.setLineWidth(1.5f);
            forecastSeries.setLineColor(Color.BLUE);
            XChartPanel panel = new XChartPanel((Chart)chart);
            JFrame frame = new JFrame("Random Walk Forecast");
            frame.setDefaultCloseOperation(2);
            frame.add((Component)panel);
            frame.pack();
            frame.setVisible(true);
        }).start();
    }
}

