package uk.ac.susx.mlcl.byblo.commands;

import com.beust.jcommander.Parameter;
import com.beust.jcommander.Parameters;
import com.beust.jcommander.ParametersDelegate;
import com.google.common.base.Objects;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import it.unimi.dsi.fastutil.ints.Int2DoubleMap;
import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntIterator;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Map;
import javax.annotation.CheckReturnValue;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import uk.ac.susx.mlcl.byblo.BybloSettings;
import uk.ac.susx.mlcl.byblo.enumerators.DoubleEnumerating;
import uk.ac.susx.mlcl.byblo.enumerators.DoubleEnumeratingDelegate;
import uk.ac.susx.mlcl.byblo.enumerators.EnumeratingDelegates;
import uk.ac.susx.mlcl.byblo.enumerators.EnumeratorType;
import uk.ac.susx.mlcl.byblo.io.BybloIO;
import uk.ac.susx.mlcl.byblo.io.FastWeightedTokenPairVectorSource;
import uk.ac.susx.mlcl.byblo.io.Token;
import uk.ac.susx.mlcl.byblo.io.TokenPair;
import uk.ac.susx.mlcl.byblo.io.Weighted;
import uk.ac.susx.mlcl.byblo.io.WeightedTokenPairSink;
import uk.ac.susx.mlcl.byblo.io.WeightedTokenSource;
import uk.ac.susx.mlcl.byblo.measures.Measure;
import uk.ac.susx.mlcl.byblo.measures.Measures;
import uk.ac.susx.mlcl.byblo.measures.impl.KendallsTau;
import uk.ac.susx.mlcl.byblo.measures.impl.KullbackLeiblerDivergence;
import uk.ac.susx.mlcl.byblo.measures.impl.LambdaDivergence;
import uk.ac.susx.mlcl.byblo.measures.impl.LeeSkewDivergence;
import uk.ac.susx.mlcl.byblo.measures.impl.LpSpaceDistance;
import uk.ac.susx.mlcl.byblo.measures.impl.Weeds;
import uk.ac.susx.mlcl.byblo.tasks.InvertedApssTask;
import uk.ac.susx.mlcl.byblo.tasks.NaiveApssTask;
import uk.ac.susx.mlcl.byblo.tasks.ThreadedApssTask;
import uk.ac.susx.mlcl.byblo.weighings.FeatureMarginalsCarrier;
import uk.ac.susx.mlcl.byblo.weighings.MarginalDistribution;
import uk.ac.susx.mlcl.byblo.weighings.Weighting;
import uk.ac.susx.mlcl.byblo.weighings.Weightings;
import uk.ac.susx.mlcl.byblo.weighings.impl.NullWeighting;
import uk.ac.susx.mlcl.byblo.weighings.impl.Step;
import uk.ac.susx.mlcl.lib.Checks;
import uk.ac.susx.mlcl.lib.commands.AbstractCommand;
import uk.ac.susx.mlcl.lib.commands.DoubleConverter;
import uk.ac.susx.mlcl.lib.commands.FileDelegate;
import uk.ac.susx.mlcl.lib.commands.InputFileValidator;
import uk.ac.susx.mlcl.lib.commands.OutputFileValidator;
import uk.ac.susx.mlcl.lib.events.ReportLoggingProgressListener;
import uk.ac.susx.mlcl.lib.io.ObjectIO;
import uk.ac.susx.mlcl.lib.io.ObjectSource;

@Parameters(commandDescription = "Perform all-pair similarity search on the given input frequency files.")
/* loaded from: input_file:uk/ac/susx/mlcl/byblo/commands/AllPairsCommand.class */
public class AllPairsCommand extends AbstractCommand {
    private static final Log LOG = LogFactory.getLog(AllPairsCommand.class);

    @ParametersDelegate
    private DoubleEnumerating indexDelegate;

    @ParametersDelegate
    private final FileDelegate fileDelegate;

    @Parameter(names = {"-i", "--input"}, description = "Event frequency vectors files.", required = true, validateWith = InputFileValidator.class)
    private File eventsFile;

    @Parameter(names = {"-if", "--input-features"}, description = "Feature frequencies file", validateWith = InputFileValidator.class)
    private File featuresFile;

    @Parameter(names = {"-ie", "--input-entries"}, description = "Entry frequencies file", validateWith = InputFileValidator.class)
    private File entriesFile;

    @Parameter(names = {"-o", "--output"}, description = "Output similarity matrix file.", required = true, validateWith = OutputFileValidator.class)
    private File outputFile;

