/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.KMeans;
import smile.math.Math;
import smile.util.MulticoreExecutor;

public class DeterministicAnnealing
extends KMeans
implements Serializable {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = LoggerFactory.getLogger(DeterministicAnnealing.class);
    private double alpha;
    private transient List<UpdateThread> tasks = null;
    private transient List<CentroidThread> ctasks = null;

    public DeterministicAnnealing(double[][] data, int Kmax) {
        this(data, Kmax, 0.9);
    }

    public DeterministicAnnealing(double[][] data, int Kmax, double alpha) {
        int j;
        int i;
        int i2;
        if (alpha <= 0.0 || alpha >= 1.0) {
            throw new IllegalArgumentException("Invalid alpha: " + alpha);
        }
        this.alpha = alpha;
        int n = data.length;
        int d = data[0].length;
        this.centroids = new double[2 * Kmax][d];
        double[][] posteriori = new double[n][2 * Kmax];
        double[] priori = new double[2 * Kmax];
        int np = MulticoreExecutor.getThreadPoolSize();
        if (n >= 1000 && np >= 2) {
            int i3;
            this.tasks = new ArrayList<UpdateThread>(np + 1);
            int step = n / np;
            if (step < 100) {
                step = 100;
            }
            int start = 0;
            int end = step;
            for (i3 = 0; i3 < np - 1; ++i3) {
                this.tasks.add(new UpdateThread(data, this.centroids, posteriori, priori, start, end));
                start += step;
                end += step;
            }
            this.tasks.add(new UpdateThread(data, this.centroids, posteriori, priori, start, n));
            this.ctasks = new ArrayList<CentroidThread>(2 * Kmax);
            for (i3 = 0; i3 < 2 * Kmax; ++i3) {
                this.ctasks.add(new CentroidThread(data, this.centroids, posteriori, priori, i3));
            }
        }
        for (i2 = 0; i2 < n; ++i2) {
            for (int j2 = 0; j2 < d; ++j2) {
                double[] dArray = this.centroids[0];
                int n2 = j2;
                dArray[n2] = dArray[n2] + data[i2][j2];
            }
        }
        for (i2 = 0; i2 < d; ++i2) {
            double[] dArray = this.centroids[0];
            int n3 = i2;
            dArray[n3] = dArray[n3] / (double)n;
            this.centroids[1][i2] = this.centroids[0][i2] * 1.01;
        }
        priori[1] = 0.5;
        priori[0] = 0.5;
        double[][] cov = Math.cov((double[][])data, (double[])this.centroids[0]);
        double[] ev = new double[d];
        Arrays.fill(ev, 1.0);
        double lambda = Math.eigen((double[][])cov, (double[])ev, (double)1.0E-4);
        double T = 2.0 * lambda + 0.01;
        this.k = 2;
        boolean stop = false;
        boolean split = false;
        while (!stop) {
            this.update(data, T, this.k, this.centroids, posteriori, priori);
            if (this.k >= 2 * Kmax && split) {
                stop = true;
            }
            int currentK = this.k;
            for (int i4 = 0; i4 < currentK; i4 += 2) {
                int j3;
                double norm = 0.0;
                for (j3 = 0; j3 < d; ++j3) {
                    double diff = this.centroids[i4][j3] - this.centroids[i4 + 1][j3];
                    norm += diff * diff;
                }
                if (norm > 0.01) {
                    if (this.k < 2 * Kmax) {
                        for (j3 = 0; j3 < d; ++j3) {
                            this.centroids[this.k][j3] = this.centroids[i4 + 1][j3];
                            this.centroids[this.k + 1][j3] = this.centroids[i4 + 1][j3] * 1.01;
                        }
                        priori[this.k] = priori[i4 + 1] / 2.0;
                        priori[this.k + 1] = priori[i4 + 1] / 2.0;
                        priori[i4] = priori[i4] / 2.0;
                        priori[i4 + 1] = priori[i4] / 2.0;
                        this.k += 2;
                    }
                    if (currentK >= 2 * Kmax) {
                        split = true;
                    }
                }
                for (j3 = 0; j3 < d; ++j3) {
                    this.centroids[i4 + 1][j3] = this.centroids[i4][j3] * 1.01;
                }
            }
            if (split) {
                T /= alpha;
            } else if (this.k - currentK > 2) {
                T /= alpha;
                alpha += 5.0 * Math.pow((double)10.0, (double)(Math.log10((double)(1.0 - alpha)) - 1.0));
            } else {
                if (this.k > currentK && this.k == 2 * Kmax - 2) {
                    alpha += 5.0 * Math.pow((double)10.0, (double)(Math.log10((double)(1.0 - alpha)) - 1.0));
                }
                T *= alpha;
            }
            if (!(alpha >= 1.0)) continue;
            break;
        }
        this.k /= 2;
        this.y = new int[n];
        this.distortion = 0.0;
        for (i = 0; i < n; ++i) {
            double nearest = Double.MAX_VALUE;
            for (int j4 = 0; j4 < this.k; j4 += 2) {
                double dist = Math.squaredDistance((double[])data[i], (double[])this.centroids[j4]);
                if (!(nearest > dist)) continue;
                this.y[i] = j4 / 2;
                nearest = dist;
            }
            this.distortion += nearest;
        }
        this.size = new int[this.k];
        this.centroids = new double[this.k][d];
        for (i = 0; i < n; ++i) {
            int n4 = this.y[i];
            this.size[n4] = this.size[n4] + 1;
            for (j = 0; j < d; ++j) {
                double[] dArray = this.centroids[this.y[i]];
                int n5 = j;
                dArray[n5] = dArray[n5] + data[i][j];
            }
        }
        for (i = 0; i < this.k; ++i) {
            j = 0;
            while (j < d) {
                double[] dArray = this.centroids[i];
                int n6 = j++;
                dArray[n6] = dArray[n6] / (double)this.size[i];
            }
        }
    }

    public double getAlpha() {
        return this.alpha;
    }

    private double update(double[][] data, double T, int k, double[][] centroids, double[][] posteriori, double[] priori) {
        int iter;
        int n = data.length;
        int d = data[0].length;
        double D = 0.0;
        double H = 0.0;
        double currentDistortion = Double.MAX_VALUE;
        double newDistortion = 8.988465674311579E307;
        for (iter = 0; iter < 100 && currentDistortion > newDistortion; ++iter) {
            currentDistortion = newDistortion;
            D = Double.NaN;
            H = 0.0;
            if (this.tasks != null) {
                try {
                    D = 0.0;
                    for (UpdateThread t : this.tasks) {
                        t.k = k;
                        t.T = T;
                    }
                    for (UpdateThread t : MulticoreExecutor.run(this.tasks)) {
                        D += t.D;
                        H += t.H;
                    }
                }
                catch (Exception ex) {
                    logger.error("Failed to run Deterministic Annealing on multi-core", (Throwable)ex);
                    D = Double.NaN;
                }
            }
            if (Double.isNaN(D)) {
                D = 0.0;
                double[] dist = new double[k];
                for (int i = 0; i < n; ++i) {
                    double p = 0.0;
                    for (int j = 0; j < k; ++j) {
                        dist[j] = Math.squaredDistance((double[])data[i], (double[])centroids[j]);
                        posteriori[i][j] = priori[j] * Math.exp((double)(-dist[j] / T));
                        p += posteriori[i][j];
                    }
                    double r = 0.0;
                    for (int j = 0; j < k; ++j) {
                        double[] dArray = posteriori[i];
                        int n2 = j;
                        dArray[n2] = dArray[n2] / p;
                        D += posteriori[i][j] * dist[j];
                        r += -posteriori[i][j] * Math.log((double)posteriori[i][j]);
                    }
                    H += r;
                }
            }
            int i = 0;
            while (i < k) {
                priori[i] = 0.0;
                for (int j = 0; j < n; ++j) {
                    int n3 = i;
                    priori[n3] = priori[n3] + posteriori[j][i];
                }
                int n4 = i++;
                priori[n4] = priori[n4] / (double)n;
            }
            boolean parallel = false;
            if (this.ctasks != null) {
                try {
                    for (CentroidThread t : this.ctasks) {
                        t.k = k;
                    }
                    MulticoreExecutor.run(this.ctasks);
                    parallel = true;
                }
                catch (Exception ex) {
                    logger.error("Failed to run Deterministic Annealing on multi-core", (Throwable)ex);
                    parallel = false;
                }
            }
            if (!parallel) {
                for (int i2 = 0; i2 < k; ++i2) {
                    Arrays.fill(centroids[i2], 0.0);
                    int j = 0;
                    while (j < d) {
                        for (int m = 0; m < n; ++m) {
                            double[] dArray = centroids[i2];
                            int n5 = j;
                            dArray[n5] = dArray[n5] + data[m][j] * posteriori[m][i2];
                        }
                        double[] dArray = centroids[i2];
                        int n6 = j++;
                        dArray[n6] = dArray[n6] / ((double)n * priori[i2]);
                    }
                }
            }
            newDistortion = D - T * H;
        }
        logger.info(String.format("Deterministic Annealing clustering entropy after %3d iterations at temperature %.4f and k = %d: %.5f (soft distortion = %.5f )%n", iter, T, k / 2, H, D));
        return currentDistortion;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("Deterministic Annealing clustering distortion: %.5f%n", this.distortion));
        sb.append(String.format("Clusters of %d data points:%n", this.y.length));
        for (int i = 0; i < this.k; ++i) {
            int r = (int)Math.round((double)(1000.0 * (double)this.size[i] / (double)this.y.length));
            sb.append(String.format("%3d\t%5d (%2d.%1d%%)%n", i, this.size[i], r / 10, r % 10));
        }
        return sb.toString();
    }

    class CentroidThread
    implements Callable<CentroidThread> {
        final int i;
        final double[][] data;
        int k;
        double[][] centroids;
        double[][] posteriori;
        double[] priori;

        CentroidThread(double[][] data, double[][] centroids, double[][] posteriori, double[] priori, int i) {
            this.data = data;
            this.centroids = centroids;
            this.posteriori = posteriori;
            this.priori = priori;
            this.i = i;
        }

        @Override
        public CentroidThread call() {
            if (this.i < this.k) {
                int n = this.data.length;
                int d = this.data[0].length;
                Arrays.fill(this.centroids[this.i], 0.0);
                int j = 0;
                while (j < d) {
                    for (int m = 0; m < n; ++m) {
                        double[] dArray = this.centroids[this.i];
                        int n2 = j;
                        dArray[n2] = dArray[n2] + this.data[m][j] * this.posteriori[m][this.i];
                    }
                    double[] dArray = this.centroids[this.i];
                    int n3 = j++;
                    dArray[n3] = dArray[n3] / ((double)n * this.priori[this.i]);
                }
            }
            return this;
        }
    }

    class UpdateThread
    implements Callable<UpdateThread> {
        final int start;
        final int end;
        final double[][] data;
        final double[][] centroids;
        int k;
        double T;
        double D;
        double H;
        double[][] posteriori;
        double[] priori;
        double[] dist;

        UpdateThread(double[][] data, double[][] centroids, double[][] posteriori, double[] priori, int start, int end) {
            this.data = data;
            this.centroids = centroids;
            this.posteriori = posteriori;
            this.priori = priori;
            this.start = start;
            this.end = end;
            this.dist = new double[centroids.length];
        }

        @Override
        public UpdateThread call() {
            this.D = 0.0;
            this.H = 0.0;
            for (int i = this.start; i < this.end; ++i) {
                double p = 0.0;
                for (int j = 0; j < this.k; ++j) {
                    this.dist[j] = Math.squaredDistance((double[])this.data[i], (double[])this.centroids[j]);
                    this.posteriori[i][j] = this.priori[j] * Math.exp((double)(-this.dist[j] / this.T));
                    p += this.posteriori[i][j];
                }
                double r = 0.0;
                for (int j = 0; j < this.k; ++j) {
                    double[] dArray = this.posteriori[i];
                    int n = j;
                    dArray[n] = dArray[n] / p;
                    this.D += this.posteriori[i][j] * this.dist[j];
                    r += -this.posteriori[i][j] * Math.log((double)this.posteriori[i][j]);
                }
                this.H += r;
            }
            return this;
        }
    }
}

