/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.confignode.manager.load.balancer.router.leader;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId;
import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation;
import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet;
import org.apache.iotdb.confignode.manager.load.balancer.router.leader.ILeaderBalancer;

public class MinCostFlowLeaderBalancer
implements ILeaderBalancer {
    private static final int INFINITY = Integer.MAX_VALUE;
    private final Map<TConsensusGroupId, TRegionReplicaSet> regionReplicaSetMap = new HashMap<TConsensusGroupId, TRegionReplicaSet>();
    private final Map<TConsensusGroupId, Integer> regionLeaderMap = new HashMap<TConsensusGroupId, Integer>();
    private final Set<Integer> disabledDataNodeSet = new HashSet<Integer>();
    private static final int sNode = 0;
    private static final int tNode = 1;
    private int maxNode = 2;
    private final Map<TConsensusGroupId, Integer> rNodeMap = new HashMap<TConsensusGroupId, Integer>();
    private final Map<Integer, Integer> dNodeMap = new HashMap<Integer, Integer>();
    private final Map<Integer, Integer> dNodeReflect = new HashMap<Integer, Integer>();
    private int maxEdge = 0;
    private final List<MinCostFlowEdge> minCostFlowEdges = new ArrayList<MinCostFlowEdge>();
    private int[] nodeHeadEdge;
    private int[] nodeCurrentEdge;
    private boolean[] isNodeVisited;
    private int[] nodeMinimumCost;
    private int maximumFlow = 0;
    private int minimumCost = 0;

    @Override
    public Map<TConsensusGroupId, Integer> generateOptimalLeaderDistribution(Map<TConsensusGroupId, TRegionReplicaSet> regionReplicaSetMap, Map<TConsensusGroupId, Integer> regionLeaderMap, Set<Integer> disabledDataNodeSet) {
        this.initialize(regionReplicaSetMap, regionLeaderMap, disabledDataNodeSet);
        this.constructMCFGraph();
        this.dinicAlgorithm();
        Map<TConsensusGroupId, Integer> result = this.collectLeaderDistribution();
        this.clear();
        return result;
    }

    private void initialize(Map<TConsensusGroupId, TRegionReplicaSet> regionReplicaSetMap, Map<TConsensusGroupId, Integer> regionLeaderMap, Set<Integer> disabledDataNodeSet) {
        this.regionReplicaSetMap.putAll(regionReplicaSetMap);
        this.regionLeaderMap.putAll(regionLeaderMap);
        this.disabledDataNodeSet.addAll(disabledDataNodeSet);
    }

    private void clear() {
        this.regionReplicaSetMap.clear();
        this.regionLeaderMap.clear();
        this.disabledDataNodeSet.clear();
        this.rNodeMap.clear();
        this.dNodeMap.clear();
        this.dNodeReflect.clear();
        this.minCostFlowEdges.clear();
        this.nodeHeadEdge = null;
        this.nodeCurrentEdge = null;
        this.isNodeVisited = null;
        this.nodeMinimumCost = null;
        this.maxNode = 2;
        this.maxEdge = 0;
    }

    private void constructMCFGraph() {
        this.maximumFlow = 0;
        this.minimumCost = 0;
        for (TRegionReplicaSet regionReplicaSet2 : this.regionReplicaSetMap.values()) {
            this.rNodeMap.put(regionReplicaSet2.getRegionId(), this.maxNode++);
            for (Object dataNodeLocation : regionReplicaSet2.getDataNodeLocations()) {
                if (this.dNodeMap.containsKey(dataNodeLocation.getDataNodeId())) continue;
                this.dNodeMap.put(dataNodeLocation.getDataNodeId(), this.maxNode);
                this.dNodeReflect.put(this.maxNode, dataNodeLocation.getDataNodeId());
                ++this.maxNode;
            }
        }
        this.isNodeVisited = new boolean[this.maxNode];
        this.nodeMinimumCost = new int[this.maxNode];
        this.nodeCurrentEdge = new int[this.maxNode];
        this.nodeHeadEdge = new int[this.maxNode];
        Arrays.fill(this.nodeHeadEdge, -1);
        Iterator<Object> iterator = this.rNodeMap.values().iterator();
        while (iterator.hasNext()) {
            int rNode = (Integer)iterator.next();
            this.addAdjacentEdges(0, rNode, 1, 0);
        }
        for (TRegionReplicaSet regionReplicaSet3 : this.regionReplicaSetMap.values()) {
            int rNode = this.rNodeMap.get(regionReplicaSet3.getRegionId());
            for (TDataNodeLocation dataNodeLocation : regionReplicaSet3.getDataNodeLocations()) {
                int dNode = this.dNodeMap.get(dataNodeLocation.getDataNodeId());
                int cost = this.regionLeaderMap.getOrDefault(regionReplicaSet3.getRegionId(), -1).intValue() == dataNodeLocation.getDataNodeId() ? 0 : 1;
                this.addAdjacentEdges(rNode, dNode, 1, cost);
            }
        }
        ConcurrentHashMap maxLeaderCounter = new ConcurrentHashMap();
        this.regionReplicaSetMap.values().forEach(regionReplicaSet -> regionReplicaSet.getDataNodeLocations().forEach(dataNodeLocation -> maxLeaderCounter.computeIfAbsent(dataNodeLocation.getDataNodeId(), empty -> new AtomicInteger(0)).getAndIncrement()));
        for (Map.Entry<Integer, Integer> dNodeEntry : this.dNodeMap.entrySet()) {
            int dataNodeId = dNodeEntry.getKey();
            int dNode = dNodeEntry.getValue();
            if (this.disabledDataNodeSet.contains(dataNodeId)) continue;
            int maxLeaderCount = ((AtomicInteger)maxLeaderCounter.get(dataNodeId)).get();
            for (int extraEdge = 1; extraEdge <= maxLeaderCount; ++extraEdge) {
                this.addAdjacentEdges(dNode, 1, 1, extraEdge * extraEdge);
            }
        }
    }

    private void addAdjacentEdges(int fromNode, int destNode, int capacity, int cost) {
        this.addEdge(fromNode, destNode, capacity, cost);
        this.addEdge(destNode, fromNode, 0, -cost);
    }

    private void addEdge(int fromNode, int destNode, int capacity, int cost) {
        MinCostFlowEdge edge = new MinCostFlowEdge(destNode, capacity, cost, this.nodeHeadEdge[fromNode]);
        this.minCostFlowEdges.add(edge);
        ++this.maxEdge;
    }

    private boolean bellmanFordCheck() {
        Arrays.fill(this.isNodeVisited, false);
        Arrays.fill(this.nodeMinimumCost, Integer.MAX_VALUE);
        LinkedList<Integer> queue = new LinkedList<Integer>();
        this.nodeMinimumCost[0] = 0;
        this.isNodeVisited[0] = true;
        queue.offer(0);
        while (!queue.isEmpty()) {
            int currentNode = (Integer)queue.poll();
            this.isNodeVisited[currentNode] = false;
            int currentEdge = this.nodeHeadEdge[currentNode];
            while (currentEdge >= 0) {
                MinCostFlowEdge edge = this.minCostFlowEdges.get(currentEdge);
                if (edge.capacity > 0 && this.nodeMinimumCost[currentNode] + edge.cost < this.nodeMinimumCost[edge.destNode]) {
                    this.nodeMinimumCost[((MinCostFlowEdge)edge).destNode] = this.nodeMinimumCost[currentNode] + edge.cost;
                    if (!this.isNodeVisited[edge.destNode]) {
                        this.isNodeVisited[((MinCostFlowEdge)edge).destNode] = true;
                        queue.offer(edge.destNode);
                    }
                }
                currentEdge = this.minCostFlowEdges.get(currentEdge).nextEdge;
            }
        }
        return this.nodeMinimumCost[1] < Integer.MAX_VALUE;
    }

    private int dfsAugmentation(int currentNode, int inputFlow) {
        if (currentNode == 1 || inputFlow == 0) {
            return inputFlow;
        }
        int outputFlow = 0;
        this.isNodeVisited[currentNode] = true;
        int currentEdge = this.nodeCurrentEdge[currentNode];
        while (currentEdge >= 0) {
            MinCostFlowEdge edge = this.minCostFlowEdges.get(currentEdge);
            if (this.nodeMinimumCost[currentNode] + edge.cost == this.nodeMinimumCost[edge.destNode] && edge.capacity > 0 && !this.isNodeVisited[edge.destNode]) {
                int subOutputFlow = this.dfsAugmentation(edge.destNode, Math.min(inputFlow, edge.capacity));
                this.minimumCost += subOutputFlow * edge.cost;
                edge.capacity -= subOutputFlow;
                this.minCostFlowEdges.get(currentEdge ^ 1).capacity += subOutputFlow;
                outputFlow += subOutputFlow;
                if ((inputFlow -= subOutputFlow) == 0) break;
            }
            currentEdge = this.minCostFlowEdges.get(currentEdge).nextEdge;
        }
        this.nodeCurrentEdge[currentNode] = currentEdge;
        if (outputFlow > 0) {
            this.isNodeVisited[currentNode] = false;
        }
        return outputFlow;
    }

    private void dinicAlgorithm() {
        while (this.bellmanFordCheck()) {
            int currentFlow;
            System.arraycopy(this.nodeHeadEdge, 0, this.nodeCurrentEdge, 0, this.maxNode);
            while ((currentFlow = this.dfsAugmentation(0, Integer.MAX_VALUE)) > 0) {
                this.maximumFlow += currentFlow;
            }
        }
    }

    private Map<TConsensusGroupId, Integer> collectLeaderDistribution() {
        ConcurrentHashMap<TConsensusGroupId, Integer> result = new ConcurrentHashMap<TConsensusGroupId, Integer>();
        this.rNodeMap.forEach((regionGroupId, rNode) -> {
            boolean matchLeader = false;
            int currentEdge = this.nodeHeadEdge[rNode];
            while (currentEdge >= 0) {
                MinCostFlowEdge edge = this.minCostFlowEdges.get(currentEdge);
                if (edge.destNode != 0 && edge.capacity == 0) {
                    matchLeader = true;
                    result.put((TConsensusGroupId)regionGroupId, this.dNodeReflect.get(edge.destNode));
                }
                currentEdge = this.minCostFlowEdges.get(currentEdge).nextEdge;
            }
            if (!matchLeader) {
                result.put((TConsensusGroupId)regionGroupId, this.regionLeaderMap.getOrDefault(regionGroupId, -1));
            }
        });
        return result;
    }

    public int getMaximumFlow() {
        return this.maximumFlow;
    }

    public int getMinimumCost() {
        return this.minimumCost;
    }

    private static class MinCostFlowEdge {
        private final int destNode;
        private int capacity;
        private final int cost;
        private final int nextEdge;

        private MinCostFlowEdge(int destNode, int capacity, int cost, int nextEdge) {
            this.destNode = destNode;
            this.capacity = capacity;
            this.cost = cost;
            this.nextEdge = nextEdge;
        }
    }
}

