package org.apache.flink.runtime.iterative.task;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.configuration2.tree.DefaultExpressionEngineSymbols;
import org.apache.flink.api.common.functions.Function;
import org.apache.flink.api.common.operators.util.JoinHashMap;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializerFactory;
import org.apache.flink.core.io.IOReadableWritable;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.io.disk.InputViewIterator;
import org.apache.flink.runtime.io.network.TaskEventDispatcher;
import org.apache.flink.runtime.io.network.api.EndOfSuperstepEvent;
import org.apache.flink.runtime.io.network.api.writer.RecordWriter;
import org.apache.flink.runtime.io.network.api.writer.RecordWriterBuilder;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.iterative.concurrent.BlockingBackChannel;
import org.apache.flink.runtime.iterative.concurrent.BlockingBackChannelBroker;
import org.apache.flink.runtime.iterative.concurrent.IterationAggregatorBroker;
import org.apache.flink.runtime.iterative.concurrent.SolutionSetBroker;
import org.apache.flink.runtime.iterative.concurrent.SolutionSetUpdateBarrier;
import org.apache.flink.runtime.iterative.concurrent.SolutionSetUpdateBarrierBroker;
import org.apache.flink.runtime.iterative.concurrent.SuperstepBarrier;
import org.apache.flink.runtime.iterative.concurrent.SuperstepKickoffLatch;
import org.apache.flink.runtime.iterative.concurrent.SuperstepKickoffLatchBroker;
import org.apache.flink.runtime.iterative.event.AllWorkersDoneEvent;
import org.apache.flink.runtime.iterative.event.TerminationEvent;
import org.apache.flink.runtime.iterative.event.WorkerDoneEvent;
import org.apache.flink.runtime.iterative.io.SerializedUpdateBuffer;
import org.apache.flink.runtime.operators.BatchTask;
import org.apache.flink.runtime.operators.hash.CompactingHashTable;
import org.apache.flink.runtime.operators.util.TaskConfig;
import org.apache.flink.util.Collector;
import org.apache.flink.util.MutableObjectIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/runtime/iterative/task/IterationHeadTask.class */
public class IterationHeadTask<X, Y, S extends Function, OT> extends AbstractIterativeTask<S, OT> {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) IterationHeadTask.class);
    private Collector<X> finalOutputCollector;
    private TypeSerializerFactory<Y> feedbackTypeSerializer;
    private TypeSerializerFactory<X> solutionTypeSerializer;
    private RecordWriter<IOReadableWritable> toSync;
    private ResultPartitionID toSyncPartitionId;
    private int feedbackDataInput;

    public IterationHeadTask(Environment environment) {
        super(environment);
    }

    @Override // org.apache.flink.runtime.operators.BatchTask
    protected int getNumTaskInputs() {
        return this.driver.getNumberOfInputs() + (this.config.getIsWorksetIteration() ? 1 : 0);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.flink.runtime.operators.BatchTask
    public void initOutputs() throws Exception {
        super.initOutputs();
        ArrayList arrayList = new ArrayList();
        TaskConfig iterationHeadFinalOutputConfig = this.config.getIterationHeadFinalOutputConfig();
        this.finalOutputCollector = BatchTask.getOutputCollector(this, iterationHeadFinalOutputConfig, getUserCodeClassLoader(), arrayList, this.config.getNumOutputs(), iterationHeadFinalOutputConfig.getNumOutputs());
        int size = this.eventualOutputs.size();
        int size2 = arrayList.size();
        int iterationHeadIndexOfSyncOutput = this.config.getIterationHeadIndexOfSyncOutput();
        if (size + size2 != iterationHeadIndexOfSyncOutput) {
            throw new Exception("Error: Inconsistent head task setup - wrong mapping of output gates.");
        }
        this.toSync = new RecordWriterBuilder().build(getEnvironment().getWriter(iterationHeadIndexOfSyncOutput));
        this.toSyncPartitionId = getEnvironment().getWriter(iterationHeadIndexOfSyncOutput).getPartitionId();
    }

    private BlockingBackChannel initBackChannel() throws Exception {
        int computeNumberOfPages = getMemoryManager().computeNumberOfPages(this.config.getRelativeBackChannelMemory());
        ArrayList arrayList = new ArrayList();
        int pageSize = getMemoryManager().getPageSize();
        getMemoryManager().allocatePages(this, arrayList, computeNumberOfPages);
        BlockingBackChannel blockingBackChannel = new BlockingBackChannel(new SerializedUpdateBuffer(arrayList, pageSize, getIOManager()));
        BlockingBackChannelBroker.instance().handIn(brokerKey(), blockingBackChannel);
        return blockingBackChannel;
    }

    private <BT> CompactingHashTable<BT> initCompactingHashTable() throws Exception {
        double relativeSolutionSetMemory = this.config.getRelativeSolutionSetMemory();
        ClassLoader userCodeClassLoader = getUserCodeClassLoader();
        TypeSerializerFactory solutionSetSerializer = this.config.getSolutionSetSerializer(userCodeClassLoader);
        TypeComparatorFactory solutionSetComparator = this.config.getSolutionSetComparator(userCodeClassLoader);
        TypeSerializer serializer = solutionSetSerializer.getSerializer();
        TypeComparator createComparator = solutionSetComparator.createComparator();
        CompactingHashTable<BT> compactingHashTable = null;
        List<MemorySegment> list = null;
        boolean z = false;
        try {
            list = getMemoryManager().allocatePages(getContainingTask(), getMemoryManager().computeNumberOfPages(relativeSolutionSetMemory));
            compactingHashTable = new CompactingHashTable<>(serializer, createComparator, list);
            z = true;
            if (1 == 0) {
                if (compactingHashTable != null) {
                    try {
                        compactingHashTable.close();
                    } catch (Throwable th) {
                        log.error("Error closing the solution set hash table after unsuccessful creation.", th);
                    }
                }
                if (list != null) {
                    try {
                        getMemoryManager().release(list);
                    } catch (Throwable th2) {
                        log.error("Error freeing memory after error during solution set hash table creation.", th2);
                    }
                }
            }
            return compactingHashTable;
        } catch (Throwable th3) {
            if (!z) {
                if (compactingHashTable != null) {
                    try {
                        compactingHashTable.close();
                    } catch (Throwable th4) {
                        log.error("Error closing the solution set hash table after unsuccessful creation.", th4);
                    }
                }
                if (list != null) {
                    try {
                        getMemoryManager().release(list);
                    } catch (Throwable th5) {
                        log.error("Error freeing memory after error during solution set hash table creation.", th5);
                    }
                }
            }
            throw th3;
        }
    }

    private <BT> JoinHashMap<BT> initJoinHashMap() {
        return new JoinHashMap<>(this.config.getSolutionSetSerializer(getUserCodeClassLoader()).getSerializer(), this.config.getSolutionSetComparator(getUserCodeClassLoader()).createComparator());
    }

    private void readInitialSolutionSet(CompactingHashTable<X> compactingHashTable, MutableObjectIterator<X> mutableObjectIterator) throws IOException {
        compactingHashTable.open();
        compactingHashTable.buildTableWithUniqueKey(mutableObjectIterator);
    }

    private void readInitialSolutionSet(JoinHashMap<X> joinHashMap, MutableObjectIterator<X> mutableObjectIterator) throws IOException {
        TypeSerializer<X> serializer = this.solutionTypeSerializer.getSerializer();
        while (true) {
            X next = mutableObjectIterator.next(serializer.mo6078createInstance());
            if (next == null) {
                return;
            } else {
                joinHashMap.insertOrReplace(next);
            }
        }
    }

    private SuperstepBarrier initSuperstepBarrier() {
        SuperstepBarrier superstepBarrier = new SuperstepBarrier(getUserCodeClassLoader());
        TaskEventDispatcher taskEventDispatcher = getEnvironment().getTaskEventDispatcher();
        ResultPartitionID resultPartitionID = this.toSyncPartitionId;
        taskEventDispatcher.subscribeToEvent(resultPartitionID, superstepBarrier, AllWorkersDoneEvent.class);
        taskEventDispatcher.subscribeToEvent(resultPartitionID, superstepBarrier, TerminationEvent.class);
        return superstepBarrier;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.flink.runtime.iterative.task.AbstractIterativeTask, org.apache.flink.runtime.operators.BatchTask
    public void run() throws Exception {
        String brokerKey = brokerKey();
        int indexOfThisSubtask = getEnvironment().getTaskInfo().getIndexOfThisSubtask();
        boolean isSolutionSetUnmanaged = this.config.isSolutionSetUnmanaged();
        CompactingHashTable compactingHashTable = null;
        JoinHashMap joinHashMap = null;
        boolean waitForSolutionSetUpdate = this.config.getWaitForSolutionSetUpdate();
        boolean isWorksetIteration = this.config.getIsWorksetIteration();
        try {
            SuperstepKickoffLatch superstepKickoffLatch = new SuperstepKickoffLatch();
            SuperstepKickoffLatchBroker.instance().handIn(brokerKey, superstepKickoffLatch);
            BlockingBackChannel initBackChannel = initBackChannel();
            SuperstepBarrier initSuperstepBarrier = initSuperstepBarrier();
            SolutionSetUpdateBarrier solutionSetUpdateBarrier = null;
            this.feedbackDataInput = this.config.getIterationHeadPartialSolutionOrWorksetInputIndex();
            this.feedbackTypeSerializer = getInputSerializer(this.feedbackDataInput);
            excludeFromReset(this.feedbackDataInput);
            if (isWorksetIteration) {
                int iterationHeadSolutionSetInputIndex = this.config.getIterationHeadSolutionSetInputIndex();
                this.solutionTypeSerializer = this.config.getSolutionSetSerializer(getUserCodeClassLoader());
                MutableObjectIterator<?> createInputIterator = createInputIterator(this.inputReaders[iterationHeadSolutionSetInputIndex], this.solutionTypeSerializer);
                if (isSolutionSetUnmanaged) {
                    joinHashMap = initJoinHashMap();
                    readInitialSolutionSet(joinHashMap, createInputIterator);
                    SolutionSetBroker.instance().handIn(brokerKey, joinHashMap);
                } else {
                    compactingHashTable = initCompactingHashTable();
                    readInitialSolutionSet(compactingHashTable, createInputIterator);
                    SolutionSetBroker.instance().handIn(brokerKey, compactingHashTable);
                }
                if (waitForSolutionSetUpdate) {
                    solutionSetUpdateBarrier = new SolutionSetUpdateBarrier();
                    SolutionSetUpdateBarrierBroker.instance().handIn(brokerKey, solutionSetUpdateBarrier);
                }
            } else {
                this.solutionTypeSerializer = this.feedbackTypeSerializer;
                if (waitForSolutionSetUpdate) {
                    solutionSetUpdateBarrier = new SolutionSetUpdateBarrier();
                    SolutionSetUpdateBarrierBroker.instance().handIn(brokerKey, solutionSetUpdateBarrier);
                }
            }
            RuntimeAggregatorRegistry runtimeAggregatorRegistry = new RuntimeAggregatorRegistry(this.config.getIterationAggregators(getUserCodeClassLoader()));
            IterationAggregatorBroker.instance().handIn(brokerKey, runtimeAggregatorRegistry);
            DataInputView dataInputView = null;
            while (this.running && !terminationRequested()) {
                if (log.isInfoEnabled()) {
                    log.info(formatLogString("starting iteration [" + currentIteration() + DefaultExpressionEngineSymbols.DEFAULT_ATTRIBUTE_END));
                }
                initSuperstepBarrier.setup();
                if (waitForSolutionSetUpdate) {
                    solutionSetUpdateBarrier.setup();
                }
                if (!inFirstIteration()) {
                    feedBackSuperstepResult(dataInputView);
                }
                super.run();
                sendEndOfSuperstepToAllIterationOutputs();
                if (waitForSolutionSetUpdate) {
                    solutionSetUpdateBarrier.waitForSolutionSetUpdate();
                }
                dataInputView = initBackChannel.getReadEndAfterSuperstepEnded();
                if (log.isInfoEnabled()) {
                    log.info(formatLogString("finishing iteration [" + currentIteration() + DefaultExpressionEngineSymbols.DEFAULT_ATTRIBUTE_END));
                }
                sendEventToSync(new WorkerDoneEvent(indexOfThisSubtask, runtimeAggregatorRegistry.getAllAggregators()));
                if (log.isInfoEnabled()) {
                    log.info(formatLogString("waiting for other workers in iteration [" + currentIteration() + DefaultExpressionEngineSymbols.DEFAULT_ATTRIBUTE_END));
                }
                initSuperstepBarrier.waitForOtherWorkers();
                if (initSuperstepBarrier.terminationSignaled()) {
                    if (log.isInfoEnabled()) {
                        log.info(formatLogString("head received termination request in iteration [" + currentIteration() + DefaultExpressionEngineSymbols.DEFAULT_ATTRIBUTE_END));
                    }
                    requestTermination();
                    superstepKickoffLatch.signalTermination();
                } else {
                    incrementIterationCounter();
                    runtimeAggregatorRegistry.updateGlobalAggregatesAndReset(initSuperstepBarrier.getAggregatorNames(), initSuperstepBarrier.getAggregates());
                    superstepKickoffLatch.triggerNextSuperstep();
                }
            }
            if (log.isInfoEnabled()) {
                log.info(formatLogString("streaming out final result after [" + currentIteration() + "] iterations"));
            }
            if (!isWorksetIteration) {
                streamOutFinalOutputBulk(new InputViewIterator(dataInputView, this.solutionTypeSerializer.getSerializer()));
            } else if (isSolutionSetUnmanaged) {
                streamSolutionSetToFinalOutput(joinHashMap);
            } else {
                streamSolutionSetToFinalOutput(compactingHashTable);
            }
            this.finalOutputCollector.close();
            IterationAggregatorBroker.instance().remove(brokerKey);
            BlockingBackChannelBroker.instance().remove(brokerKey);
            SuperstepKickoffLatchBroker.instance().remove(brokerKey);
            SolutionSetBroker.instance().remove(brokerKey);
            SolutionSetUpdateBarrierBroker.instance().remove(brokerKey);
            if (compactingHashTable != null) {
                compactingHashTable.close();
            }
        } catch (Throwable th) {
            IterationAggregatorBroker.instance().remove(brokerKey);
            BlockingBackChannelBroker.instance().remove(brokerKey);
            SuperstepKickoffLatchBroker.instance().remove(brokerKey);
            SolutionSetBroker.instance().remove(brokerKey);
            SolutionSetUpdateBarrierBroker.instance().remove(brokerKey);
            if (compactingHashTable != null) {
                compactingHashTable.close();
            }
            throw th;
        }
    }

    private void streamOutFinalOutputBulk(MutableObjectIterator<X> mutableObjectIterator) throws IOException {
        Collector<X> collector = this.finalOutputCollector;
        X mo6078createInstance = this.solutionTypeSerializer.getSerializer().mo6078createInstance();
        while (true) {
            X next = mutableObjectIterator.next(mo6078createInstance);
            mo6078createInstance = next;
            if (next == null) {
                return;
            } else {
                collector.collect(mo6078createInstance);
            }
        }
    }

    private void streamSolutionSetToFinalOutput(CompactingHashTable<X> compactingHashTable) throws IOException {
        MutableObjectIterator<X> entryIterator = compactingHashTable.getEntryIterator();
        Collector<X> collector = this.finalOutputCollector;
        X mo6078createInstance = this.solutionTypeSerializer.getSerializer().mo6078createInstance();
        while (true) {
            X next = entryIterator.next(mo6078createInstance);
            mo6078createInstance = next;
            if (next == null) {
                return;
            } else {
                collector.collect(mo6078createInstance);
            }
        }
    }

    private void streamSolutionSetToFinalOutput(JoinHashMap<X> joinHashMap) throws IOException {
        Collector<X> collector = this.finalOutputCollector;
        Iterator it = joinHashMap.values().iterator();
        while (it.hasNext()) {
            collector.collect(it.next());
        }
    }

    private void feedBackSuperstepResult(DataInputView dataInputView) {
        this.inputs[this.feedbackDataInput] = new InputViewIterator(dataInputView, this.feedbackTypeSerializer.getSerializer());
    }

    private void sendEndOfSuperstepToAllIterationOutputs() throws IOException, InterruptedException {
        if (log.isDebugEnabled()) {
            log.debug(formatLogString("Sending end-of-superstep to all iteration outputs."));
        }
        Iterator<RecordWriter<?>> it = this.eventualOutputs.iterator();
        while (it.hasNext()) {
            it.next().broadcastEvent(EndOfSuperstepEvent.INSTANCE);
        }
    }

    private void sendEventToSync(WorkerDoneEvent workerDoneEvent) throws IOException, InterruptedException {
        if (log.isInfoEnabled()) {
            log.info(formatLogString("sending " + WorkerDoneEvent.class.getSimpleName() + " to sync"));
        }
        this.toSync.broadcastEvent(workerDoneEvent);
    }
}