    @Parameter(names = {"-t", "--threads"}, description = "Number of concurrent processing threads.")
    private int numThreads;
    public static final double DEFAULT_MIN_SIMILARITY = Double.NEGATIVE_INFINITY;
    public static final double DEFAULT_MAX_SIMILARITY = Double.POSITIVE_INFINITY;

    @Parameter(names = {"-Smn", "--similarity-min"}, description = "Minimum similarity threshold.", converter = DoubleConverter.class)
    private double minSimilarity;

    @Parameter(names = {"-Smx", "--similarity-max"}, description = "Maximum similarity threshold.", converter = DoubleConverter.class)
    private double maxSimilarity;

    @Parameter(names = {"-ip", "--identity-pairs"}, description = "Produce similarity between pair of identical entries.")
    private boolean outputIdentityPairs;
    public static final String DEFAULT_MEASURE = "Lin";

    @Parameter(names = {"-m", "--measure"}, description = "Similarity measure to use.")
    private String measureName;

    @Parameter(names = {"--measure-reversed"}, description = "Swap similarity measure inputs.")
    private boolean measureReversed;

    @Parameter(names = {"--lee-alpha"}, description = "Alpha parameter to Lee's alpha-skew divergence measure.", converter = DoubleConverter.class)
    private double leeAlpha;

    @Parameter(names = {"--crmi-beta"}, description = "Beta parameter to Weed's CRMI measure.", converter = DoubleConverter.class)
    private double crmiBeta;

    @Parameter(names = {"--crmi-gamma"}, description = "Gamma parameter to Weed's CRMI measure.", converter = DoubleConverter.class)
    private double crmiGamma;

    @Parameter(names = {"--mink-p"}, description = "P parameter to Minkowski/Lp space measure.", converter = DoubleConverter.class)
    private double minkP;

    @Parameter(names = {"--lambda-lambda"}, description = "lambda parameter to Lambda-Divergence measure.", converter = DoubleConverter.class)
    private double lambdaLambda;
    private Weighting weighting;

    @Parameter(names = {"--algorithm"}, hidden = true, description = "APPS algorithm to use.")
    private Algorithm algorithm;

    /* loaded from: input_file:uk/ac/susx/mlcl/byblo/commands/AllPairsCommand$Algorithm.class */
    public enum Algorithm {
        Naive(NaiveApssTask.class),
        Inverted(InvertedApssTask.class);

        private final Class<? extends NaiveApssTask> implementation;

        Algorithm(Class cls) {
            this.implementation = cls;
        }

        public Class<? extends NaiveApssTask> getImplementation() {
            return this.implementation;
        }

        public NaiveApssTask newInstance() throws InstantiationException, IllegalAccessException {
            return getImplementation().newInstance();
        }
    }

    public AllPairsCommand(File file, File file2, File file3, File file4, Charset charset, DoubleEnumerating doubleEnumerating) {
        this.indexDelegate = new DoubleEnumeratingDelegate();
        this.fileDelegate = new FileDelegate();
        this.numThreads = Runtime.getRuntime().availableProcessors() + 1;
        this.minSimilarity = Double.NEGATIVE_INFINITY;
        this.maxSimilarity = Double.POSITIVE_INFINITY;
        this.outputIdentityPairs = false;
        this.measureName = DEFAULT_MEASURE;
        this.measureReversed = false;
        this.leeAlpha = 0.99d;
        this.crmiBeta = 0.5d;
        this.crmiGamma = 0.5d;
        this.minkP = 2.0d;
        this.lambdaLambda = 0.5d;
        this.weighting = new NullWeighting();
        this.algorithm = Algorithm.Inverted;
        setEventsFile(file3);
        setEntriesFile(file);
        setFeaturesFile(file2);
        setOutputFile(file4);
        setCharset(charset);
        this.indexDelegate = doubleEnumerating;
    }

    public AllPairsCommand() {
        this.indexDelegate = new DoubleEnumeratingDelegate();
        this.fileDelegate = new FileDelegate();
        this.numThreads = Runtime.getRuntime().availableProcessors() + 1;
        this.minSimilarity = Double.NEGATIVE_INFINITY;
        this.maxSimilarity = Double.POSITIVE_INFINITY;
        this.outputIdentityPairs = false;
        this.measureName = DEFAULT_MEASURE;
        this.measureReversed = false;
        this.leeAlpha = 0.99d;
        this.crmiBeta = 0.5d;
        this.crmiGamma = 0.5d;
        this.minkP = 2.0d;
        this.lambdaLambda = 0.5d;
        this.weighting = new NullWeighting();
        this.algorithm = Algorithm.Inverted;
    }

