package linear.regression;

import com.google.common.collect.ImmutableList;
import data.DoubleFunctions;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.ejml.alg.dense.mult.MatrixVectorMult;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.data.Matrix;
import org.ejml.factory.LinearSolverFactory;
import org.ejml.interfaces.decomposition.QRDecomposition;
import org.ejml.interfaces.linsol.LinearSolver;
import org.ejml.ops.CommonOps;
import stats.Statistics;

/* loaded from: input_file:linear/regression/MultipleLinearRegression.class */
public final class MultipleLinearRegression implements LinearRegression {
    private final List<List<Double>> predictors;
    private final List<Double> response;
    private final List<Double> beta;
    private final List<Double> standardErrors;
    private final List<Double> fitted;
    private final List<Double> residuals;
    private final double sigma2;
    private final boolean hasIntercept;

    /* loaded from: input_file:linear/regression/MultipleLinearRegression$Builder.class */
    public static final class Builder {
        private ImmutableList.Builder<List<Double>> listBuilder;
        private List<Double> response;
        private boolean hasIntercept = true;

        public Builder from(LinearRegression linearRegression) {
            this.listBuilder = ImmutableList.builder();
            Iterator<List<Double>> it = linearRegression.predictors().iterator();
            while (it.hasNext()) {
                this.listBuilder.add(ImmutableList.copyOf(it.next()));
            }
            this.response = ImmutableList.copyOf(linearRegression.response());
            this.hasIntercept = linearRegression.hasIntercept();
            return this;
        }

        Builder predictors(List<List<Double>> list) {
            this.listBuilder = ImmutableList.builder();
            Iterator<List<Double>> it = list.iterator();
            while (it.hasNext()) {
                this.listBuilder.add(ImmutableList.copyOf(it.next()));
            }
            return this;
        }

        public Builder predictor(List<Double> list) {
            if (this.listBuilder == null) {
                this.listBuilder = ImmutableList.builder();
            }
            this.listBuilder.add(ImmutableList.copyOf(list));
            return this;
        }

        public Builder response(List<Double> list) {
            this.response = ImmutableList.copyOf(list);
            return this;
        }

        public Builder hasIntercept(boolean z) {
            this.hasIntercept = z;
            return this;
        }

        public MultipleLinearRegression build() {
            return new MultipleLinearRegression(this);
        }
    }

    /* loaded from: input_file:linear/regression/MultipleLinearRegression$MatrixFormulation.class */
    private class MatrixFormulation {
        private final DenseMatrix64F A;
        private final DenseMatrix64F At;
        private final DenseMatrix64F AtAInv;
        private final DenseMatrix64F b;
        private final DenseMatrix64F y;
        private final D1Matrix64F fitted;
        private final List<Double> residuals;
        private final double sigma2;
        private final DenseMatrix64F covarianceMatrix;

        private MatrixFormulation() {
            int size = MultipleLinearRegression.this.response.size();
            int size2 = MultipleLinearRegression.this.predictors.size() + (MultipleLinearRegression.this.hasIntercept() ? 1 : 0);
            this.A = createMatrixA(size, size2);
            this.At = new DenseMatrix64F(size2, size);
            CommonOps.transpose(this.A, this.At);
            this.AtAInv = new DenseMatrix64F(size2, size2);
            this.b = new DenseMatrix64F(size2, 1);
            this.y = new DenseMatrix64F(size, 1);
            solveSystem(size, size2);
            this.fitted = computeFittedValues();
            this.residuals = computeResiduals();
            this.sigma2 = estimateSigma2(size2);
            this.covarianceMatrix = new DenseMatrix64F(size2, size2);
            CommonOps.scale(this.sigma2, this.AtAInv, this.covarianceMatrix);
        }

