package uk.ac.susx.mlcl.byblo.tasks;

import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import uk.ac.susx.mlcl.byblo.io.TokenPair;
import uk.ac.susx.mlcl.byblo.io.Weighted;
import uk.ac.susx.mlcl.lib.Checks;
import uk.ac.susx.mlcl.lib.collect.Indexed;
import uk.ac.susx.mlcl.lib.collect.SparseDoubleVector;
import uk.ac.susx.mlcl.lib.events.ProgressReporting;
import uk.ac.susx.mlcl.lib.io.SeekableObjectSource;

/* loaded from: input_file:uk/ac/susx/mlcl/byblo/tasks/InvertedApssTask.class */
public final class InvertedApssTask<S> extends NaiveApssTask<S> {
    private static final Log LOG = LogFactory.getLog(InvertedApssTask.class);
    private Int2ObjectMap<Set<Indexed<SparseDoubleVector>>> index = null;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // uk.ac.susx.mlcl.byblo.tasks.NaiveApssTask, uk.ac.susx.mlcl.lib.tasks.AbstractTask
    public void initialiseTask() throws Exception {
        super.initialiseTask();
        if (this.index == null) {
            this.index = buildIndex();
        }
    }

    @Override // uk.ac.susx.mlcl.byblo.tasks.NaiveApssTask, uk.ac.susx.mlcl.lib.tasks.AbstractTask
    protected void runTask() throws IOException {
        this.progress.startAdjusting();
        this.progress.setState(ProgressReporting.State.RUNNING);
        this.progress.setMessage("Running inverted all-pairs.");
        this.progress.setProgressPercent(0);
        this.progress.endAdjusting();
        S position = getSourceB().position();
        ArrayList arrayList = new ArrayList();
        while (getSourceB().hasNext()) {
            Indexed<SparseDoubleVector> read = getSourceB().read();
            if (getProcessRecord().apply(read)) {
                for (Indexed<SparseDoubleVector> indexed : findCandidates(read)) {
                    if (getProcessRecord().apply(indexed)) {
                        getStats().incrementCandidatesCount();
                        Weighted<TokenPair> weighted = new Weighted<>(new TokenPair(read.key(), indexed.key()), sim(indexed, read));
                        if (getProducePair().apply(weighted)) {
                            arrayList.add(weighted);
                            getStats().incrementProductionCount();
                            if (arrayList.size() > 100000) {
                                writeOutPairs(arrayList);
                            }
                        }
                    }
                }
            }
        }
        writeOutPairs(arrayList);
        getSourceB().position(position);
        this.progress.startAdjusting();
        this.progress.setState(ProgressReporting.State.COMPLETED);
        this.progress.setProgressPercent(100);
        this.progress.endAdjusting();
    }

    Set<Indexed<SparseDoubleVector>> findCandidates(Indexed<SparseDoubleVector> indexed) {
        ObjectOpenHashSet objectOpenHashSet = new ObjectOpenHashSet();
        for (int i : indexed.value().keys) {
            if (this.index.containsKey(i)) {
                objectOpenHashSet.addAll((Collection) this.index.get(i));
            }
        }
        return objectOpenHashSet;
    }

    Int2ObjectMap<Set<Indexed<SparseDoubleVector>>> buildIndex() throws IOException {
        SeekableObjectSource<Indexed<SparseDoubleVector>, S> sourceA = getSourceA();
        Int2ObjectOpenHashMap int2ObjectOpenHashMap = new Int2ObjectOpenHashMap();
        S position = sourceA.position();
        while (sourceA.hasNext()) {
            Indexed<SparseDoubleVector> read = sourceA.read();
            for (int i : read.value().keys) {
                if (!int2ObjectOpenHashMap.containsKey(i)) {
                    int2ObjectOpenHashMap.put(i, new ObjectOpenHashSet());
                }
                ((Set) int2ObjectOpenHashMap.get(i)).add(read);
            }
        }
        sourceA.position(position);
        return int2ObjectOpenHashMap;
    }

    void setIndex(Int2ObjectMap<Set<Indexed<SparseDoubleVector>>> int2ObjectMap) {
        Checks.checkNotNull("index is null", int2ObjectMap);
        this.index = int2ObjectMap;
    }

    Int2ObjectMap<Set<Indexed<SparseDoubleVector>>> getIndex() {
        return this.index;
    }

    @Override // uk.ac.susx.mlcl.byblo.tasks.NaiveApssTask, uk.ac.susx.mlcl.lib.events.ProgressReporting
    public String getName() {
        return "inverted-allpairs";
    }
}
