package ai.fritz.vision.poseestimation;

import android.graphics.Point;
import android.graphics.PointF;
import android.util.Size;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.PriorityBlockingQueue;

/* loaded from: classes.dex */
public class PoseDecoderWithDisplacements {
    private static final int DEFAULT_QUEUE_SIZE = 25;
    private static final String TAG = "PoseDecoderWithDisplacements";
    private Size bounds;
    private Displacements displacementsBwd;
    private Displacements displacementsFwd;
    private HeatmapScores heatmapScores;
    private Offsets offsets;
    PriorityBlockingQueue<PartScore> scoreQueue = new PriorityBlockingQueue<>(25, new Comparator<PartScore>() { // from class: ai.fritz.vision.poseestimation.PoseDecoderWithDisplacements.1
        @Override // java.util.Comparator
        public int compare(PartScore partScore, PartScore partScore2) {
            return partScore.getScore() > partScore2.getScore() ? 1 : -1;
        }
    });
    private Skeleton skeleton;

    public PoseDecoderWithDisplacements(HeatmapScores heatmapScores, Offsets offsets, Displacements displacements, Displacements displacements2, Size size, Skeleton skeleton) {
        this.heatmapScores = heatmapScores;
        this.offsets = offsets;
        this.displacementsFwd = displacements;
        this.displacementsBwd = displacements2;
        this.bounds = size;
        this.skeleton = skeleton;
    }

    private PointF addVectors(PointF pointF, PointF pointF2) {
        return new PointF(pointF.x + pointF2.x, pointF.y + pointF2.y);
    }

    private PriorityBlockingQueue buildPartWithScoringQueue(float f, int i) {
        for (int i2 = 0; i2 < this.heatmapScores.getWidth(); i2++) {
            for (int i3 = 0; i3 < this.heatmapScores.getHeight(); i3++) {
                for (int i4 = 0; i4 < this.heatmapScores.getNumKeypoints(); i4++) {
                    float score = this.heatmapScores.getScore(i4, i2, i3);
                    if (score >= f && scoreIsMaximumInLocalWindow(i4, i2, i3, score, i)) {
                        this.scoreQueue.put(new PartScore(i4, i2, i3, score));
                    }
                }
            }
        }
        return this.scoreQueue;
    }

    private int clamp(int i, int i2, int i3) {
        return Math.max(i2, Math.min(i3, i));
    }

    private Keypoint[] decodePose(PartScore partScore, int i) {
        int numKeypoints = this.skeleton.getNumKeypoints();
        int numEdges = this.skeleton.getNumEdges();
        Keypoint[] keypointArr = new Keypoint[numKeypoints];
        PointF imageCoordinates = getImageCoordinates(partScore, i);
        int keypointId = partScore.getKeypointId();
        keypointArr[partScore.getKeypointId()] = new Keypoint(keypointId, this.skeleton.getKeypointName(keypointId), imageCoordinates, partScore.getScore(), this.bounds);
        Integer[] parentToChildEdges = this.skeleton.getParentToChildEdges();
        Integer[] childToParentEdges = this.skeleton.getChildToParentEdges();
        for (int i2 = numEdges - 1; i2 >= 0; i2--) {
            int intValue = parentToChildEdges[i2].intValue();
            int intValue2 = childToParentEdges[i2].intValue();
            if (keypointArr[intValue] != null && keypointArr[intValue2] == null) {
                keypointArr[intValue2] = traverseToTargetKeypoint(this.displacementsBwd, i2, keypointArr[intValue], intValue2, i);
            }
        }
        for (int i3 = 0; i3 < numEdges; i3++) {
            int intValue3 = childToParentEdges[i3].intValue();
            int intValue4 = parentToChildEdges[i3].intValue();
            if (keypointArr[intValue3] != null && keypointArr[intValue4] == null) {
                keypointArr[intValue4] = traverseToTargetKeypoint(this.displacementsFwd, i3, keypointArr[intValue3], intValue4, i);
            }
        }
        return keypointArr;
    }

