/*
 * Decompiled with CFR 0.152.
 */
package com.hankcs.hanlp.dependency.perceptron.transition.trainer;

import com.hankcs.hanlp.corpus.io.IOUtil;
import com.hankcs.hanlp.dependency.perceptron.accessories.Edge;
import com.hankcs.hanlp.dependency.perceptron.accessories.Evaluator;
import com.hankcs.hanlp.dependency.perceptron.accessories.Options;
import com.hankcs.hanlp.dependency.perceptron.accessories.Pair;
import com.hankcs.hanlp.dependency.perceptron.learning.AveragedPerceptron;
import com.hankcs.hanlp.dependency.perceptron.structures.IndexMaps;
import com.hankcs.hanlp.dependency.perceptron.structures.ParserModel;
import com.hankcs.hanlp.dependency.perceptron.transition.configuration.BeamElement;
import com.hankcs.hanlp.dependency.perceptron.transition.configuration.Configuration;
import com.hankcs.hanlp.dependency.perceptron.transition.configuration.Instance;
import com.hankcs.hanlp.dependency.perceptron.transition.configuration.State;
import com.hankcs.hanlp.dependency.perceptron.transition.features.FeatureExtractor;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.Action;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.ArcEager;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.BeamScorerThread;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.KBeamArcEagerParser;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.LabeledAction;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.TransitionBasedParser;
import com.hankcs.hanlp.utility.MathUtility;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.TreeSet;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class ArcEagerBeamTrainer
extends TransitionBasedParser {
    Options options;
    private String updateMode;
    private Random randGen;

    public ArcEagerBeamTrainer(String updateMode, AveragedPerceptron classifier, Options options, ArrayList<Integer> dependencyRelations, int featureLength, IndexMaps maps) {
        super(classifier, dependencyRelations, featureLength, maps);
        this.updateMode = updateMode;
        this.options = options;
        this.randGen = new Random();
    }

    public void train(ArrayList<Instance> trainData, String devPath, int maxIteration, String modelPath, boolean lowerCased, HashSet<String> punctuations, int partialTreeIter) throws IOException, ExecutionException, InterruptedException {
        ExecutorService executor = Executors.newFixedThreadPool(this.options.numOfThreads);
        ExecutorCompletionService<ArrayList<BeamElement>> pool = new ExecutorCompletionService<ArrayList<BeamElement>>(executor);
        double bestUAS = -1.0;
        for (int i = 1; i <= maxIteration; ++i) {
            long start = System.currentTimeMillis();
            int dataCount = 0;
            int logEvery = (int)Math.ceil((float)trainData.size() / 10000.0f);
            for (Instance instance : trainData) {
                if (++dataCount % logEvery == 0 || dataCount == trainData.size()) {
                    System.out.printf("\r\u8fed\u4ee3 " + i + "/" + maxIteration + " %.2f%% ", MathUtility.percentage(dataCount, trainData.size()));
                }
                this.trainOnOneSample(instance, partialTreeIter, i, dataCount, pool);
                this.classifier.incrementIteration();
            }
            long end = System.currentTimeMillis();
            long timeSec = (end - start) / 1000L;
            System.out.print(" \u8017\u65f6 " + timeSec + " \u79d2\u3002");
            ParserModel parserModel = new ParserModel(this.classifier, this.maps, this.dependencyRelations, this.options);
            if (!devPath.equals("")) {
                AveragedPerceptron averagedPerceptron = new AveragedPerceptron(parserModel);
                KBeamArcEagerParser parser = new KBeamArcEagerParser(averagedPerceptron, this.dependencyRelations, this.featureLength, this.maps, this.options.numOfThreads, this.options);
                String outputFile = modelPath + ".__tmp__";
                parser.parseConllFile(devPath, outputFile, this.options.rootFirst, this.options.beamWidth, true, lowerCased, this.options.numOfThreads, false, "");
                double[] score = Evaluator.evaluate(devPath, outputFile, punctuations);
                System.out.printf("UAS=%.2f LAS=%.2f", score[0], score[1]);
                IOUtil.deleteFile(outputFile);
                parser.shutDownLiveThreads();
                if (score[0] > bestUAS) {
                    bestUAS = score[0];
                    System.out.println(" \u6700\u9ad8\u5206\uff01\u4fdd\u5b58\u4e2d...");
                    parserModel.saveModel(modelPath);
                    continue;
                }
                System.out.println();
                continue;
            }
            parserModel.saveModel(modelPath);
            System.out.println();
        }
        boolean isTerminated = executor.isTerminated();
        while (!isTerminated) {
            executor.shutdownNow();
            isTerminated = executor.isTerminated();
        }
    }

    private void trainOnOneSample(Instance instance, int partialTreeIter, int i, int dataCount, CompletionService<ArrayList<BeamElement>> pool) throws InterruptedException, ExecutionException {
        boolean isPartial = instance.isPartial(this.options.rootFirst);
        if (partialTreeIter > i && isPartial) {
            return;
        }
        Configuration initialConfiguration = new Configuration(instance.getSentence(), this.options.rootFirst);
        Configuration firstOracle = initialConfiguration.clone();
        ArrayList<Configuration> beam = new ArrayList<Configuration>(this.options.beamWidth);
        beam.add(initialConfiguration);
        HashSet<Configuration> oracles = new HashSet<Configuration>();
        oracles.add(firstOracle);
        float maxViol = Float.NEGATIVE_INFINITY;
        Pair<Configuration, Configuration> maxViolPair = null;
        Configuration bestScoringOracle = null;
        boolean oracleInBeam = false;
        while (!ArcEager.isTerminal(beam) && beam.size() > 0) {
            float violation;
            HashSet<Configuration> newOracles = new HashSet<Configuration>();
            bestScoringOracle = this.options.useDynamicOracle || isPartial ? this.zeroCostDynamicOracle(instance, oracles, newOracles) : this.staticOracle(instance, oracles, newOracles);
            if (newOracles.size() == 0) {
                bestScoringOracle = this.staticOracle(instance, oracles, newOracles);
            }
            oracles = newOracles;
            TreeSet<BeamElement> beamPreserver = new TreeSet<BeamElement>();
            if (this.options.numOfThreads == 1 || beam.size() == 1) {
                this.beamSortOneThread(beam, beamPreserver);
            } else {
                int b;
                for (b = 0; b < beam.size(); ++b) {
                    pool.submit(new BeamScorerThread(false, this.classifier, beam.get(b), this.dependencyRelations, this.featureLength, b, this.options.rootFirst));
                }
                for (b = 0; b < beam.size(); ++b) {
                    for (BeamElement element : pool.take().get()) {
                        beamPreserver.add(element);
                        if (beamPreserver.size() <= this.options.beamWidth) continue;
                        beamPreserver.pollFirst();
                    }
                }
            }
            if (beamPreserver.size() == 0 || beam.size() == 0) break;
            oracleInBeam = false;
            ArrayList<Configuration> repBeam = new ArrayList<Configuration>(this.options.beamWidth);
            for (BeamElement beamElement : beamPreserver.descendingSet()) {
                int b = beamElement.index;
                int action = beamElement.action;
                int label = beamElement.label;
                float score = beamElement.score;
                Configuration newConfig = beam.get(b).clone();
                ArcEager.commitAction(action, label, score, this.dependencyRelations, newConfig);
                repBeam.add(newConfig);
                if (oracleInBeam || !oracles.contains(newConfig)) continue;
                oracleInBeam = true;
            }
            beam = repBeam;
            if (beam.size() <= 0 || oracles.size() <= 0) break;
            Configuration bestConfig = beam.get(0);
            if (oracles.contains(bestConfig)) {
                oracles = new HashSet();
                oracles.add(bestConfig);
            } else if (this.options.useRandomOracleSelection) {
                ArrayList<Configuration> keys = new ArrayList<Configuration>(oracles);
                Configuration randomKey = (Configuration)keys.get(this.randGen.nextInt(keys.size()));
                oracles = new HashSet();
                oracles.add(randomKey);
                bestScoringOracle = randomKey;
            } else {
                oracles = new HashSet();
                oracles.add(bestScoringOracle);
            }
            if (!oracleInBeam && this.updateMode.equals("early")) break;
            if (oracleInBeam || !this.updateMode.equals("max_violation") || !((violation = bestConfig.getScore(true) - bestScoringOracle.getScore(true)) > maxViol)) continue;
            maxViol = violation;
            maxViolPair = new Pair<Configuration, Configuration>(bestConfig, bestScoringOracle);
        }
        if (!oracleInBeam || !bestScoringOracle.equals(beam.get(0))) {
            this.updateWeights(initialConfiguration, maxViol, isPartial, bestScoringOracle, maxViolPair, beam);
        }
    }

    private Configuration staticOracle(Instance instance, Collection<Configuration> oracles, Collection<Configuration> newOracles) {
        Configuration bestScoringOracle = null;
        int top = -1;
        int first = -1;
        HashMap<Integer, Edge> goldDependencies = instance.getGoldDependencies();
        HashMap<Integer, HashSet<Integer>> reversedDependencies = instance.getReversedDependencies();
        for (Configuration configuration : oracles) {
            State state = configuration.state;
            Object[] features = FeatureExtractor.extractAllParseFeatures(configuration, this.featureLength);
            if (!state.stackEmpty()) {
                top = state.stackTop();
            }
            if (!state.bufferEmpty()) {
                first = state.bufferHead();
            }
            if (!configuration.state.isTerminalState()) {
                float score;
                float[] scores;
                Configuration newConfig = configuration.clone();
                if (first > 0 && goldDependencies.containsKey(first) && goldDependencies.get((Object)Integer.valueOf((int)first)).headIndex == top) {
                    int dependency = goldDependencies.get((Object)Integer.valueOf((int)first)).relationId;
                    scores = this.classifier.rightArcScores(features, false);
                    score = scores[dependency];
                    ArcEager.rightArc(newConfig.state, dependency);
                    newConfig.addAction(3 + dependency);
                    newConfig.addScore(score);
                } else if (top > 0 && goldDependencies.containsKey(top) && goldDependencies.get((Object)Integer.valueOf((int)top)).headIndex == first) {
                    int dependency = goldDependencies.get((Object)Integer.valueOf((int)top)).relationId;
                    scores = this.classifier.leftArcScores(features, false);
                    score = scores[dependency];
                    ArcEager.leftArc(newConfig.state, dependency);
                    newConfig.addAction(3 + this.dependencyRelations.size() + dependency);
                    newConfig.addScore(score);
                } else if (top >= 0 && state.hasHead(top)) {
                    if (reversedDependencies.containsKey(top)) {
                        if (reversedDependencies.get(top).size() == state.valence(top)) {
                            float score2 = this.classifier.reduceScore(features, false);
                            ArcEager.reduce(newConfig.state);
                            newConfig.addAction(1);
                            newConfig.addScore(score2);
                        } else {
                            float score3 = this.classifier.shiftScore(features, false);
                            ArcEager.shift(newConfig.state);
                            newConfig.addAction(0);
                            newConfig.addScore(score3);
                        }
                    } else {
                        float score4 = this.classifier.reduceScore(features, false);
                        ArcEager.reduce(newConfig.state);
                        newConfig.addAction(1);
                        newConfig.addScore(score4);
                    }
                } else if (state.bufferEmpty() && state.stackSize() == 1 && state.stackTop() == state.rootIndex) {
                    float score5 = this.classifier.reduceScore(features, false);
                    ArcEager.reduce(newConfig.state);
                    newConfig.addAction(1);
                    newConfig.addScore(score5);
                } else {
                    float score6 = this.classifier.shiftScore(features, true);
                    ArcEager.shift(newConfig.state);
                    newConfig.addAction(0);
                    newConfig.addScore(score6);
                }
                bestScoringOracle = newConfig;
                newOracles.add(newConfig);
                continue;
            }
            newOracles.add(configuration);
        }
        return bestScoringOracle;
    }

    private Configuration zeroCostDynamicOracle(Instance instance, Collection<Configuration> oracles, Collection<Configuration> newOracles) {
        float bestScore = Float.NEGATIVE_INFINITY;
        Configuration bestScoringOracle = null;
        for (Configuration configuration : oracles) {
            if (!configuration.state.isTerminalState()) {
                float score;
                Configuration newConfig;
                int dependency;
                float score22;
                Configuration newConfig2;
                State currentState = configuration.state;
                Object[] features = FeatureExtractor.extractAllParseFeatures(configuration, this.featureLength);
                if (instance.actionCost(Action.Shift, -1, currentState) == 0) {
                    newConfig2 = configuration.clone();
                    score22 = this.classifier.shiftScore(features, false);
                    ArcEager.shift(newConfig2.state);
                    newConfig2.addAction(0);
                    newConfig2.addScore(score22);
                    newOracles.add(newConfig2);
                    if (newConfig2.getScore(true) > bestScore) {
                        bestScore = newConfig2.getScore(true);
                        bestScoringOracle = newConfig2;
                    }
                }
                if (ArcEager.canDo(Action.RightArc, currentState)) {
                    float[] rightArcScores = this.classifier.rightArcScores(features, false);
                    Iterator score22 = this.dependencyRelations.iterator();
                    while (score22.hasNext()) {
                        dependency = (Integer)score22.next();
                        if (instance.actionCost(Action.RightArc, dependency, currentState) != 0) continue;
                        newConfig = configuration.clone();
                        score = rightArcScores[dependency];
                        ArcEager.rightArc(newConfig.state, dependency);
                        newConfig.addAction(3 + dependency);
                        newConfig.addScore(score);
                        newOracles.add(newConfig);
                        if (!(newConfig.getScore(true) > bestScore)) continue;
                        bestScore = newConfig.getScore(true);
                        bestScoringOracle = newConfig;
                    }
                }
                if (ArcEager.canDo(Action.LeftArc, currentState)) {
                    float[] leftArcScores = this.classifier.leftArcScores(features, false);
                    Iterator score22 = this.dependencyRelations.iterator();
                    while (score22.hasNext()) {
                        dependency = (Integer)score22.next();
                        if (instance.actionCost(Action.LeftArc, dependency, currentState) != 0) continue;
                        newConfig = configuration.clone();
                        score = leftArcScores[dependency];
                        ArcEager.leftArc(newConfig.state, dependency);
                        newConfig.addAction(3 + this.dependencyRelations.size() + dependency);
                        newConfig.addScore(score);
                        newOracles.add(newConfig);
                        if (!(newConfig.getScore(true) > bestScore)) continue;
                        bestScore = newConfig.getScore(true);
                        bestScoringOracle = newConfig;
                    }
                }
                if (instance.actionCost(Action.Reduce, -1, currentState) != 0) continue;
                newConfig2 = configuration.clone();
                score22 = this.classifier.reduceScore(features, false);
                ArcEager.reduce(newConfig2.state);
                newConfig2.addAction(1);
                newConfig2.addScore(score22);
                newOracles.add(newConfig2);
                if (!(newConfig2.getScore(true) > bestScore)) continue;
                bestScore = newConfig2.getScore(true);
                bestScoringOracle = newConfig2;
                continue;
            }
            newOracles.add(configuration);
        }
        return bestScoringOracle;
    }

    private void beamSortOneThread(ArrayList<Configuration> beam, TreeSet<BeamElement> beamPreserver) {
        for (int b = 0; b < beam.size(); ++b) {
            float addedScore;
            float score;
            int dependency;
            float addedScore2;
            float score2;
            Configuration configuration = beam.get(b);
            State currentState = configuration.state;
            float prevScore = configuration.score;
            boolean canShift = ArcEager.canDo(Action.Shift, currentState);
            boolean canReduce = ArcEager.canDo(Action.Reduce, currentState);
            boolean canRightArc = ArcEager.canDo(Action.RightArc, currentState);
            boolean canLeftArc = ArcEager.canDo(Action.LeftArc, currentState);
            Object[] features = FeatureExtractor.extractAllParseFeatures(configuration, this.featureLength);
            if (canShift) {
                score2 = this.classifier.shiftScore(features, false);
                addedScore2 = score2 + prevScore;
                this.addToBeam(beamPreserver, b, addedScore2, 0, -1, this.options.beamWidth);
            }
            if (canReduce) {
                score2 = this.classifier.reduceScore(features, false);
                addedScore2 = score2 + prevScore;
                this.addToBeam(beamPreserver, b, addedScore2, 1, -1, this.options.beamWidth);
            }
            if (canRightArc) {
                float[] rightArcScores = this.classifier.rightArcScores(features, false);
                Iterator iterator = this.dependencyRelations.iterator();
                while (iterator.hasNext()) {
                    dependency = (Integer)iterator.next();
                    score = rightArcScores[dependency];
                    addedScore = score + prevScore;
                    this.addToBeam(beamPreserver, b, addedScore, 2, dependency, this.options.beamWidth);
                }
            }
            if (!canLeftArc) continue;
            float[] leftArcScores = this.classifier.leftArcScores(features, false);
            Iterator iterator = this.dependencyRelations.iterator();
            while (iterator.hasNext()) {
                dependency = (Integer)iterator.next();
                score = leftArcScores[dependency];
                addedScore = score + prevScore;
                this.addToBeam(beamPreserver, b, addedScore, 3, dependency, this.options.beamWidth);
            }
        }
    }

    private void addToBeam(TreeSet<BeamElement> beamPreserver, int b, float addedScore, int action, int label, int beamWidth) {
        beamPreserver.add(new BeamElement(addedScore, b, action, label));
        if (beamPreserver.size() > beamWidth) {
            beamPreserver.pollFirst();
        }
    }

    private void updateWeights(Configuration initialConfiguration, float maxViol, boolean isPartial, Configuration bestScoringOracle, Pair<Configuration, Configuration> maxViolPair, ArrayList<Configuration> beam) {
        Float value;
        Pair<Integer, Object> featName;
        HashMap map;
        int f;
        Object[] feats;
        boolean isTrueFeature;
        Configuration predicted;
        Configuration finalOracle;
        if (!this.updateMode.equals("max_violation")) {
            finalOracle = bestScoringOracle;
            predicted = beam.get(0);
        } else {
            float violation = beam.get(0).getScore(true) - bestScoringOracle.getScore(true);
            if (violation > maxViol) {
                maxViolPair = new Pair<Configuration, Configuration>(beam.get(0), bestScoringOracle);
            }
            predicted = (Configuration)maxViolPair.first;
            finalOracle = (Configuration)maxViolPair.second;
        }
        Object[] predictedFeatures = new Object[this.featureLength];
        Object[] oracleFeatures = new Object[this.featureLength];
        for (int f2 = 0; f2 < predictedFeatures.length; ++f2) {
            oracleFeatures[f2] = new HashMap();
            predictedFeatures[f2] = new HashMap();
        }
        Configuration predictedConfiguration = initialConfiguration.clone();
        Configuration oracleConfiguration = initialConfiguration.clone();
        for (int action : finalOracle.actionHistory) {
            isTrueFeature = ArcEagerBeamTrainer.isTrueFeature(isPartial, oracleConfiguration, action);
            if (isTrueFeature) {
                feats = FeatureExtractor.extractAllParseFeatures(oracleConfiguration, this.featureLength);
                for (f = 0; f < feats.length; ++f) {
                    map = (HashMap)oracleFeatures[f];
                    featName = new Pair<Integer, Object>(action, feats[f]);
                    value = (Float)map.get(featName);
                    if (value == null) {
                        map.put(featName, Float.valueOf(1.0f));
                        continue;
                    }
                    map.put(featName, Float.valueOf(value.floatValue() + 1.0f));
                }
            }
            if (action == 0) {
                ArcEager.shift(oracleConfiguration.state);
                continue;
            }
            if (action == 1) {
                ArcEager.reduce(oracleConfiguration.state);
                continue;
            }
            if (action >= 3 + this.dependencyRelations.size()) {
                int dependency = action - (3 + this.dependencyRelations.size());
                ArcEager.leftArc(oracleConfiguration.state, dependency);
                continue;
            }
            if (action < 3) continue;
            int dependency = action - 3;
            ArcEager.rightArc(oracleConfiguration.state, dependency);
        }
        for (int action : predicted.actionHistory) {
            int dependency;
            isTrueFeature = ArcEagerBeamTrainer.isTrueFeature(isPartial, predictedConfiguration, action);
            if (isTrueFeature) {
                feats = FeatureExtractor.extractAllParseFeatures(predictedConfiguration, this.featureLength);
                if (action != 2) {
                    for (f = 0; f < feats.length; ++f) {
                        map = (HashMap)predictedFeatures[f];
                        featName = new Pair<Integer, Object>(action, feats[f]);
                        value = (Float)map.get(featName);
                        if (value == null) {
                            map.put(featName, Float.valueOf(1.0f));
                            continue;
                        }
                        map.put(featName, Float.valueOf(((Float)map.get(featName)).floatValue() + 1.0f));
                    }
                }
            }
            State state = predictedConfiguration.state;
            if (action == 0) {
                ArcEager.shift(state);
                continue;
            }
            if (action == 1) {
                ArcEager.reduce(state);
                continue;
            }
            if (action >= 3 + this.dependencyRelations.size()) {
                dependency = action - (3 + this.dependencyRelations.size());
                ArcEager.leftArc(state, dependency);
                continue;
            }
            if (action >= 3) {
                dependency = action - 3;
                ArcEager.rightArc(state, dependency);
                continue;
            }
            if (action != 2) continue;
            ArcEager.unShift(state);
        }
        for (int f3 = 0; f3 < predictedFeatures.length; ++f3) {
            Object feature;
            int dependency;
            Action actionType;
            LabeledAction labeledAction;
            int action;
            HashMap map2 = (HashMap)predictedFeatures[f3];
            HashMap map22 = (HashMap)oracleFeatures[f3];
            for (Pair feat : map2.keySet()) {
                action = (Integer)feat.first;
                labeledAction = new LabeledAction(action, this.dependencyRelations.size());
                actionType = labeledAction.action;
                dependency = labeledAction.label;
                if (feat.second == null) continue;
                feature = feat.second;
                if (map22.containsKey(feat) && ((Float)map22.get(feat)).equals(map2.get(feat))) continue;
                this.classifier.changeWeight(actionType, f3, feature, dependency, -((Float)map2.get(feat)).floatValue());
            }
            for (Pair feat : map22.keySet()) {
                action = (Integer)feat.first;
                labeledAction = new LabeledAction(action, this.dependencyRelations.size());
                actionType = labeledAction.action;
                dependency = labeledAction.label;
                if (feat.second == null) continue;
                feature = feat.second;
                if (map2.containsKey(feat) && ((Float)map2.get(feat)).equals(map22.get(feat))) continue;
                this.classifier.changeWeight(actionType, f3, feature, dependency, ((Float)map22.get(feat)).floatValue());
            }
        }
    }

    private static boolean isTrueFeature(boolean isPartial, Configuration oracleConfiguration, int action) {
        boolean isTrueFeature = true;
        if (isPartial && action >= 3) {
            if (!oracleConfiguration.state.hasHead(oracleConfiguration.state.stackTop()) || !oracleConfiguration.state.hasHead(oracleConfiguration.state.bufferHead())) {
                isTrueFeature = false;
            }
        } else if (isPartial && action == 0) {
            if (!oracleConfiguration.state.hasHead(oracleConfiguration.state.bufferHead())) {
                isTrueFeature = false;
            }
        } else if (isPartial && action == 1 && !oracleConfiguration.state.hasHead(oracleConfiguration.state.stackTop())) {
            isTrueFeature = false;
        }
        return isTrueFeature;
    }
}

