/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.codegen.template;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.codegen.SpoofCompiler;
import org.apache.sysml.hops.codegen.opt.InterestingPoint;
import org.apache.sysml.hops.codegen.opt.PlanSelection;
import org.apache.sysml.hops.codegen.template.TemplateBase;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.runtime.util.UtilFunctions;

public class CPlanMemoTable {
    private static final Log LOG = LogFactory.getLog((String)CPlanMemoTable.class.getName());
    protected HashMap<Long, List<MemoTableEntry>> _plans = new HashMap();
    protected HashMap<Long, Hop> _hopRefs = new HashMap();
    protected HashSet<Long> _plansBlacklist = new HashSet();

    public HashMap<Long, List<MemoTableEntry>> getPlans() {
        return this._plans;
    }

    public HashSet<Long> getPlansBlacklisted() {
        return this._plansBlacklist;
    }

    public HashMap<Long, Hop> getHopRefs() {
        return this._hopRefs;
    }

    public void addHop(Hop hop) {
        this._hopRefs.put(hop.getHopID(), hop);
    }

    public boolean containsHop(Hop hop) {
        return this._hopRefs.containsKey(hop.getHopID());
    }

    public boolean contains(long hopID) {
        return this._plans.containsKey(hopID) && !this._plans.get(hopID).isEmpty();
    }

    public boolean contains(long hopID, TemplateBase.TemplateType type) {
        return this.contains(hopID) && this.get(hopID).stream().anyMatch(p -> p.type == type);
    }

    public boolean contains(long hopID, MemoTableEntry me, TemplateBase.TemplateType type) {
        return this.contains(hopID) && this.get(hopID).stream().anyMatch(p -> p.type == type && p.equalPlanRefs(me));
    }

    public boolean contains(long hopID, boolean checkClose, TemplateBase.TemplateType ... type) {
        if (!checkClose && type.length == 1) {
            return this.contains(hopID, type[0]);
        }
        Set<TemplateBase.TemplateType> probe = UtilFunctions.asSet(type);
        return this.contains(hopID) && this.get(hopID).stream().anyMatch(p -> (!checkClose || !p.isClosed()) && probe.contains((Object)p.type));
    }

    public boolean containsNotIn(long hopID, Collection<TemplateBase.TemplateType> types, boolean checkChildRefs) {
        return this.contains(hopID) && this.get(hopID).stream().anyMatch(p -> (!checkChildRefs || p.hasPlanRef()) && p.isValid() && !types.contains((Object)p.type));
    }

    public boolean hasOnlyExactMatches(long hopID, TemplateBase.TemplateType type1, TemplateBase.TemplateType type2) {
        List<MemoTableEntry> l1 = this.get(hopID, type1);
        List<MemoTableEntry> l2 = this.get(hopID, type2);
        boolean ret = l1.size() == l2.size();
        for (MemoTableEntry me : l1) {
            ret &= l2.stream().anyMatch(p -> p.equalPlanRefs(me));
        }
        return ret;
    }

    public int countEntries(long hopID) {
        return this.get(hopID).size();
    }

    public int countEntries(long hopID, TemplateBase.TemplateType type) {
        return (int)this.get(hopID).stream().filter(p -> p.type == type).count();
    }

    public boolean containsTopLevel(long hopID) {
        return !this._plansBlacklist.contains(hopID) && this.getBest(hopID) != null;
    }

    public void add(Hop hop, TemplateBase.TemplateType type) {
        this.add(hop, type, -1L, -1L, -1L);
    }

    public void add(Hop hop, TemplateBase.TemplateType type, long in1) {
        this.add(hop, type, in1, -1L, -1L);
    }

    public void add(Hop hop, TemplateBase.TemplateType type, long in1, long in2) {
        this.add(hop, type, in1, in2, -1L);
    }

    public void add(Hop hop, TemplateBase.TemplateType type, long in1, long in2, long in3) {
        int size = hop instanceof IndexingOp ? 1 : hop.getInput().size();
        this.add(hop, new MemoTableEntry(type, in1, in2, in3, size));
    }

    public void add(Hop hop, MemoTableEntry me) {
        this._hopRefs.put(hop.getHopID(), hop);
        if (!this._plans.containsKey(hop.getHopID())) {
            this._plans.put(hop.getHopID(), new ArrayList());
        }
        this._plans.get(hop.getHopID()).add(me);
    }

