package ai.fritz.vision.imagelabeling;

import ai.fritz.core.OutputTensor;
import ai.fritz.vision.FritzVisionImage;
import ai.fritz.vision.FritzVisionLabel;
import ai.fritz.vision.ImageInputTensor;
import ai.fritz.vision.base.FritzVisionRecordablePredictor;
import android.util.Size;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

/* loaded from: classes.dex */
public class FritzVisionLabelPredictor extends FritzVisionRecordablePredictor {
    private static final String TAG = "FritzVisionLabelPredictor";
    private ImageInputTensor inputTensor;
    private List<String> labels;
    private FritzVisionLabelPredictorOptions options;
    private OutputTensor outputTensor;

    public FritzVisionLabelPredictor(LabelingOnDeviceModel labelingOnDeviceModel, FritzVisionLabelPredictorOptions fritzVisionLabelPredictorOptions) {
        super(labelingOnDeviceModel, fritzVisionLabelPredictorOptions);
        this.inputTensor = new ImageInputTensor("Input Image", 0);
        this.outputTensor = new OutputTensor("Image Label Output", 0);
        this.options = fritzVisionLabelPredictorOptions;
        this.labels = labelingOnDeviceModel.getLabels();
        this.inputTensor.setupInputBuffer(this.interpreter);
        this.outputTensor.setupOutputBuffer(this.interpreter);
        this.inputSize = this.inputTensor.getImageDimensions();
    }

    private List<FritzVisionLabel> getLabelResults() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.labels.size(); i++) {
            String str = this.labels.get(i);
            float normalizedProbability = getNormalizedProbability(i);
            if (normalizedProbability >= this.options.confidenceThreshold) {
                arrayList.add(new FritzVisionLabel(str, normalizedProbability));
            }
        }
        Collections.sort(arrayList, new Comparator<FritzVisionLabel>() { // from class: ai.fritz.vision.imagelabeling.FritzVisionLabelPredictor.1
            @Override // java.util.Comparator
            public int compare(FritzVisionLabel fritzVisionLabel, FritzVisionLabel fritzVisionLabel2) {
                return Float.compare(fritzVisionLabel2.getConfidence(), fritzVisionLabel.getConfidence());
            }
        });
        return arrayList;
    }

    @Override // ai.fritz.vision.base.FritzVisionPredictor
    public Size getInputSize() {
        return this.inputSize;
    }

    public List<String> getLabels() {
        return this.labels;
    }

    protected float getNormalizedProbability(int i) {
        return this.outputTensor.is8BitQuantized() ? (this.outputTensor.getByte(i) & 255) / 255.0f : this.outputTensor.getFloat(i);
    }

    @Override // ai.fritz.vision.base.FritzVisionPredictor
    public FritzVisionLabelResult predict(FritzVisionImage fritzVisionImage) {
        this.inputTensor.preprocess(fritzVisionImage);
        this.outputTensor.rewind();
        this.interpreter.run(this.inputTensor.buffer, this.outputTensor.buffer);
        return new FritzVisionLabelResult(getLabelResults());
    }
}
