package org.dromara.easyai.randomForest;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:org/dromara/easyai/randomForest/Tree.class */
public class Tree {
    private DataTable dataTable;
    private Map<String, List<Integer>> table;
    private Node rootNode;
    private List<Integer> endList;
    private final List<Node> lastNodes = new ArrayList();
    private final Random random = new Random();
    private final float trustPunishment;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/dromara/easyai/randomForest/Tree$Gain.class */
    public static class Gain {
        private float gain;
        private float gainRatio;

        private Gain() {
        }
    }

    public Node getRootNode() {
        return this.rootNode;
    }

    public DataTable getDataTable() {
        return this.dataTable;
    }

    public void setRootNode(Node node) {
        this.rootNode = node;
    }

    public Tree(float f) {
        this.trustPunishment = f;
    }

    public Tree(DataTable dataTable, float f) throws Exception {
        if (dataTable == null || dataTable.getKey() == null) {
            throw new Exception("dataTable is empty");
        }
        this.trustPunishment = f;
        this.dataTable = dataTable;
    }

    private float log2(float f) {
        return ((float) Math.log(f)) / ((float) Math.log(2.0d));
    }

    private float getEnt(List<Integer> list) {
        HashMap hashMap = new HashMap();
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            int intValue = this.endList.get(it.next().intValue()).intValue();
            if (hashMap.containsKey(Integer.valueOf(intValue))) {
                hashMap.put(Integer.valueOf(intValue), Integer.valueOf(((Integer) hashMap.get(Integer.valueOf(intValue))).intValue() + 1));
            } else {
                hashMap.put(Integer.valueOf(intValue), 1);
            }
        }
        float f = 0.0f;
        Iterator it2 = hashMap.entrySet().iterator();
        while (it2.hasNext()) {
            float intValue2 = ((Integer) ((Map.Entry) it2.next()).getValue()).intValue() / list.size();
            f += intValue2 * log2(intValue2);
        }
        return -f;
    }

    private float getGain(float f, float f2, float f3) {
        return f3 + (f * f2);
    }

    private List<Node> createNode(Node node) {
        Set<String> set = node.attribute;
        List<Integer> list = node.fatherList;
        if (set.isEmpty()) {
            node.isEnd = true;
            node.type = getType(list);
            this.lastNodes.add(node);
            return null;
        }
        HashMap hashMap = new HashMap();
        float ent = getEnt(list);
        int size = list.size();
        for (int i = 0; i < list.size(); i++) {
            int intValue = list.get(i).intValue();
            for (String str : set) {
                if (!hashMap.containsKey(str)) {
                    hashMap.put(str, new HashMap());
                }
                Map map = (Map) hashMap.get(str);
                int intValue2 = this.table.get(str).get(intValue).intValue();
                if (!map.containsKey(Integer.valueOf(intValue2))) {
                    map.put(Integer.valueOf(intValue2), new ArrayList());
                }
                ((List) map.get(Integer.valueOf(intValue2))).add(Integer.valueOf(intValue));
            }
        }
        HashMap hashMap2 = new HashMap();
        int i2 = 0;
        float f = 0.0f;
        HashMap hashMap3 = new HashMap();
        for (Map.Entry entry : hashMap.entrySet()) {
            Map map2 = (Map) entry.getValue();
            float f2 = 0.0f;
            float f3 = 0.0f;
            ArrayList arrayList = new ArrayList();
            String str2 = (String) entry.getKey();
            hashMap2.put(str2, arrayList);
            for (Map.Entry entry2 : map2.entrySet()) {
                Set<String> removeAttribute = removeAttribute(set, str2);
                Node node2 = new Node();
                arrayList.add(node2);
                node2.attribute = removeAttribute;
                List<Integer> list2 = (List) entry2.getValue();
                node2.fatherList = list2;
                node2.typeId = ((Integer) entry2.getKey()).intValue();
                int size2 = list2.size();
                float ent2 = getEnt(list2);
                float f4 = size2 / size;
                f3 = (f4 * log2(f4)) + f3;
                f2 = getGain(ent2, f4, f2);
            }
            Gain gain = new Gain();
            hashMap3.put(str2, gain);
            gain.gain = ent - f2;
            if (f3 != 0.0f) {
                gain.gainRatio = gain.gain / (-f3);
            } else {
                gain.gainRatio = 1000000.0f;
            }
            f = gain.gain + f;
            i2++;
        }
        float f5 = f / i2;
        float f6 = -2.0f;
        String str3 = null;
        for (Map.Entry entry3 : hashMap3.entrySet()) {
            Gain gain2 = (Gain) entry3.getValue();
            if (hashMap3.size() != 1) {
                if (gain2.gain >= f5 || Math.abs(gain2.gain - f5) < 1.0E-6d) {
                    if (gain2.gainRatio < f6 && f6 != -2.0f) {
                    }
                }
            }
            f6 = gain2.gainRatio;
            str3 = (String) entry3.getKey();
        }
        node.key = str3;
        List<Node> list3 = (List) hashMap2.get(str3);
        for (int i3 = 0; i3 < list3.size(); i3++) {
            list3.get(i3).fatherNode = node;
        }
        for (int i4 = 0; i4 < list3.size(); i4++) {
            Node node3 = list3.get(i4);
            node3.nodeList = createNode(node3);
        }
        return list3;
    }

    private int getType(List<Integer> list) {
        HashMap hashMap = new HashMap();
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            int intValue = this.endList.get(it.next().intValue()).intValue();
            if (hashMap.containsKey(Integer.valueOf(intValue))) {
                hashMap.put(Integer.valueOf(intValue), Integer.valueOf(((Integer) hashMap.get(Integer.valueOf(intValue))).intValue() + 1));
            } else {
                hashMap.put(Integer.valueOf(intValue), 1);
            }
        }
        int i = 0;
        int i2 = 0;
        for (Map.Entry entry : hashMap.entrySet()) {
            int intValue2 = ((Integer) entry.getValue()).intValue();
            if (intValue2 > i2) {
                i = ((Integer) entry.getKey()).intValue();
                i2 = intValue2;
            }
        }
        return i;
    }

    private Set<String> removeAttribute(Set<String> set, String str) {
        HashSet hashSet = new HashSet();
        for (String str2 : set) {
            if (!str2.equals(str)) {
                hashSet.add(str2);
            }
        }
        return hashSet;
    }

    private int getTypeId(Object obj, String str) throws Exception {
        return Integer.parseInt(obj.getClass().getMethod("get" + str.substring(0, 1).toUpperCase() + str.substring(1), new Class[0]).invoke(obj, new Object[0]).toString());
    }

    public TreeWithTrust judge(Object obj) throws Exception {
        if (this.rootNode == null) {
            throw new Exception("rootNode is null");
        }
        TreeWithTrust treeWithTrust = new TreeWithTrust();
        treeWithTrust.setTrust(1.0f);
        goTree(obj, this.rootNode, treeWithTrust, 0);
        return treeWithTrust;
    }

    private void punishment(TreeWithTrust treeWithTrust) {
        treeWithTrust.setTrust(treeWithTrust.getTrust() * this.trustPunishment);
    }

    private void goTree(Object obj, Node node, TreeWithTrust treeWithTrust, int i) throws Exception {
        if (node.isEnd) {
            if (node.typeId == 0) {
                int size = this.rootNode.attribute.size() - i;
                for (int i2 = 0; i2 < size; i2++) {
                    punishment(treeWithTrust);
                }
            }
            treeWithTrust.setType(node.type);
            return;
        }
        int typeId = getTypeId(obj, node.key);
        if (typeId == 0) {
            punishment(treeWithTrust);
        }
        List<Node> list = node.nodeList;
        boolean z = false;
        Iterator<Node> it = list.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            Node next = it.next();
            if (next.typeId == typeId) {
                z = true;
                node = next;
                break;
            }
        }
        if (!z) {
            punishment(treeWithTrust);
            node = list.get(this.random.nextInt(list.size()));
        }
        goTree(obj, node, treeWithTrust, i + 1);
    }

    public void study() throws Exception {
        if (this.dataTable == null || this.dataTable.getLength() <= 0) {
            throw new Exception("dataTable is null");
        }
        this.rootNode = new Node();
        this.table = this.dataTable.getTable();
        this.endList = this.dataTable.getTable().get(this.dataTable.getKey());
        Set<String> keyType = this.dataTable.getKeyType();
        keyType.remove(this.dataTable.getKey());
        this.rootNode.attribute = keyType;
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.endList.size(); i++) {
            arrayList.add(Integer.valueOf(i));
        }
        this.rootNode.fatherList = arrayList;
        this.rootNode.nodeList = createNode(this.rootNode);
        Iterator<Node> it = this.lastNodes.iterator();
        while (it.hasNext()) {
            prune(it.next().fatherNode);
        }
        this.lastNodes.clear();
    }

    private void prune(Node node) {
        if (node == null || node.isEnd || !isPrune(node, node.nodeList)) {
            return;
        }
        deduction(node);
        prune(node.fatherNode);
    }

    private void deduction(Node node) {
        node.isEnd = true;
        node.nodeList = null;
        node.type = getType(node.fatherList);
    }

    private boolean isPrune(Node node, List<Node> list) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(Integer.valueOf(getType(list.get(i).fatherList)));
        }
        float rightPoint = getRightPoint(node.fatherList, getType(node.fatherList)) / node.fatherList.size();
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < list.size(); i4++) {
            List<Integer> list2 = list.get(i4).fatherList;
            i2 += getRightPoint(list2, ((Integer) arrayList.get(i4)).intValue());
            i3 += list2.size();
        }
        return i2 / i3 <= rightPoint;
    }

    private int getRightPoint(List<Integer> list, int i) {
        int i2 = 0;
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            if (this.endList.get(it.next().intValue()).intValue() == i) {
                i2++;
            }
        }
        return i2;
    }
}