    public void addAll(Hop hop, MemoTableEntrySet P) {
        this._hopRefs.put(hop.getHopID(), hop);
        if (!this._plans.containsKey(hop.getHopID())) {
            this._plans.put(hop.getHopID(), new ArrayList());
        }
        this._plans.get(hop.getHopID()).addAll(P.plans);
    }

    public void remove(Hop hop, Set<MemoTableEntry> blackList) {
        this._plans.get(hop.getHopID()).removeIf(p -> blackList.contains(p));
    }

    public void remove(Hop hop, TemplateBase.TemplateType type) {
        this._plans.get(hop.getHopID()).removeIf(p -> p.type == type);
    }

    public void removeAllRefTo(long hopID) {
        this.removeAllRefTo(hopID, null);
    }

    public void removeAllRefTo(long hopID, TemplateBase.TemplateType type) {
        for (Map.Entry<Long, List<MemoTableEntry>> e : this._plans.entrySet()) {
            if (e.getValue().isEmpty() || e.getKey() == hopID) continue;
            e.getValue().removeIf(p -> p.hasPlanRefTo(hopID) && (type == null || p.type == type));
        }
    }

    public void setDistinct(long hopID, List<MemoTableEntry> plans) {
        this._plans.put(hopID, plans.stream().distinct().collect(Collectors.toList()));
    }

    public void pruneRedundant(long hopID, boolean pruneDominated, InterestingPoint[] matPoints) {
        if (!this.contains(hopID)) {
            return;
        }
        this.setDistinct(hopID, this._plans.get(hopID));
        this._plans.get(hopID).removeIf(p -> p.isClosed() && !p.hasPlanRef());
        if (pruneDominated) {
            HashSet<MemoTableEntry> rmList = new HashSet<MemoTableEntry>();
            List<MemoTableEntry> list = this._plans.get(hopID);
            Hop hop = this._hopRefs.get(hopID);
            for (MemoTableEntry e1 : list) {
                for (MemoTableEntry e2 : list) {
                    if (e1 == e2 || !e1.subsumes(e2)) continue;
                    boolean rmSafe = true;
                    for (int i = 0; i <= 2; ++i) {
                        rmSafe &= e1.isPlanRef(i) && !e2.isPlanRef(i) ? matPoints != null && !InterestingPoint.isMatPoint(matPoints, hopID, e1.input(i)) || hop.getInput().get(i).getParent().size() == 1 : true;
                    }
                    if (!rmSafe) continue;
                    rmList.add(e2);
                }
            }
            this.remove(hop, rmList);
        }
    }

