/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.naturalli;

import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.naturalli.NaturalLogicAnnotations;
import edu.stanford.nlp.naturalli.NaturalLogicRelation;
import edu.stanford.nlp.naturalli.NaturalLogicWeights;
import edu.stanford.nlp.naturalli.OperatorSpec;
import edu.stanford.nlp.naturalli.Polarity;
import edu.stanford.nlp.naturalli.SentenceFragment;
import edu.stanford.nlp.naturalli.Util;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.semgraph.SemanticGraphEdge;
import edu.stanford.nlp.util.Lazy;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Stack;
import java.util.stream.Collectors;

public class ForwardEntailerSearchProblem {
    public final SemanticGraph parseTree;
    public final int maxTicks;
    public final int maxResults;
    public final NaturalLogicWeights weights;
    private final byte[] indexToMaskIndex;

    protected ForwardEntailerSearchProblem(SemanticGraph parseTree, int maxResults, int maxTicks, NaturalLogicWeights weights) {
        this.parseTree = parseTree;
        this.maxResults = maxResults;
        this.maxTicks = maxTicks;
        this.weights = weights;
        List<IndexedWord> vertices = this.parseTree.vertexListSorted();
        this.indexToMaskIndex = new byte[vertices.get(vertices.size() - 1).index()];
        byte i = 0;
        for (IndexedWord vertex : vertices) {
            this.indexToMaskIndex[vertex.index() - 1] = i;
            i = (byte)(i + 1);
        }
    }

    public List<SentenceFragment> search() {
        if (this.parseTree.vertexSet().size() > 63) {
            return Collections.EMPTY_LIST;
        }
        return this.searchImplementation().stream().map(x -> new SentenceFragment(x.tree, false).changeScore(x.confidence)).filter(x -> x.words.size() > 0).collect(Collectors.toList());
    }