    @Override // uk.ac.susx.mlcl.lib.commands.AbstractCommand, uk.ac.susx.mlcl.lib.commands.Command
    @CheckReturnValue
    public boolean runCommand() {
        try {
            if (LOG.isInfoEnabled()) {
                LOG.info("Running all-pairs similarity.");
            }
            Measure newInstance = getMeasureClass().newInstance();
            if (newInstance instanceof LpSpaceDistance) {
                ((LpSpaceDistance) newInstance).setPower(getMinkP());
            }
            if (newInstance instanceof LeeSkewDivergence) {
                ((LeeSkewDivergence) newInstance).setAlpha(getLeeAlpha());
            }
            if (newInstance instanceof Weeds) {
                ((Weeds) newInstance).setBeta(getCrmiBeta());
                ((Weeds) newInstance).setGamma(getCrmiGamma());
            }
            if (newInstance instanceof LambdaDivergence) {
                ((LambdaDivergence) newInstance).setLambda(getLambdaLambda());
            }
            if (this.weighting.getClass().equals(NullWeighting.class)) {
                this.weighting = newInstance.getExpectedWeighting().newInstance();
            } else {
                this.weighting = Weightings.compose(this.weighting, newInstance.getExpectedWeighting().newInstance());
            }
            if ((newInstance instanceof FeatureMarginalsCarrier) || (this.weighting instanceof FeatureMarginalsCarrier)) {
                if (LOG.isInfoEnabled()) {
                    LOG.info("Loading features file " + getFeaturesFile());
                }
                MarginalDistribution readMarginalDistribution = BybloIO.readMarginalDistribution(openFeaturesSource());
                if (newInstance instanceof FeatureMarginalsCarrier) {
                    ((FeatureMarginalsCarrier) newInstance).setFeatureMarginals(readMarginalDistribution);
                }
                if (this.weighting instanceof FeatureMarginalsCarrier) {
                    ((FeatureMarginalsCarrier) this.weighting).setFeatureMarginals(readMarginalDistribution);
                }
            } else if ((newInstance instanceof KendallsTau) || (newInstance instanceof KullbackLeiblerDivergence)) {
                if (LOG.isInfoEnabled()) {
                    LOG.info("Loading entries file for KendallsTau.minCardinality: " + getFeaturesFile());
                }
                WeightedTokenSource.WTStatsSource wTStatsSource = new WeightedTokenSource.WTStatsSource(openFeaturesSource());
                ObjectIO.copy(wTStatsSource, ObjectIO.nullSink());
                int maxId = wTStatsSource.getMaxId() + 1;
                if (newInstance instanceof KendallsTau) {
                    ((KendallsTau) newInstance).setMinCardinality(maxId);
                }
                if (newInstance instanceof KullbackLeiblerDivergence) {
                    ((KullbackLeiblerDivergence) newInstance).setMinCardinality(maxId);
                }
            }
            if (isMeasureReversed()) {
                newInstance = Measures.reverse(newInstance);
            }
            Measure autoWeighted = Measures.autoWeighted(newInstance, this.weighting);
            FastWeightedTokenPairVectorSource openEventsSource = openEventsSource();
            FastWeightedTokenPairVectorSource openEventsSource2 = openEventsSource();
            WeightedTokenPairSink openSimsSink = openSimsSink();
            NaiveApssTask newAlgorithmInstance = newAlgorithmInstance();
            newAlgorithmInstance.setSourceA(openEventsSource);
            newAlgorithmInstance.setSourceB(openEventsSource2);
            newAlgorithmInstance.setSink(openSimsSink);
            newAlgorithmInstance.setMeasure(autoWeighted);
            newAlgorithmInstance.setProducePair(getProductionFilter());
            newAlgorithmInstance.addProgressListener(new ReportLoggingProgressListener(LOG));
            newAlgorithmInstance.run();
            openSimsSink.flush();
            openSimsSink.close();
            openEventsSource.close();
            openEventsSource2.close();
            if (newAlgorithmInstance.isExceptionTrapped()) {
                newAlgorithmInstance.throwTrappedException();
            }
            if (this.indexDelegate.isEnumeratorOpen()) {
                this.indexDelegate.saveEnumerator();
                this.indexDelegate.closeEnumerator();
            }
            if (!LOG.isInfoEnabled()) {
                return true;
            }
            LOG.info("Completed all-pairs similarity search.");
            return true;
        } catch (IOException e) {
            throw new RuntimeException(e);
        } catch (ClassNotFoundException e2) {
            throw new RuntimeException(e2);
        } catch (IllegalAccessException e3) {
            throw new RuntimeException(e3);
        } catch (InstantiationException e4) {
            throw new RuntimeException(e4);
        } catch (Exception e5) {
            throw new RuntimeException(e5);
        }
    }

