package org.apache.flink.runtime.resourceestimator;

import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.rescaling.EstimatorMetricsCommitter;
import org.apache.flink.runtime.rescaling.RescalingUtils;
import org.apache.flink.runtime.resourceestimator.predictions.PredictionCoordinator;
import org.apache.flink.runtime.resourceestimator.predictions.VertexEstimations;
import org.apache.flink.runtime.util.LogCallStackUtils;
import org.apache.flink.util.CollectionUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/runtime/resourceestimator/DataFlowEstimator.class */
public class DataFlowEstimator {
    private static final Logger LOG = LoggerFactory.getLogger(DataFlowEstimator.class);

    @Nonnull
    private final PredictionCoordinator coordinator;
    private final ScheduledExecutorService modelUpdater;
    private final ScheduledFuture<?> modelUpdateFuture;
    private final long timeout = 15000;
    private JobGraph oldGraph;
    private Optional<EstimatorMetricsCommitter<DataFlowEstimatorMetrics>> committer;
    private Map<JobVertexID, DataFlowEstimatorMetrics> metricSnapshot;

    public DataFlowEstimator(PredictionCoordinator predictionCoordinator, @Nonnull JobGraph jobGraph) {
        this(predictionCoordinator, jobGraph, null);
    }

    public DataFlowEstimator(PredictionCoordinator predictionCoordinator, @Nonnull JobGraph jobGraph, @Nullable EstimatorMetricsCommitter<DataFlowEstimatorMetrics> estimatorMetricsCommitter) {
        this.modelUpdater = Executors.newScheduledThreadPool(1);
        this.timeout = 15000L;
        this.metricSnapshot = new HashMap();
        this.coordinator = predictionCoordinator;
        predictionCoordinator.addPredictableList(jobGraph.getVerticesSortedTopologicallyFromSources());
        this.oldGraph = jobGraph;
        this.modelUpdateFuture = this.modelUpdater.scheduleWithFixedDelay(() -> {
            synchronized (predictionCoordinator) {
                try {
                    predictionCoordinator.updateModels(this.oldGraph);
                } catch (Exception e) {
                    LOG.error("Error during metric update " + e.getClass() + ": " + e.getMessage());
                }
            }
        }, 15000L, 15000L, TimeUnit.MILLISECONDS);
        this.committer = Optional.ofNullable(estimatorMetricsCommitter);
        this.committer.ifPresent(estimatorMetricsCommitter2 -> {
            estimatorMetricsCommitter2.init(jobGraph);
        });
    }

    public Map<JobVertexID, Integer> evaluate(JobGraph jobGraph, int i, int i2) {
        Map<JobVertexID, VertexEstimations> jobVerticesInformation;
        LOG.info("Evaluating for an old graph at: " + System.currentTimeMillis());
        Map map = (Map) jobGraph.getVerticesSortedTopologicallyFromSources().stream().collect(Collectors.toMap((v0) -> {
            return v0.getID();
        }, (v0) -> {
            return v0.getParallelism();
        }));
        synchronized (this.coordinator) {
            jobVerticesInformation = this.coordinator.getJobVerticesInformation(86400L);
        }
        if (LOG.isInfoEnabled()) {
            LOG.info((String) jobGraph.getVerticesSortedTopologicallyFromSources().stream().map(jobVertex -> {
                return jobVertex.getName() + ": " + ((VertexEstimations) jobVerticesInformation.get(jobVertex.getID())).toString();
            }).collect(Collectors.joining(",\n")));
        }
        Map<JobVertexID, Integer> calcParallelism = calcParallelism(jobGraph, jobVerticesInformation, i, i2, this.committer.isPresent());
        this.committer.ifPresent(estimatorMetricsCommitter -> {
            estimatorMetricsCommitter.commitInformation(this.metricSnapshot);
        });
        if (CollectionUtil.isNullOrEmpty(calcParallelism)) {
            LOG.info("Fail to calculate new parallelism for job {}", jobGraph.getJobID());
            return calcParallelism;
        }
        if (LOG.isInfoEnabled()) {
            if (this.committer.isPresent()) {
                StringBuilder sb = new StringBuilder();
                for (JobVertex jobVertex2 : jobGraph.getVerticesSortedTopologicallyFromSources()) {
                    DataFlowEstimatorMetrics dataFlowEstimatorMetrics = this.metricSnapshot.get(jobVertex2.getID());
                    sb.append(jobVertex2.getName()).append(": inputFlow=").append(dataFlowEstimatorMetrics.getExpectedInputFlow()).append(": virtualFlow=").append(dataFlowEstimatorMetrics.getExpectedVirtualFlow()).append("; outputFlow=").append(dataFlowEstimatorMetrics.getExpectedOutputFlow()).append("; throughput=").append(dataFlowEstimatorMetrics.getThroughput()).append("; extraTasks=").append(dataFlowEstimatorMetrics.getExtraDistributedTasks()).append("\n");
                }
                LOG.info("Intermediate results: " + ((Object) sb));
            }
            StringBuilder sb2 = new StringBuilder();
            for (JobVertex jobVertex3 : jobGraph.getVerticesSortedTopologicallyFromSources()) {
                sb2.append(jobVertex3.getName()).append(": ").append(calcParallelism.get(jobVertex3.getID()).intValue()).append("/").append(jobVertex3.getMaxParallelism()).append("\n");
            }
            LOG.info("Calculated parallelism: " + ((Object) sb2));
        }
        if (!map.equals(calcParallelism)) {
            LOG.info("Rescaled Graph: " + calcParallelism);
            return calcParallelism;
        }
        LOG.info("After calculation, parallelism for job {} needs no change", jobGraph.getJobID());
        calcParallelism.clear();
        return calcParallelism;
    }

