package ai.fritz.visionCV.rigidpose;

import ai.fritz.vision.ByteImage;
import ai.fritz.vision.FritzVisionImage;
import ai.fritz.vision.base.FritzVisionPredictor;
import ai.fritz.visionCV.FritzCVImage;
import android.graphics.PointF;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.HashMap;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Size;
import org.opencv.imgproc.Imgproc;
import org.tensorflow.lite.Tensor;

/* loaded from: classes.dex */
public class FritzVisionRigidPosePredictor extends FritzVisionPredictor {
    private static final int HEATMAP_IDX = 0;
    private static final int INPUT_IDX = 0;
    private static final int NUM_CHANNELS = 3;
    private static final int OFFSETS_IDX = 1;
    private static final String TAG = "FritzVisionRigidPosePredictor";
    private ByteBuffer inputByteBuffer;
    private Size inputCVSize;
    private byte[] inputData;
    private int minPartsOverThreshold;
    private int numKeypoints;
    private FritzVisionRigidPosePredictorOptions options;
    private Size outputCVSize;
    private ByteBuffer outputHeatmaps;
    private ByteBuffer outputOffsets;
    private android.util.Size outputSize;
    private float scoreThreshold;

    public FritzVisionRigidPosePredictor(RigidPoseOnDeviceModel rigidPoseOnDeviceModel) {
        this(rigidPoseOnDeviceModel, new FritzVisionRigidPosePredictorOptions());
    }

    public FritzVisionRigidPosePredictor(RigidPoseOnDeviceModel rigidPoseOnDeviceModel, FritzVisionRigidPosePredictorOptions fritzVisionRigidPosePredictorOptions) {
        super(rigidPoseOnDeviceModel, fritzVisionRigidPosePredictorOptions);
        this.options = fritzVisionRigidPosePredictorOptions;
        this.scoreThreshold = fritzVisionRigidPosePredictorOptions.confidenceThreshold;
        this.minPartsOverThreshold = fritzVisionRigidPosePredictorOptions.numKeypointsAboveThreshold;
        this.numKeypoints = rigidPoseOnDeviceModel.getNumKeypoints();
        Tensor inputTensor = this.interpreter.getInputTensor(0);
        this.inputSize = getSizeFromTensor(inputTensor);
        this.inputCVSize = new Size(this.inputSize.getWidth(), this.inputSize.getHeight());
        ByteBuffer allocateDirect = ByteBuffer.allocateDirect(inputTensor.dataType().byteSize() * inputTensor.numElements());
        this.inputByteBuffer = allocateDirect;
        allocateDirect.order(ByteOrder.nativeOrder());
        this.inputData = new byte[this.inputSize.getHeight() * this.inputSize.getWidth() * 3];
        Tensor outputTensor = this.interpreter.getOutputTensor(0);
        this.outputSize = getSizeFromTensor(outputTensor);
        this.outputCVSize = new Size(r7.getWidth(), this.outputSize.getHeight());
        ByteBuffer allocateDirect2 = ByteBuffer.allocateDirect(outputTensor.dataType().byteSize() * outputTensor.numElements());
        this.outputHeatmaps = allocateDirect2;
        allocateDirect2.order(ByteOrder.nativeOrder());
        Tensor outputTensor2 = this.interpreter.getOutputTensor(1);
        ByteBuffer allocateDirect3 = ByteBuffer.allocateDirect(outputTensor2.dataType().byteSize() * outputTensor2.numElements());
        this.outputOffsets = allocateDirect3;
        allocateDirect3.order(ByteOrder.nativeOrder());
    }

    private void preprocess(ByteImage byteImage) {
        this.inputByteBuffer.rewind();
        byte[] copyOfImageData = byteImage.getCopyOfImageData();
        for (int i = 0; i < copyOfImageData.length; i += 3) {
            float f = copyOfImageData[i] & 255;
            float f2 = copyOfImageData[i + 1] & 255;
            float f3 = copyOfImageData[i + 2] & 255;
            this.inputByteBuffer.putFloat((f / 255.0f) - 0.5f);
            this.inputByteBuffer.putFloat((f2 / 255.0f) - 0.5f);
            this.inputByteBuffer.putFloat((f3 / 255.0f) - 0.5f);
        }
    }

    private void rewindOutputs() {
        this.outputHeatmaps.rewind();
        this.outputOffsets.rewind();
    }

    @Override // ai.fritz.vision.base.FritzVisionPredictor
    public RigidPoseResult predict(FritzVisionImage fritzVisionImage) {
        return null;
    }

    public RigidPoseResult predict(FritzCVImage fritzCVImage) {
        int i;
        Mat rotate = fritzCVImage.rotate();
        Mat mat = new Mat();
        Imgproc.resize(rotate, mat, this.inputCVSize);
        int i2 = 0;
        mat.get(0, 0, this.inputData);
        preprocess(new ByteImage(this.inputData, mat.width(), mat.height()));
        rewindOutputs();
        Object[] objArr = {this.inputByteBuffer};
        HashMap hashMap = new HashMap();
        hashMap.put(0, this.outputHeatmaps);
        hashMap.put(1, this.outputOffsets);
        this.interpreter.runForMultipleInputsOutputs(objArr, hashMap);
        int width = this.outputSize.getWidth();
        int height = this.outputSize.getHeight();
        int width2 = this.inputSize.getWidth();
        int height2 = this.inputSize.getHeight();
        HeatmapScores heatmapScores = new HeatmapScores(this.outputHeatmaps, height, width, this.numKeypoints);
        Offsets offsets = new Offsets(this.outputOffsets, height, width, this.numKeypoints);
        int i3 = this.numKeypoints;
        float[] fArr = new float[i3];
        int[] iArr = new int[i3];
        int[] iArr2 = new int[i3];
        for (int i4 = 0; i4 < this.numKeypoints; i4++) {
            for (int i5 = 0; i5 < height; i5++) {
                for (int i6 = 0; i6 < width; i6++) {
                    float score = heatmapScores.getScore(i4, i6, i5);
                    if (score > fArr[i4]) {
                        fArr[i4] = score;
                        iArr[i4] = i5;
                        iArr2[i4] = i6;
                    }
                }
            }
        }
        int i7 = 0;
        int i8 = 0;
        while (true) {
            i = this.numKeypoints;
            if (i7 >= i) {
                break;
            }
            if (fArr[i7] >= this.scoreThreshold) {
                i8++;
            }
            i7++;
        }
        if (i8 < this.minPartsOverThreshold) {
            return null;
        }
        Point[] pointArr = new Point[i];
        while (i2 < this.numKeypoints) {
            int i9 = iArr[i2];
            PointF offsetPoint = offsets.getOffsetPoint(i2, iArr2[i2], i9);
            pointArr[i2] = new Point((r12 * (width2 / width)) + offsetPoint.x, (i9 * (height2 / height)) + offsetPoint.y);
            i2++;
            width = width;
            height = height;
        }
        return new RigidPoseResult(pointArr, fArr, this.inputSize);
    }
}
