package uk.ac.susx.mlcl.byblo.tasks;

import com.google.common.base.Objects;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import it.unimi.dsi.fastutil.ints.Int2DoubleMap;
import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import uk.ac.susx.mlcl.byblo.io.TokenPair;
import uk.ac.susx.mlcl.byblo.io.Weighted;
import uk.ac.susx.mlcl.byblo.measures.DecomposableMeasure;
import uk.ac.susx.mlcl.byblo.measures.Measure;
import uk.ac.susx.mlcl.byblo.measures.impl.Jaccard;
import uk.ac.susx.mlcl.lib.Checks;
import uk.ac.susx.mlcl.lib.collect.Indexed;
import uk.ac.susx.mlcl.lib.collect.SparseDoubleVector;
import uk.ac.susx.mlcl.lib.events.ProgressDelegate;
import uk.ac.susx.mlcl.lib.events.ProgressListener;
import uk.ac.susx.mlcl.lib.events.ProgressReporting;
import uk.ac.susx.mlcl.lib.io.ObjectIO;
import uk.ac.susx.mlcl.lib.io.ObjectSink;
import uk.ac.susx.mlcl.lib.io.SeekableObjectSource;
import uk.ac.susx.mlcl.lib.tasks.AbstractTask;

/* loaded from: input_file:uk/ac/susx/mlcl/byblo/tasks/NaiveApssTask.class */
public class NaiveApssTask<P> extends AbstractTask implements ProgressReporting {
    private static final Log LOG = LogFactory.getLog(NaiveApssTask.class);
    private static final Measure DEFAULT_MEASURE = new Jaccard();
    private SeekableObjectSource<Indexed<SparseDoubleVector>, P> sourceA;
    private SeekableObjectSource<Indexed<SparseDoubleVector>, P> sourceB;
    private ObjectSink<Weighted<TokenPair>> sink;
    final ProgressDelegate progress = new ProgressDelegate(this, true);
    private Measure measure = DEFAULT_MEASURE;
    private Predicate<Indexed<SparseDoubleVector>> processRecord = Predicates.alwaysTrue();
    private Predicate<Weighted<TokenPair>> producePair = Predicates.alwaysTrue();
    private ApssStats stats = new ApssStats();
    private Int2DoubleMap preCalcA = null;
    private Int2DoubleMap preCalcB = null;
    final int PAIR_OUTPUT_BUFFER_SIZE = 100000;

    /* JADX INFO: Access modifiers changed from: package-private */
    public NaiveApssTask(SeekableObjectSource<Indexed<SparseDoubleVector>, P> seekableObjectSource, SeekableObjectSource<Indexed<SparseDoubleVector>, P> seekableObjectSource2, ObjectSink<Weighted<TokenPair>> objectSink) {
        setSourceA(seekableObjectSource);
        setSourceB(seekableObjectSource2);
        setSink(objectSink);
    }

    public NaiveApssTask() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Predicate<Weighted<TokenPair>> getProducePair() {
        return this.producePair;
    }