    @VisibleForTesting
    public Map<JobVertexID, Integer> calcParallelism(JobGraph jobGraph, Map<JobVertexID, VertexEstimations> map, long j, long j2, boolean z) {
        if (z) {
            rewriteSnapshots(jobGraph);
            populateEstimationsInformation(jobGraph, map);
        }
        if (RescalingUtils.calculateMinimalSlotRequirements(jobGraph).intValue() > j2) {
            LOG.info("Cannot fit the {} limitation", Long.valueOf(j2));
            HashMap hashMap = new HashMap();
            Iterator<JobVertex> it = jobGraph.getVerticesSortedTopologicallyFromSources().iterator();
            while (it.hasNext()) {
                hashMap.put(it.next().getID(), 1);
            }
            return hashMap;
        }
        Map<JobVertexID, Integer> calculateParallelismForRate = calculateParallelismForRate(jobGraph, map, j2, 1.0d, z);
        if (calculateParallelismForRate == null) {
            double d = 0.0d;
            double d2 = 1.0d;
            for (int i = 0; i < 30; i++) {
                double d3 = (d + d2) / 2.0d;
                if (calculateParallelismForRate(jobGraph, map, j2, d3, false) == null) {
                    d2 = d3;
                } else {
                    d = d3;
                }
            }
            LOG.info("Estimated data rate: {}", Double.valueOf(d));
            return calculateParallelismForRate(jobGraph, map, j2, d, z);
        }
        try {
            LOG.info("Can process whole data stream using at most {} slots", Long.valueOf(j2));
            if (RescalingUtils.calculateSlotRequirements(jobGraph, calculateParallelismForRate).intValue() >= j) {
                LOG.info("Rescaled job parallelism fits minimal {} slots requirement.", Long.valueOf(j));
                return calculateParallelismForRate;
            }
            double d4 = 1.0d;
            double d5 = 1.0E18d;
            for (int i2 = 0; i2 < 90; i2++) {
                double d6 = (d4 + d5) / 2.0d;
                if ((calculateParallelismForRate(jobGraph, map, j2, d6, false) == null ? j2 + 1 : RescalingUtils.calculateSlotRequirements(jobGraph, r0).intValue()) < j) {
                    d4 = d6;
                } else {
                    d5 = d6;
                }
            }
            LOG.info("Estimated data rate is {}", Double.valueOf(d4));
            Map<JobVertexID, Integer> calculateParallelismForRate2 = calculateParallelismForRate(jobGraph, map, j2, d4, z);
            if (calculateParallelismForRate2 == null) {
                LOG.error("Could not scale out job to fit minimal quota of {}", Long.valueOf(j));
                return null;
            }
            int intValue = RescalingUtils.calculateSlotRequirements(jobGraph, calculateParallelismForRate2).intValue();
            LOG.info("Now job requires {} slots to run", Integer.valueOf(intValue));
            if (intValue < j) {
                LOG.info("Distributing extra slots to fit the {} requirement", Long.valueOf(j));
                HashMap hashMap2 = z ? new HashMap() : null;
                boolean distributeExtraSlot = RescalingUtils.distributeExtraSlot(jobGraph, calculateParallelismForRate2, j, Optional.ofNullable(hashMap2));
                if (z) {
                    populateExtraDistribution(jobGraph, hashMap2);
                }
                if (!distributeExtraSlot) {
                    LOG.info("Could not tune job to fit slot quota limit");
                    return null;
                }
            }
            return calculateParallelismForRate2;
        } catch (Exception e) {
            LOG.error("Could not scale out job to fit minimal quota of {}", Long.valueOf(j));
            LOG.info(LogCallStackUtils.logCallStack(e.getStackTrace()));
            return null;
        }
    }