    private List<SearchResult> searchImplementation() {
        List<IndexedWord> topologicalVertices;
        SemanticGraph parseTree = new SemanticGraph(this.parseTree);
        assert (Util.isTree(parseTree));
        ArrayList<String> determinerRemovals = new ArrayList<String>();
        parseTree.getLeafVertices().stream().filter(vertex -> vertex.word().equalsIgnoreCase("the") || vertex.word().equalsIgnoreCase("a") || vertex.word().equalsIgnoreCase("an")).forEach(vertex -> {
            parseTree.removeVertex((IndexedWord)vertex);
            assert (Util.isTree(parseTree));
            determinerRemovals.add("det");
        });
        HashSet<SemanticGraphEdge> andsToAdd = new HashSet<SemanticGraphEdge>();
        for (IndexedWord indexedWord : parseTree.vertexSet()) {
            if (parseTree.inDegree(indexedWord) <= 1) continue;
            SemanticGraphEdge conjAnd = null;
            for (SemanticGraphEdge edge2 : parseTree.incomingEdgeIterable(indexedWord)) {
                if (!edge2.getRelation().toString().equals("conj_and")) continue;
                conjAnd = edge2;
            }
            if (conjAnd == null) continue;
            parseTree.removeEdge(conjAnd);
            assert (Util.isTree(parseTree));
            andsToAdd.add(conjAnd);
        }
        Util.cleanTree(parseTree);
        assert (Util.isTree(parseTree));
        boolean[] isSubject = new boolean[65];
        block4: for (IndexedWord vertex3 : parseTree.vertexSet()) {
            Iterator<SemanticGraphEdge> incomingEdges = parseTree.incomingEdgeIterator(vertex3);
            SemanticGraphEdge edge2 = null;
            if (incomingEdges.hasNext()) {
                edge2 = incomingEdges.next();
            }
            while (edge2 != null) {
                if (edge2.getRelation().toString().endsWith("subj")) {
                    isSubject[vertex3.index() - 1] = true;
                    continue block4;
                }
                incomingEdges = parseTree.incomingEdgeIterator(edge2.getGovernor());
                if (incomingEdges.hasNext()) {
                    edge2 = incomingEdges.next();
                    continue;
                }
                edge2 = null;
            }
        }
        ArrayList<SearchResult> arrayList = new ArrayList<SearchResult>();
        if (!determinerRemovals.isEmpty()) {
            if (andsToAdd.isEmpty()) {
                double score = Math.pow(this.weights.deletionProbability("det"), determinerRemovals.size());
                assert (!Double.isNaN(score));
                assert (!Double.isInfinite(score));
                arrayList.add(new SearchResult(parseTree, determinerRemovals, score));
            } else {
                SemanticGraph treeWithAnds = new SemanticGraph(parseTree);
                assert (Util.isTree(treeWithAnds));
                for (SemanticGraphEdge and : andsToAdd) {
                    treeWithAnds.addEdge(and.getGovernor(), and.getDependent(), and.getRelation(), Double.NEGATIVE_INFINITY, false);
                }
                assert (Util.isTree(treeWithAnds));
                arrayList.add(new SearchResult(treeWithAnds, determinerRemovals, Math.pow(this.weights.deletionProbability("det"), determinerRemovals.size())));
            }
        }
        assert (Util.isTree(parseTree));
        try {
            topologicalVertices = parseTree.topologicalSort();
        }
        catch (IllegalStateException e) {
            System.err.println("Could not topologically sort the vertices! Using left-to-right traversal.");
            topologicalVertices = parseTree.vertexListSorted();
        }
        if (topologicalVertices.isEmpty()) {
            return arrayList;
        }
        Stack<SearchState> fringe = new Stack<SearchState>();
        fringe.push(new SearchState(0L, 0, parseTree, null, null, 1.0));
        int numTicks = 0;
        block7: while (!fringe.isEmpty()) {
            Object operator;
            int nextIndex;
            if (numTicks >= this.maxTicks) {
                return arrayList;
            }
            ++numTicks;
            if (arrayList.size() >= this.maxResults) {
                return arrayList;
            }
            SearchState state = (SearchState)fringe.pop();
            assert (state.score > 0.0);
            IndexedWord currentWord = topologicalVertices.get(state.currentIndex);
            for (nextIndex = state.currentIndex + 1; nextIndex < topologicalVertices.size(); ++nextIndex) {
                IndexedWord nextWord = topologicalVertices.get(nextIndex);
                if ((state.deletionMask >>> this.indexToMaskIndex[nextWord.index() - 1] & 1L) != 0L) continue;
                fringe.push(new SearchState(state.deletionMask, nextIndex, state.tree, null, state, state.score));
                break;
            }
            boolean canDelete = !state.tree.getFirstRoot().equals(currentWord);
            for (SemanticGraphEdge edge3 : state.tree.incomingEdgeIterable(currentWord)) {
                if ("CD".equals(edge3.getGovernor().tag())) {
                    canDelete = false;
                    continue;
                }
                CoreLabel token = edge3.getDependent().backingLabel();
                Polarity tokenPolarity = (Polarity)token.get(NaturalLogicAnnotations.PolarityAnnotation.class);
                if (tokenPolarity == null) {
                    tokenPolarity = Polarity.DEFAULT;
                }
                NaturalLogicRelation lexicalRelation = (operator = (OperatorSpec)token.get(NaturalLogicAnnotations.OperatorAnnotation.class)) != null ? ((OperatorSpec)operator).instance.deleteRelation : NaturalLogicRelation.forDependencyDeletion(edge3.getRelation().toString(), isSubject[edge3.getDependent().index() - 1]);
                NaturalLogicRelation projectedRelation = tokenPolarity.projectLexicalRelation(lexicalRelation);
                if (projectedRelation.isEntailed) continue;
                canDelete = false;
            }
            if (!canDelete) continue;
            Lazy<Pair> treeWithDeletionsAndNewMask = Lazy.of(() -> {
                SemanticGraph impl = new SemanticGraph(searchState.tree);
                long newMask = searchState.deletionMask;
                for (IndexedWord vertex : searchState.tree.descendants(currentWord)) {
                    impl.removeVertex(vertex);
                    newMask |= 1L << this.indexToMaskIndex[vertex.index() - 1];
                    assert (this.indexToMaskIndex[vertex.index() - 1] < 64);
                    assert ((newMask >>> this.indexToMaskIndex[vertex.index() - 1] & 1L) == 1L);
                }
                return Pair.makePair(impl, newMask);
            });
            double newScore = state.score;
            operator = state.tree.incomingEdgeIterable(currentWord).iterator();
            while (operator.hasNext()) {
                SemanticGraphEdge edge4 = operator.next();
                double multiplier = this.weights.deletionProbability(edge4, state.tree.outgoingEdgeIterable(edge4.getGovernor()));
                assert (!Double.isNaN(multiplier));
                assert (!Double.isInfinite(multiplier));
                newScore *= multiplier;
            }
            if (!(newScore > 0.0)) continue;
            SemanticGraph resultTree = new SemanticGraph((SemanticGraph)treeWithDeletionsAndNewMask.get().first);
            andsToAdd.stream().filter(edge -> resultTree.containsVertex(edge.getGovernor()) && resultTree.containsVertex(edge.getDependent())).forEach(edge -> resultTree.addEdge(edge.getGovernor(), edge.getDependent(), edge.getRelation(), Double.NEGATIVE_INFINITY, false));
            arrayList.add(new SearchResult(resultTree, ForwardEntailerSearchProblem.aggregateDeletedEdges(state, state.tree.incomingEdgeIterable(currentWord), determinerRemovals), newScore));
            for (nextIndex = state.currentIndex + 1; nextIndex < topologicalVertices.size(); ++nextIndex) {
                IndexedWord nextWord = topologicalVertices.get(nextIndex);
                long newMask = (Long)treeWithDeletionsAndNewMask.get().second;
                SemanticGraph treeWithDeletions = (SemanticGraph)treeWithDeletionsAndNewMask.get().first;
                if ((newMask >>> this.indexToMaskIndex[nextWord.index() - 1] & 1L) != 0L) continue;
                assert (treeWithDeletions.containsVertex(topologicalVertices.get(nextIndex)));
                fringe.push(new SearchState(newMask, nextIndex, treeWithDeletions, null, state, newScore));
                continue block7;
            }
        }
        return arrayList;
    }

