package org.dromara.easyai.nerveEntity;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.dromara.easyai.i.OutBack;

/* loaded from: input_file:org/dromara/easyai/nerveEntity/SoftMax.class */
public class SoftMax extends Nerve {
    private final List<OutNerve> outNerves;
    private final boolean isShowLog;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/dromara/easyai/nerveEntity/SoftMax$Mes.class */
    public static class Mes {
        int typeID;
        float poi;
        List<Float> softMax;

        Mes() {
        }
    }

    public SoftMax(int i, boolean z, List<OutNerve> list, boolean z2, int i2) throws Exception {
        super(0, i, "softMax", 0, 0.0f, false, null, z, 0, 0.0f, 0, 0, 0, 0, i2, 0, 0.0f, false);
        this.outNerves = list;
        this.isShowLog = z2;
    }

    @Override // org.dromara.easyai.nerveEntity.Nerve
    protected void input(long j, float f, boolean z, Map<Integer, Float> map, OutBack outBack) throws Exception {
        if (insertParameter(j, f)) {
            Mes softMax = softMax(j);
            int i = 0;
            if (!z) {
                destoryParameter(j);
                if (outBack == null) {
                    throw new Exception("not find outBack");
                }
                outBack.getBack(softMax.poi, softMax.typeID, j);
                outBack.getSoftMaxBack(j, softMax.softMax);
                return;
            }
            Iterator<Map.Entry<Integer, Float>> it = map.entrySet().iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Map.Entry<Integer, Float> next = it.next();
                if (next.getValue().floatValue() > 0.9d) {
                    i = next.getKey().intValue();
                    break;
                }
            }
            if (this.isShowLog) {
                System.out.println("softMax==" + i + ",out==" + softMax.poi + ",nerveId==" + softMax.typeID);
            }
            List<Float> error = error(softMax, i);
            this.features.remove(Long.valueOf(j));
            int size = this.outNerves.size();
            for (int i2 = 0; i2 < size; i2++) {
                this.outNerves.get(i2).getGBySoftMax(error.get(i2).floatValue(), j);
            }
        }
    }

    private List<Float> error(Mes mes, int i) {
        int i2 = i - 1;
        List<Float> list = mes.softMax;
        ArrayList arrayList = new ArrayList();
        int i3 = 0;
        while (i3 < list.size()) {
            float floatValue = list.get(i3).floatValue();
            arrayList.add(Float.valueOf(i3 != i2 ? -floatValue : 1.0f - floatValue));
            i3++;
        }
        return arrayList;
    }

    private Mes softMax(long j) {
        float f = 0.0f;
        int i = 0;
        float f2 = 0.0f;
        Mes mes = new Mes();
        List<Float> list = this.features.get(Long.valueOf(j));
        Iterator<Float> it = list.iterator();
        while (it.hasNext()) {
            f = ((float) Math.exp(it.next().floatValue())) + f;
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < list.size(); i2++) {
            float exp = ((float) Math.exp(list.get(i2).floatValue())) / f;
            arrayList.add(Float.valueOf(exp));
            if (exp > f2) {
                f2 = exp;
                i = i2 + 1;
            }
        }
        mes.softMax = arrayList;
        mes.typeID = i;
        mes.poi = f2;
        return mes;
    }
}