    private Map<JobVertexID, Integer> calculateParallelismForRate(JobGraph jobGraph, Map<JobVertexID, VertexEstimations> map, long j, double d, boolean z) {
        int max;
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        HashMap hashMap4 = new HashMap();
        for (JobVertex jobVertex : jobGraph.getVerticesSortedTopologicallyFromSources()) {
            VertexEstimations vertexEstimations = map.get(jobVertex.getID());
            double extraInputPerSecond = vertexEstimations.getExtraInputPerSecond() * d;
            double d2 = 4.0d;
            double throughput = vertexEstimations.getThroughput();
            for (JobEdge jobEdge : jobVertex.getInputs()) {
                extraInputPerSecond += ((Double) hashMap3.get(jobEdge.getSource().getProducer().getID())).doubleValue();
                d2 *= ((Double) hashMap3.get(jobEdge.getSource().getProducer().getID())).doubleValue() / 4.0d;
            }
            double extraInputPerSecond2 = jobVertex.getInputs().size() == 0 ? vertexEstimations.getExtraInputPerSecond() * d : d2 + (vertexEstimations.getExtraInputPerSecond() * d);
            double outputRate = vertexEstimations.getOutputRate();
            if (jobVertex.isInputVertex()) {
                if (d > 1.0d) {
                    extraInputPerSecond = Math.max(1.0E-9d, extraInputPerSecond);
                }
                if (jobVertex.getMaxParallelism() == 1) {
                    throughput = Math.max(throughput, extraInputPerSecond);
                    if (z) {
                        modifyVertexThroughput(jobVertex.getID(), throughput);
                    }
                }
            }
            double d3 = extraInputPerSecond * outputRate;
            if (jobVertex.canScale()) {
                double ceil = Math.ceil(extraInputPerSecond2 / throughput);
                if (32768.0d < ceil) {
                    return null;
                }
                max = Math.max(1, (int) ceil);
            } else {
                max = jobVertex.getParallelism();
                d3 = extraInputPerSecond * Math.min(1.0d, (throughput * jobVertex.getParallelism()) / extraInputPerSecond2) * outputRate;
            }
            if (jobVertex.getMaxParallelism() != -1 && jobVertex.getMaxParallelism() < max) {
                LOG.info("Failed to assign parallelism for rate: {}, vertex: {}, estimated: {}, limit: {}", new Object[]{Double.valueOf(d), jobVertex.getName(), Integer.valueOf(max), Integer.valueOf(jobVertex.getMaxParallelism())});
                return null;
            }
            hashMap4.put(jobVertex.getID(), Integer.valueOf(max));
            hashMap3.put(jobVertex.getID(), Double.valueOf(d3));
            hashMap.put(jobVertex.getID(), Double.valueOf(extraInputPerSecond));
            hashMap2.put(jobVertex.getID(), Double.valueOf(extraInputPerSecond2));
        }
        if (RescalingUtils.calculateSlotRequirements(jobGraph, hashMap4).intValue() > j) {
            return null;
        }
        if (z) {
            populateFlowStageInformation(jobGraph, hashMap, hashMap2, hashMap3, d);
        }
        return hashMap4;
    }

    private void rewriteSnapshots(JobGraph jobGraph) {
        Iterator<JobVertex> it = jobGraph.getVerticesSortedTopologicallyFromSources().iterator();
        while (it.hasNext()) {
            this.metricSnapshot.put(it.next().getID(), new DataFlowEstimatorMetrics(0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0, 0.0d));
        }
    }