    public void pruneSuboptimal(ArrayList<Hop> roots) {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("#1: Memo before plan selection (" + this.size() + " plans)\n" + this));
        }
        HashSet<Long> ix = new HashSet<Long>();
        for (Map.Entry<Long, List<MemoTableEntry>> e : this._plans.entrySet()) {
            for (MemoTableEntry me : e.getValue()) {
                ix.add(me.input1);
                ix.add(me.input2);
                ix.add(me.input3);
            }
        }
        Iterator<Map.Entry<Long, List<MemoTableEntry>>> iter = this._plans.entrySet().iterator();
        while (iter.hasNext()) {
            Map.Entry<Long, List<MemoTableEntry>> e;
            e = iter.next();
            if (ix.contains(e.getKey())) continue;
            ((List)e.getValue()).removeIf(p -> !p.hasPlanRef());
            if (!((List)e.getValue()).isEmpty()) continue;
            iter.remove();
        }
        if (SpoofCompiler.PLAN_SEL_POLICY.isHeuristic()) {
            for (Map.Entry entry : this._plans.entrySet()) {
                for (MemoTableEntry me : (List)entry.getValue()) {
                    for (int i = 0; i <= 2; ++i) {
                        if (!me.isPlanRef(i) || this._hopRefs.get(me.input(i)).getParent().size() != 1) continue;
                        this._plansBlacklist.add(me.input(i));
                    }
                }
            }
        }
        PlanSelection selector = SpoofCompiler.createPlanSelector();
        selector.selectPlans(this, roots);
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("#2: Memo after plan selection (" + this.size() + " plans)\n" + this));
        }
    }

    public List<MemoTableEntry> get(long hopID) {
        return this._plans.get(hopID);
    }

    public List<MemoTableEntry> get(long hopID, TemplateBase.TemplateType type) {
        return this._plans.get(hopID).stream().filter(p -> p.type == type).collect(Collectors.toList());
    }

    public List<MemoTableEntry> getDistinct(long hopID) {
        return this._plans.get(hopID).stream().distinct().collect(Collectors.toList());
    }

    public List<TemplateBase> getDistinctTemplates(long hopID) {
        if (!this.contains(hopID)) {
            return Collections.emptyList();
        }
        return this._plans.get(hopID).stream().map(p -> TemplateUtils.createTemplate(p.type, p.ctype)).distinct().collect(Collectors.toList());
    }

    public List<TemplateBase.TemplateType> getDistinctTemplateTypes(long hopID, int refAt) {
        return this.getDistinctTemplateTypes(hopID, refAt, false);
    }

    public List<TemplateBase.TemplateType> getDistinctTemplateTypes(long hopID, int refAt, boolean exclInvalOuter) {
        if (!this.contains(hopID)) {
            return Collections.emptyList();
        }
        return this._plans.get(hopID).stream().filter(p -> p.isPlanRef(refAt) && (!exclInvalOuter || p.type != TemplateBase.TemplateType.OUTER || p.isValid())).map(p -> p.type).distinct().collect(Collectors.toList());
    }

    public MemoTableEntry getBest(long hopID) {
        List<MemoTableEntry> tmp = this.get(hopID);
        if (tmp == null || tmp.isEmpty()) {
            return null;
        }
        return tmp.stream().filter(p -> p.isValid()).min(Comparator.comparing(p -> p.type.getRank())).orElse(null);
    }

    public MemoTableEntry getBest(long hopID, TemplateBase.TemplateType pref) {
        List<MemoTableEntry> tmp = this.get(hopID);
        if (tmp == null || tmp.isEmpty()) {
            return null;
        }
        return Collections.min(tmp, Comparator.comparing(p -> p.type == pref ? -p.countPlanRefs() : p.type.getRank() + 1));
    }

    public MemoTableEntry getBest(long hopID, TemplateBase.TemplateType pref1, TemplateBase.TemplateType pref2) {
        List<MemoTableEntry> tmp = this.get(hopID);
        if (tmp == null || tmp.isEmpty()) {
            return null;
        }
        return Collections.min(tmp, Comparator.comparing(p -> p.type == pref1 ? -p.countPlanRefs() - 4 : (p.type == pref2 ? -p.countPlanRefs() : p.type.getRank() + 1)));
    }

    public long[] getAllRefs(long hopID) {
        long[] refs = new long[3];
        for (MemoTableEntry me : this.get(hopID)) {
            for (int i = 0; i < 3; ++i) {
                if (!me.isPlanRef(i)) continue;
                refs[i] = me.input(i);
            }
        }
        return refs;
    }

    public int size() {
        return this._plans.values().stream().map(list -> list.size()).mapToInt(x -> x).sum();
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("----------------------------------\n");
        sb.append("MEMO TABLE: \n");
        sb.append("----------------------------------\n");
        for (Map.Entry<Long, List<MemoTableEntry>> e : this._plans.entrySet()) {
            sb.append(e.getKey() + " " + this._hopRefs.get(e.getKey()).getOpString() + ": ");
            sb.append(Arrays.toString(e.getValue().toArray(new MemoTableEntry[0])) + "\n");
        }
        sb.append("----------------------------------\n");
        sb.append("Blacklisted Plans: ");
        sb.append(Arrays.toString((Object[])this._plansBlacklist.toArray(new Long[0])) + "\n");
        sb.append("----------------------------------\n");
        return sb.toString();
    }

    public static class MemoTableEntrySet {
        public ArrayList<MemoTableEntry> plans = new ArrayList();

        public MemoTableEntrySet(Hop hop, Hop c, TemplateBase tpl) {
            int pos = c != null ? hop.getInput().indexOf(c) : -1;
            int size = hop instanceof IndexingOp ? 1 : hop.getInput().size();
            this.plans.add(new MemoTableEntry(tpl.getType(), pos == 0 ? c.getHopID() : -1L, pos == 1 ? c.getHopID() : -1L, pos == 2 ? c.getHopID() : -1L, size, tpl.getCType()));
        }

        public void crossProduct(int pos, Long ... refs) {
            if (refs.length == 1 && refs[0] == -1L) {
                return;
            }
            ArrayList<MemoTableEntry> tmp = new ArrayList<MemoTableEntry>();
            for (MemoTableEntry me : this.plans) {
                for (Long ref : refs) {
                    tmp.add(new MemoTableEntry(me.type, pos == 0 ? ref : me.input1, pos == 1 ? ref : me.input2, pos == 2 ? ref : me.input3, me.size));
                }
            }
            this.plans = tmp;
        }

        public String toString() {
            return Arrays.toString(this.plans.toArray(new MemoTableEntry[0]));
        }
    }

    public static class MemoTableEntry {
        public TemplateBase.TemplateType type;
        public final long input1;
        public final long input2;
        public final long input3;
        public final int size;
        public TemplateBase.CloseType ctype;

        public MemoTableEntry(TemplateBase.TemplateType t, long in1, long in2, long in3, int inlen) {
            this(t, in1, in2, in3, inlen, TemplateBase.CloseType.OPEN_VALID);
        }

        public MemoTableEntry(TemplateBase.TemplateType t, long in1, long in2, long in3, int inlen, TemplateBase.CloseType close) {
            this.type = t;
            this.input1 = in1;
            this.input2 = in2;
            this.input3 = in3;
            this.size = inlen;
            this.ctype = close;
        }

        public boolean isClosed() {
            return this.ctype.isClosed();
        }

        public boolean isValid() {
            return this.ctype.isValid();
        }

        public boolean isPlanRef(int index) {
            return index == 0 && this.input1 >= 0L || index == 1 && this.input2 >= 0L || index == 2 && this.input3 >= 0L;
        }

        public boolean hasPlanRef() {
            return this.isPlanRef(0) || this.isPlanRef(1) || this.isPlanRef(2);
        }

        public boolean hasPlanRefTo(long hopID) {
            return this.input1 == hopID || this.input2 == hopID || this.input3 == hopID;
        }

        public int countPlanRefs() {
            return (this.input1 >= 0L ? 1 : 0) + (this.input2 >= 0L ? 1 : 0) + (this.input3 >= 0L ? 1 : 0);
        }

        public int getPlanRefIndex() {
            return this.input1 >= 0L ? 0 : (this.input2 >= 0L ? 1 : (this.input3 >= 0L ? 2 : -1));
        }

        public boolean equalPlanRefs(MemoTableEntry that) {
            return this.input1 == that.input1 && this.input2 == that.input2 && this.input3 == that.input3;
        }

        public long input(int index) {
            return index == 0 ? this.input1 : (index == 1 ? this.input2 : this.input3);
        }

        public boolean subsumes(MemoTableEntry that) {
            return !(this.type != that.type || !this.isPlanRef(0) && that.isPlanRef(0) || !this.isPlanRef(1) && that.isPlanRef(1) || !this.isPlanRef(2) && that.isPlanRef(2));
        }

        public int hashCode() {
            int h = UtilFunctions.intHashCode(this.type.ordinal(), Long.hashCode(this.input1));
            h = UtilFunctions.intHashCode(h, Long.hashCode(this.input2));
            h = UtilFunctions.intHashCode(h, Long.hashCode(this.input3));
            h = UtilFunctions.intHashCode(h, this.size);
            h = UtilFunctions.intHashCode(h, this.ctype.ordinal());
            return h;
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof MemoTableEntry)) {
                return false;
            }
            MemoTableEntry that = (MemoTableEntry)obj;
            return this.type == that.type && this.input1 == that.input1 && this.input2 == that.input2 && this.input3 == that.input3 && this.size == that.size && this.ctype == that.ctype;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append(this.type.name());
            sb.append("(");
            for (int i = 0; i < this.size; ++i) {
                if (i > 0) {
                    sb.append(",");
                }
                sb.append(this.input(i));
            }
            if (!this.isValid()) {
                sb.append("|x");
            }
            sb.append(")");
            return sb.toString();
        }
    }
}