    private PointF getImageCoordinates(Part part, int i) {
        PointF offsetPoint = this.offsets.getOffsetPoint(part.getKeypointId(), part.getHeatMapScoresX(), part.getHeatMapScoresY());
        return new PointF((part.getHeatMapScoresX() * i) + offsetPoint.x, (part.getHeatMapScoresY() * i) + offsetPoint.y);
    }

    private float getInstanceScore(List<Pose> list, float f, Keypoint[] keypointArr) {
        float f2 = 0.0f;
        for (Keypoint keypoint : keypointArr) {
            if (keypoint != null && !withinNMSRadiusOfCorrespondingPoint(list, f, keypoint.getPosition(), keypoint.getId())) {
                f2 += keypoint.getScore();
            }
        }
        return f2 / keypointArr.length;
    }

    private Point getStridedIndexNearPoint(PointF pointF, int i, int i2, int i3) {
        float f = i;
        return new Point(clamp(Math.round(pointF.x / f), 0, i3 - 1), clamp(Math.round(pointF.y / f), 0, i2 - 1));
    }

    private boolean scoreIsMaximumInLocalWindow(int i, int i2, int i3, float f, int i4) {
        int min = Math.min(i3 + i4, this.heatmapScores.getHeight());
        int max = Math.max(i2 - i4, 0);
        int min2 = Math.min(i2 + i4, this.heatmapScores.getWidth());
        for (int max2 = Math.max(i3 - i4, 0); max2 < min; max2++) {
            for (int i5 = max; i5 < min2; i5++) {
                if (this.heatmapScores.getScore(i, i5, max2) > f) {
                    return false;
                }
            }
        }
        return true;
    }

    private Keypoint traverseToTargetKeypoint(Displacements displacements, int i, Keypoint keypoint, int i2, int i3) {
        Point stridedIndexNearPoint = getStridedIndexNearPoint(keypoint.getPosition(), i3, this.heatmapScores.getHeight(), this.heatmapScores.getWidth());
        Point stridedIndexNearPoint2 = getStridedIndexNearPoint(addVectors(keypoint.getPosition(), displacements.getDisplacement(i, stridedIndexNearPoint.x, stridedIndexNearPoint.y)), i3, this.heatmapScores.getHeight(), this.heatmapScores.getWidth());
        PointF offsetPoint = this.offsets.getOffsetPoint(i2, stridedIndexNearPoint2.x, stridedIndexNearPoint2.y);
        return new Keypoint(i2, this.skeleton.getKeypointName(i2), addVectors(new PointF(stridedIndexNearPoint2.x * i3, stridedIndexNearPoint2.y * i3), offsetPoint), this.heatmapScores.getScore(i2, stridedIndexNearPoint2.x, stridedIndexNearPoint2.y), this.bounds);
    }

    private boolean withinNMSRadiusOfCorrespondingPoint(List<Pose> list, double d, PointF pointF, int i) {
        Iterator<Pose> it = list.iterator();
        while (it.hasNext()) {
            if (it.next().getKeypoints()[i].calculateSquaredDistanceFromCoordinates(pointF) <= d) {
                return true;
            }
        }
        return false;
    }

    public List<Pose> decodeMultiplePoses(int i, int i2, float f, float f2, int i3) {
        List<Pose> arrayList = new ArrayList<>();
        float f3 = f2 * f2;
        PriorityBlockingQueue buildPartWithScoringQueue = buildPartWithScoringQueue(f, i3);
        while (arrayList.size() < i2 && !buildPartWithScoringQueue.isEmpty()) {
            PartScore partScore = (PartScore) buildPartWithScoringQueue.poll();
            if (!withinNMSRadiusOfCorrespondingPoint(arrayList, f3, getImageCoordinates(partScore, i), partScore.getKeypointId())) {
                Keypoint[] decodePose = decodePose(partScore, i);
                arrayList.add(new Pose(this.skeleton, decodePose, getInstanceScore(arrayList, f3, decodePose), f, this.bounds));
            }
        }
        return arrayList;
    }
}