    public static double[] readAllAsArray(ObjectSource<Weighted<Token>> objectSource) throws IOException {
        Int2DoubleOpenHashMap int2DoubleOpenHashMap = new Int2DoubleOpenHashMap();
        while (objectSource.hasNext()) {
            Weighted<Token> read = objectSource.read();
            if (int2DoubleOpenHashMap.containsKey(read.record().id())) {
                int id = read.record().id();
                double d = int2DoubleOpenHashMap.get(id);
                double weight = d + read.weight();
                if (LOG.isWarnEnabled()) {
                    LOG.warn("Found duplicate Entry \"" + read.record().id() + "\" (id=" + id + ") in entries file. Merging records. Old frequency = " + d + ", new frequency = " + weight + ".");
                }
                int2DoubleOpenHashMap.put(id, weight);
            } else {
                int2DoubleOpenHashMap.put(read.record().id(), read.weight());
            }
        }
        int i = 0;
        IntIterator it = int2DoubleOpenHashMap.keySet().iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            if (intValue > i) {
                i = intValue;
            }
        }
        double[] dArr = new double[i + 1];
        ObjectIterator it2 = int2DoubleOpenHashMap.int2DoubleEntrySet().iterator();
        while (it2.hasNext()) {
            Int2DoubleMap.Entry entry = (Int2DoubleMap.Entry) it2.next();
            dArr[entry.getIntKey()] = entry.getDoubleValue();
        }
        return dArr;
    }

    private NaiveApssTask newAlgorithmInstance() throws InstantiationException, IllegalAccessException {
        if (getNumThreads() == 1) {
            return getAlgorithm().newInstance();
        }
        ThreadedApssTask threadedApssTask = new ThreadedApssTask();
        threadedApssTask.setInnerAlgorithm(getAlgorithm().getImplementation());
        threadedApssTask.setNumThreads(getNumThreads());
        return threadedApssTask;
    }

    private WeightedTokenSource openFeaturesSource() throws IOException {
        return BybloIO.openFeaturesSource(getFeaturesFile(), getCharset(), EnumeratingDelegates.toSingleFeatures(getIndexDelegate()));
    }

    private FastWeightedTokenPairVectorSource openEventsSource() throws IOException {
        return BybloIO.openEventsVectorSource(getEventsFile(), getCharset(), getIndexDelegate());
    }

    private WeightedTokenPairSink openSimsSink() throws IOException {
        return BybloIO.openSimsSink(getOutputFile(), getCharset(), EnumeratingDelegates.toSingleEntries(getIndexDelegate()));
    }

    private Predicate<Weighted<TokenPair>> getProductionFilter() {
        ArrayList arrayList = new ArrayList();
        if (getMinSimilarity() != Double.NEGATIVE_INFINITY) {
            arrayList.add(Weighted.greaterThanOrEqualTo(getMinSimilarity()));
        }
        if (getMaxSimilarity() != Double.POSITIVE_INFINITY) {
            arrayList.add(Weighted.lessThanOrEqualTo(getMaxSimilarity()));
        }
        if (!isOutputIdentityPairs()) {
            arrayList.add(Predicates.not(Predicates.compose(TokenPair.identity(), Weighted.recordFunction())));
        }
        return arrayList.size() == 1 ? (Predicate) arrayList.get(0) : arrayList.size() > 1 ? Predicates.and(arrayList) : Predicates.alwaysTrue();
    }

    final Class<? extends Measure> getMeasureClass() throws ClassNotFoundException {
        Map<String, Class<? extends Measure>> loadMeasureAliasTable = Measures.loadMeasureAliasTable();
        String trim = getMeasureName().toLowerCase(BybloSettings.getLocale()).trim();
        return loadMeasureAliasTable.containsKey(trim) ? loadMeasureAliasTable.get(trim) : Class.forName(getMeasureName());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // uk.ac.susx.mlcl.lib.commands.AbstractCommand
    public Objects.ToStringHelper toStringHelper() {
        return super.toStringHelper().add("eventsIn", getEventsFile()).add("entriesIn", getEntriesFile()).add("featuresIn", getFeaturesFile()).add("simsOut", getOutputFile()).add("charset", getCharset()).add("threads", getNumThreads()).add("minSimilarity", getMinSimilarity()).add("maxSimilarity", getMaxSimilarity()).add("outputIdentityPairs", isOutputIdentityPairs()).add("measure", getMeasureName()).add("measureReversed", isMeasureReversed()).add("leeAlpha", getLeeAlpha()).add("crmiBeta", getCrmiBeta()).add("crmiGamma", getCrmiGamma()).add("minkP", getMinkP());
    }

    Algorithm getAlgorithm() {
        return this.algorithm;
    }

    public void setAlgorithm(Algorithm algorithm) {
        Checks.checkNotNull("algorithm", algorithm);
        this.algorithm = algorithm;
    }

    final File getEventsFile() {
        return this.eventsFile;
    }

    public final void setEventsFile(File file) {
        Checks.checkNotNull("eventsFile", file);
        this.eventsFile = file;
    }

    final File getFeaturesFile() {
        return this.featuresFile;
    }

    public final void setFeaturesFile(File file) {
        Checks.checkNotNull("featuresFile", file);
        this.featuresFile = file;
    }

    File getEntriesFile() {
        return this.entriesFile;
    }

    public final void setEntriesFile(File file) {
        Checks.checkNotNull("entriesFile", file);
        this.entriesFile = file;
    }

    final File getOutputFile() {
        return this.outputFile;
    }

    public final void setOutputFile(File file) {
        Checks.checkNotNull("outputFile", file);
        this.outputFile = file;
    }

    final Charset getCharset() {
        return this.fileDelegate.getCharset();
    }

    public final void setCharset(Charset charset) {
        Checks.checkNotNull("charset", charset);
        this.fileDelegate.setCharset(charset);
    }

    final int getNumThreads() {
        return this.numThreads;
    }

    public final void setNumThreads(int i) {
        Checks.checkRangeIncl("nThreads", i, 1, Integer.MAX_VALUE);
        this.numThreads = i;
    }

    final double getMinSimilarity() {
        return this.minSimilarity;
    }

    public final void setMinSimilarity(double d) {
        this.minSimilarity = d;
    }

    final double getMaxSimilarity() {
        return this.maxSimilarity;
    }

    public final void setMaxSimilarity(double d) {
        Checks.checkRangeIncl("maxSimilarity", d, Step.DEFAULT_BOUNDARY, Double.POSITIVE_INFINITY);
        this.maxSimilarity = d;
    }

    final boolean isOutputIdentityPairs() {
        return this.outputIdentityPairs;
    }

    public final void setOutputIdentityPairs(boolean z) {
        this.outputIdentityPairs = z;
    }

    final String getMeasureName() {
        return this.measureName;
    }

    public final void setMeasureName(String str) {
        Checks.checkNotNull("measureName", str);
        this.measureName = str;
    }

    final boolean isMeasureReversed() {
        return this.measureReversed;
    }

    public final void setMeasureReversed(boolean z) {
        this.measureReversed = z;
    }

    final double getLeeAlpha() {
        return this.leeAlpha;
    }

    public final void setLeeAlpha(double d) {
        Checks.checkRangeIncl("leeAlpha", d, Step.DEFAULT_BOUNDARY, 1.0d);
        this.leeAlpha = d;
    }

    final double getCrmiBeta() {
        return this.crmiBeta;
    }

    public final void setCrmiBeta(double d) {
        this.crmiBeta = d;
    }

    final double getCrmiGamma() {
        return this.crmiGamma;
    }

    public final void setCrmiGamma(double d) {
        this.crmiGamma = d;
    }

    final double getMinkP() {
        return this.minkP;
    }

    public final void setMinkP(double d) {
        this.minkP = d;
    }

    double getLambdaLambda() {
        return this.lambdaLambda;
    }

    public void setLambdaLambda(double d) {
        this.lambdaLambda = d;
    }

    final DoubleEnumerating getIndexDelegate() {
        return this.indexDelegate;
    }

    public final void setIndexDelegate(DoubleEnumerating doubleEnumerating) {
        Checks.checkNotNull("indexDelegate", doubleEnumerating);
        this.indexDelegate = doubleEnumerating;
    }

    public void setEnumeratedFeatures(boolean z) {
        this.indexDelegate.setEnumeratedFeatures(z);
    }

    public void setEnumeratedEntries(boolean z) {
        this.indexDelegate.setEnumeratedEntries(z);
    }

    public boolean isEnumeratedFeatures() {
        return this.indexDelegate.isEnumeratedFeatures();
    }

    public boolean isEnumeratedEntries() {
        return this.indexDelegate.isEnumeratedEntries();
    }

    public void setEnumeratorType(EnumeratorType enumeratorType) {
        this.indexDelegate.setEnumeratorType(enumeratorType);
    }

    public EnumeratorType getEnumeratorType() {
        return this.indexDelegate.getEnumeratorType();
    }
}