        private void solveSystem(int i, int i2) {
            LinearSolver qr = LinearSolverFactory.qr(i, i2);
            QRDecomposition decomposition = qr.getDecomposition();
            qr.setA(this.A);
            this.y.setData(DoubleFunctions.arrayFrom((List<Double>) MultipleLinearRegression.this.response));
            qr.solve(this.y, this.b);
            DenseMatrix64F r = decomposition.getR((Matrix) null, true);
            LinearSolver linear2 = LinearSolverFactory.linear(i2);
            linear2.setA(r);
            DenseMatrix64F denseMatrix64F = new DenseMatrix64F(i2, i2);
            linear2.invert(denseMatrix64F);
            CommonOps.multOuter(denseMatrix64F, this.AtAInv);
        }

        /* JADX WARN: Type inference failed for: r0v19, types: [double[], double[][]] */
        private DenseMatrix64F createMatrixA(int i, int i2) {
            double[] fill = MultipleLinearRegression.this.hasIntercept ? DoubleFunctions.fill(i, 1.0d) : DoubleFunctions.arrayFrom(new double[0]);
            Iterator it = MultipleLinearRegression.this.predictors.iterator();
            while (it.hasNext()) {
                fill = DoubleFunctions.combine(new double[]{fill, DoubleFunctions.arrayFrom((List<Double>) it.next())});
            }
            return new DenseMatrix64F(i, i2, false, fill);
        }

        private D1Matrix64F computeFittedValues() {
            DenseMatrix64F denseMatrix64F = new DenseMatrix64F(MultipleLinearRegression.this.response.size(), 1);
            MatrixVectorMult.mult(this.A, this.b, denseMatrix64F);
            return denseMatrix64F;
        }

        private List<Double> computeResiduals() {
            List<Double> fittedvalues = getFittedvalues();
            ArrayList arrayList = new ArrayList(fittedvalues.size());
            for (int i = 0; i < fittedvalues.size(); i++) {
                arrayList.add(Double.valueOf(((Double) MultipleLinearRegression.this.response.get(i)).doubleValue() - fittedvalues.get(i).doubleValue()));
            }
            return arrayList;
        }