    private static List<String> aggregateDeletedEdges(SearchState state, Iterable<SemanticGraphEdge> justDeleted, Iterable<String> otherEdges) {
        ArrayList<String> rtn = new ArrayList<String>();
        for (SemanticGraphEdge semanticGraphEdge : justDeleted) {
            rtn.add(semanticGraphEdge.getRelation().toString());
        }
        for (String string : otherEdges) {
            rtn.add(string);
        }
        while (state != null) {
            if (state.lastDeletedEdge != null) {
                rtn.add(state.lastDeletedEdge);
            }
            state = state.source;
        }
        return rtn;
    }

    private static class SearchState {
        public final long deletionMask;
        public final int currentIndex;
        public final SemanticGraph tree;
        public final String lastDeletedEdge;
        public final SearchState source;
        public final double score;

        private SearchState(long deletionMask, int currentIndex, SemanticGraph tree, String lastDeletedEdge, SearchState source, double score) {
            this.deletionMask = deletionMask;
            this.currentIndex = currentIndex;
            this.tree = tree;
            this.lastDeletedEdge = lastDeletedEdge;
            this.source = source;
            this.score = score;
        }
    }

    private static class SearchResult {
        public final SemanticGraph tree;
        public final List<String> deletedEdges;
        public final double confidence;

        private SearchResult(SemanticGraph tree, List<String> deletedEdges, double confidence) {
            this.tree = tree;
            this.deletedEdges = deletedEdges;
            this.confidence = confidence;
        }

        public String toString() {
            return StringUtils.join(this.tree.vertexListSorted().stream().map(IndexedWord::word), " ");
        }
    }
}