    private void modifyVertexThroughput(JobVertexID jobVertexID, double d) {
        this.metricSnapshot.get(jobVertexID).setThroughput(d);
    }

    private void populateEstimationsInformation(JobGraph jobGraph, Map<JobVertexID, VertexEstimations> map) {
        Iterator<JobVertex> it = jobGraph.getVerticesSortedTopologicallyFromSources().iterator();
        while (it.hasNext()) {
            JobVertexID id = it.next().getID();
            DataFlowEstimatorMetrics dataFlowEstimatorMetrics = this.metricSnapshot.get(id);
            dataFlowEstimatorMetrics.setExtraInputPerSec(map.get(id).getExtraInputPerSecond());
            dataFlowEstimatorMetrics.setThroughput(map.get(id).getThroughput());
            dataFlowEstimatorMetrics.setOutputRate(map.get(id).getOutputRate());
        }
    }

    private void populateFlowStageInformation(JobGraph jobGraph, Map<JobVertexID, Double> map, Map<JobVertexID, Double> map2, Map<JobVertexID, Double> map3, double d) {
        Iterator<JobVertex> it = jobGraph.getVerticesSortedTopologicallyFromSources().iterator();
        while (it.hasNext()) {
            JobVertexID id = it.next().getID();
            DataFlowEstimatorMetrics dataFlowEstimatorMetrics = this.metricSnapshot.get(id);
            dataFlowEstimatorMetrics.setExpectedInputFlow(map.get(id).doubleValue());
            dataFlowEstimatorMetrics.setExpectedVirtualFlow(map2.get(id).doubleValue());
            dataFlowEstimatorMetrics.setExpectedOutputFlow(map3.get(id).doubleValue());
            dataFlowEstimatorMetrics.setSourceInputMultiplier(d);
        }
    }

    private void populateExtraDistribution(JobGraph jobGraph, Map<JobVertexID, Integer> map) {
        Iterator<JobVertex> it = jobGraph.getVerticesSortedTopologicallyFromSources().iterator();
        while (it.hasNext()) {
            JobVertexID id = it.next().getID();
            this.metricSnapshot.get(id).setExpectedExtraDistributedTasks(map.getOrDefault(id, 0).intValue());
        }
    }

    public void updateParallelism(JobGraph jobGraph) {
        HashMap hashMap = new HashMap();
        for (JobVertex jobVertex : jobGraph.getVerticesSortedTopologicallyFromSources()) {
            hashMap.put(jobVertex.getID(), Integer.valueOf(jobVertex.getParallelism()));
        }
        synchronized (this.coordinator) {
            this.coordinator.updateParallelism(hashMap);
            this.coordinator.cleanMetricsForJob(this.oldGraph.getJobID());
            this.oldGraph = jobGraph;
        }
    }

    public void cancel() {
        this.modelUpdateFuture.cancel(true);
        this.modelUpdater.shutdown();
    }

    @VisibleForTesting
    public Map<JobVertexID, DataFlowEstimatorMetrics> getMetricSnapshot() {
        return this.metricSnapshot;
    }

    public Map<String, String> getJobVertexRescalingInfo(JobVertexID jobVertexID) {
        synchronized (this) {
            if (!this.metricSnapshot.containsKey(jobVertexID)) {
                return Collections.emptyMap();
            }
            HashMap hashMap = new HashMap();
            hashMap.put("EstimatedExtraInputPerSecond", String.valueOf(this.metricSnapshot.get(jobVertexID).getExtraInputPerSec()));
            hashMap.put("EstimatedThroughput", String.valueOf(this.metricSnapshot.get(jobVertexID).getThroughput()));
            hashMap.put("EstimatedOutputRate", String.valueOf(this.metricSnapshot.get(jobVertexID).getOutputRate()));
            hashMap.put("DerivedInputFlow", String.valueOf(this.metricSnapshot.get(jobVertexID).getExpectedInputFlow()));
            hashMap.put("DerivedVirtualFlow", String.valueOf(this.metricSnapshot.get(jobVertexID).getExpectedVirtualFlow()));
            hashMap.put("DerivedOutputFlow", String.valueOf(this.metricSnapshot.get(jobVertexID).getExpectedOutputFlow()));
            hashMap.put("ExtraDistributedTasks", String.valueOf(this.metricSnapshot.get(jobVertexID).getExtraDistributedTasks()));
            return hashMap;
        }
    }
}