        private double estimateSigma2(int i) {
            return Statistics.sumOfSquared(DoubleFunctions.arrayFrom(this.residuals)) / (this.residuals.size() - i);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public List<Double> getFittedvalues() {
            return DoubleFunctions.listFrom(this.fitted.getData());
        }

        /* JADX INFO: Access modifiers changed from: private */
        public List<Double> getResiduals() {
            return this.residuals;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public List<Double> getBetaEstimates() {
            return DoubleFunctions.listFrom(this.b.getData());
        }

        /* JADX INFO: Access modifiers changed from: private */
        public List<Double> getBetaStandardErrors(int i) {
            DenseMatrix64F denseMatrix64F = new DenseMatrix64F(i, 1);
            CommonOps.extractDiag(this.covarianceMatrix, denseMatrix64F);
            return DoubleFunctions.listFrom(DoubleFunctions.sqrt(denseMatrix64F.getData()));
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double getSigma2() {
            return this.sigma2;
        }
    }

    private MultipleLinearRegression(Builder builder) {
        this.predictors = builder.listBuilder.build();
        this.response = builder.response;
        this.hasIntercept = builder.hasIntercept;
        MatrixFormulation matrixFormulation = new MatrixFormulation();
        this.beta = matrixFormulation.getBetaEstimates();
        this.fitted = matrixFormulation.getFittedvalues();
        this.residuals = matrixFormulation.getResiduals();
        this.sigma2 = matrixFormulation.getSigma2();
        this.standardErrors = matrixFormulation.getBetaStandardErrors(this.beta.size());
    }

    @Override // linear.regression.LinearRegression
    public List<List<Double>> predictors() {
        return this.predictors;
    }

    @Override // linear.regression.LinearRegression
    public List<Double> beta() {
        return ImmutableList.copyOf(this.beta);
    }

    @Override // linear.regression.LinearRegression
    public List<Double> standardErrors() {
        return ImmutableList.copyOf(this.standardErrors);
    }

    @Override // linear.regression.LinearRegression
    public List<Double> response() {
        return this.response;
    }

    @Override // linear.regression.LinearRegression
    public List<Double> fitted() {
        return ImmutableList.copyOf(this.fitted);
    }

    @Override // linear.regression.LinearRegression
    public List<Double> residuals() {
        return ImmutableList.copyOf(this.residuals);
    }

    @Override // linear.regression.LinearRegression
    public double sigma2() {
        return this.sigma2;
    }

    @Override // linear.regression.LinearRegression
    public boolean hasIntercept() {
        return this.hasIntercept;
    }

    public MultipleLinearRegression withHasIntercept(boolean z) {
        return new Builder().from(this).hasIntercept(z).build();
    }

    public MultipleLinearRegression withResponse(List<Double> list) {
        return new Builder().from(this).response(list).build();
    }

    public MultipleLinearRegression withPredictor(List<Double> list) {
        return new Builder().from(this).predictor(list).build();
    }

    public MultipleLinearRegression withPredictors(List<List<Double>> list) {
        return new Builder().from(this).predictors(list).build();
    }

    public static Builder builder() {
        return new Builder();
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof MultipleLinearRegression)) {
            return false;
        }
        MultipleLinearRegression multipleLinearRegression = (MultipleLinearRegression) obj;
        List<List<Double>> list = this.predictors;
        List<List<Double>> list2 = multipleLinearRegression.predictors;
        if (list == null) {
            if (list2 != null) {
                return false;
            }
        } else if (!list.equals(list2)) {
            return false;
        }
        List<Double> list3 = this.response;
        List<Double> list4 = multipleLinearRegression.response;
        if (list3 == null) {
            if (list4 != null) {
                return false;
            }
        } else if (!list3.equals(list4)) {
            return false;
        }
        List<Double> list5 = this.beta;
        List<Double> list6 = multipleLinearRegression.beta;
        if (list5 == null) {
            if (list6 != null) {
                return false;
            }
        } else if (!list5.equals(list6)) {
            return false;
        }
        List<Double> list7 = this.standardErrors;
        List<Double> list8 = multipleLinearRegression.standardErrors;
        if (list7 == null) {
            if (list8 != null) {
                return false;
            }
        } else if (!list7.equals(list8)) {
            return false;
        }
        List<Double> list9 = this.fitted;
        List<Double> list10 = multipleLinearRegression.fitted;
        if (list9 == null) {
            if (list10 != null) {
                return false;
            }
        } else if (!list9.equals(list10)) {
            return false;
        }
        List<Double> list11 = this.residuals;
        List<Double> list12 = multipleLinearRegression.residuals;
        if (list11 == null) {
            if (list12 != null) {
                return false;
            }
        } else if (!list11.equals(list12)) {
            return false;
        }
        return Double.compare(this.sigma2, multipleLinearRegression.sigma2) == 0 && this.hasIntercept == multipleLinearRegression.hasIntercept;
    }

    public int hashCode() {
        List<List<Double>> list = this.predictors;
        int hashCode = (1 * 59) + (list == null ? 43 : list.hashCode());
        List<Double> list2 = this.response;
        int hashCode2 = (hashCode * 59) + (list2 == null ? 43 : list2.hashCode());
        List<Double> list3 = this.beta;
        int hashCode3 = (hashCode2 * 59) + (list3 == null ? 43 : list3.hashCode());
        List<Double> list4 = this.standardErrors;
        int hashCode4 = (hashCode3 * 59) + (list4 == null ? 43 : list4.hashCode());
        List<Double> list5 = this.fitted;
        int hashCode5 = (hashCode4 * 59) + (list5 == null ? 43 : list5.hashCode());
        List<Double> list6 = this.residuals;
        int hashCode6 = (hashCode5 * 59) + (list6 == null ? 43 : list6.hashCode());
        long doubleToLongBits = Double.doubleToLongBits(this.sigma2);
        return (((hashCode6 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits))) * 59) + (this.hasIntercept ? 79 : 97);
    }

    public String toString() {
        return "MultipleLinearRegression(predictors=" + this.predictors + ", response=" + this.response + ", beta=" + this.beta + ", standardErrors=" + this.standardErrors + ", fitted=" + this.fitted + ", residuals=" + this.residuals + ", sigma2=" + this.sigma2 + ", hasIntercept=" + this.hasIntercept + ")";
    }
}