    public void setProducePair(Predicate<Weighted<TokenPair>> predicate) {
        Checks.checkNotNull("producePair");
        this.producePair = predicate;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Predicate<Indexed<SparseDoubleVector>> getProcessRecord() {
        return this.processRecord;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setProcessRecord(Predicate<Indexed<SparseDoubleVector>> predicate) {
        Checks.checkNotNull("processRecord");
        this.processRecord = predicate;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final ApssStats getStats() {
        return this.stats;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final void setStats(ApssStats apssStats) {
        Checks.checkNotNull("stats");
        this.stats = apssStats;
    }

    public final void setSourceA(SeekableObjectSource<Indexed<SparseDoubleVector>, P> seekableObjectSource) {
        if (seekableObjectSource == null) {
            throw new NullPointerException("sourceA is null");
        }
        if (seekableObjectSource == this.sourceB) {
            throw new IllegalArgumentException("sourceA == sourceB");
        }
        this.sourceA = seekableObjectSource;
    }

    public final void setSourceB(SeekableObjectSource<Indexed<SparseDoubleVector>, P> seekableObjectSource) {
        if (seekableObjectSource == null) {
            throw new NullPointerException("sourceB is null");
        }
        if (this.sourceA == seekableObjectSource) {
            throw new IllegalArgumentException("sourceA == sourceB");
        }
        this.sourceB = seekableObjectSource;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final SeekableObjectSource<Indexed<SparseDoubleVector>, P> getSourceA() {
        return this.sourceA;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final SeekableObjectSource<Indexed<SparseDoubleVector>, P> getSourceB() {
        return this.sourceB;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final Measure getMeasure() {
        return this.measure;
    }

    public final void setMeasure(Measure measure) {
        Checks.checkNotNull("measure", measure);
        this.measure = measure;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final ObjectSink<Weighted<TokenPair>> getSink() {
        return this.sink;
    }

    public final void setSink(ObjectSink<Weighted<TokenPair>> objectSink) {
        if (objectSink == null) {
            throw new NullPointerException("handler == null");
        }
        this.sink = objectSink;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // uk.ac.susx.mlcl.lib.tasks.AbstractTask
    public void initialiseTask() throws Exception {
        checkState();
        buildPreCalcs();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void writeOutPairs(List<Weighted<TokenPair>> list) throws IOException {
        if (list.isEmpty()) {
            return;
        }
        Collections.sort(list, Weighted.recordOrder(TokenPair.indexOrder()));
        synchronized (getSink()) {
            ObjectIO.copy(list, getSink());
        }
        list.clear();
    }

    @Override // uk.ac.susx.mlcl.lib.tasks.AbstractTask
    protected void runTask() throws Exception {
        ArrayList arrayList = new ArrayList();
        P position = getSourceB().position();
        this.progress.startAdjusting();
        this.progress.setState(ProgressReporting.State.RUNNING);
        this.progress.setMessage("Running all-pairs.");
        this.progress.setProgressPercent(0);
        this.progress.endAdjusting();
        while (getSourceA().hasNext()) {
            Indexed<SparseDoubleVector> read = getSourceA().read();
            if (this.processRecord.apply(read)) {
                if (this.sourceB.position() != position) {
                    this.sourceB.position(position);
                }
                while (getSourceB().hasNext()) {
                    this.stats.incrementCandidatesCount();
                    Indexed<SparseDoubleVector> read2 = this.sourceB.read();
                    if (this.processRecord.apply(read2)) {
                        Weighted<TokenPair> weighted = new Weighted<>(new TokenPair(read2.key(), read.key()), sim(read, read2));
                        if (this.producePair.apply(weighted)) {
                            arrayList.add(weighted);
                            this.stats.incrementProductionCount();
                            if (arrayList.size() > 100000) {
                                writeOutPairs(arrayList);
                            }
                        }
                    }
                }
            }
        }
        writeOutPairs(arrayList);
        this.progress.startAdjusting();
        this.progress.setProgressPercent(100);
        this.progress.setState(ProgressReporting.State.COMPLETED);
        this.progress.endAdjusting();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // uk.ac.susx.mlcl.lib.tasks.AbstractTask
    public void finaliseTask() throws Exception {
        this.preCalcA = null;
        this.preCalcB = null;
    }

    void checkState() throws IOException {
        if (this.sourceA == null) {
            throw new IllegalStateException("source A is not set");
        }
        if (!this.sourceA.hasNext()) {
            throw new IllegalStateException("source A is exhausted");
        }
        if (this.sourceB == null) {
            throw new IllegalStateException("source B is not set");
        }
        if (!this.sourceB.hasNext()) {
            throw new IllegalStateException("source B is exhausted");
        }
        if (this.sourceA == this.sourceB) {
            throw new IllegalArgumentException("sourceA == sourceB");
        }
        if (this.sink == null) {
            throw new IllegalStateException("sink (destination) is not set");
        }
        if (this.measure == null) {
            throw new IllegalStateException("measure is not set");
        }
        if (this.processRecord == null) {
            throw new NullPointerException("recordFilter == null");
        }
        if (this.producePair == null) {
            throw new NullPointerException("pairFilter == null");
        }
    }

    void buildPreCalcs() throws IOException {
        if (getMeasure() instanceof DecomposableMeasure) {
            if (this.preCalcA == null) {
                this.preCalcA = buildPrecalcA();
            }
            if (this.preCalcB == null) {
                this.preCalcB = buildPrecalcB();
            }
        }
    }

    protected Int2DoubleMap getPreCalcA() {
        return this.preCalcA;
    }

    protected Int2DoubleMap getPreCalcB() {
        return this.preCalcB;
    }

    Int2DoubleMap buildPrecalcA() throws IOException {
        if (!(getMeasure() instanceof DecomposableMeasure)) {
            return null;
        }
        DecomposableMeasure decomposableMeasure = (DecomposableMeasure) getMeasure();
        P position = this.sourceA.position();
        Int2DoubleOpenHashMap int2DoubleOpenHashMap = new Int2DoubleOpenHashMap();
        while (this.sourceA.hasNext()) {
            Indexed<SparseDoubleVector> read = this.sourceA.read();
            int2DoubleOpenHashMap.put(read.key(), decomposableMeasure.left(read.value()));
        }
        this.sourceA.position(position);
        return int2DoubleOpenHashMap;
    }

    Int2DoubleMap buildPrecalcB() throws IOException {
        if (!(getMeasure() instanceof DecomposableMeasure)) {
            return null;
        }
        DecomposableMeasure decomposableMeasure = (DecomposableMeasure) getMeasure();
        P position = this.sourceB.position();
        Int2DoubleOpenHashMap int2DoubleOpenHashMap = new Int2DoubleOpenHashMap();
        while (this.sourceB.hasNext()) {
            Indexed<SparseDoubleVector> read = this.sourceB.read();
            int2DoubleOpenHashMap.put(read.key(), decomposableMeasure.right(read.value()));
        }
        this.sourceB.position(position);
        return int2DoubleOpenHashMap;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final double sim(Indexed<SparseDoubleVector> indexed, Indexed<SparseDoubleVector> indexed2) {
        this.stats.incrementComparisonCount();
        if (!(this.measure instanceof DecomposableMeasure)) {
            return this.measure.similarity(indexed.value(), indexed2.value());
        }
        DecomposableMeasure decomposableMeasure = (DecomposableMeasure) getMeasure();
        return decomposableMeasure.combine(decomposableMeasure.shared(indexed.value(), indexed2.value()), this.preCalcA.get(indexed.key()), this.preCalcB.get(indexed2.key()));
    }

    @Override // uk.ac.susx.mlcl.lib.events.ProgressReporting
    public void removeProgressListener(ProgressListener progressListener) {
        this.progress.removeProgressListener(progressListener);
    }

    @Override // uk.ac.susx.mlcl.lib.events.ProgressReporting
    public boolean isProgressPercentageSupported() {
        return this.progress.isProgressPercentageSupported();
    }

    @Override // uk.ac.susx.mlcl.lib.events.ProgressReporting
    public String getProgressReport() {
        return this.progress.getProgressReport();
    }

    @Override // uk.ac.susx.mlcl.lib.events.ProgressReporting
    public int getProgressPercent() {
        return this.progress.getProgressPercent();
    }

    @Override // uk.ac.susx.mlcl.lib.events.ProgressReporting
    public ProgressListener[] getProgressListeners() {
        return this.progress.getProgressListeners();
    }

    public String getName() {
        return "naive-allpairs";
    }

    @Override // uk.ac.susx.mlcl.lib.events.ProgressReporting
    public void addProgressListener(ProgressListener progressListener) {
        this.progress.addProgressListener(progressListener);
    }

    @Override // uk.ac.susx.mlcl.lib.events.ProgressReporting
    public ProgressReporting.State getState() {
        return this.progress.getState();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // uk.ac.susx.mlcl.lib.tasks.AbstractTask
    public Objects.ToStringHelper toStringHelper() {
        return super.toStringHelper().add("sourceA", this.sourceA).add("sourceB", this.sourceB).add("measure", this.measure).add("sink", this.sink).add("processRecord", this.processRecord).add("producePair", this.producePair).add("stats", this.stats);
    }
}
