2D Detection | precision, recall 그래프 그리기

2023. 11. 8. 00:28DL

Precision-Recall 그래프를 그리려면 다양한 임계값(threshold)에서의 정밀도(Precision) 및 재현율(Recall)을 계산하고 이를 그래프로 표현해야 합니다. 아래는 Python 및 Matplotlib을 사용하여 Precision-Recall 그래프를 그리는 예제 코드입니다.

import matplotlib.pyplot as plt

def calculate_iou(box1, box2):
    # calculate IoU as before

def calculate_precision_recall(predictions, ground_truths, thresholds):
    precisions = []
    recalls = []

    for threshold in thresholds:
        true_positives = 0
        false_positives = 0
        false_negatives = 0

        for pred in predictions:
            is_matched = False
            for gt in ground_truths:
                iou = calculate_iou(pred, gt)
                if iou >= threshold:
                    is_matched = True
                    break
            if is_matched:
                true_positives += 1
            else:
                false_positives += 1

        for gt in ground_truths:
            is_matched = False
            for pred in predictions:
                iou = calculate_iou(pred, gt)
                if iou >= threshold:
                    is_matched = True
                    break
            if not is_matched:
                false_negatives += 1

        precision = true_positives / (true_positives + false_positives)
        recall = true_positives / (true_positives + false_negatives)
        precisions.append(precision)
        recalls.append(recall)

    return precisions, recalls

# 예측(prediction)과 실제 탐지 결과(ground truth)를 정의
predictions = [(2, 2, 6, 6), (7, 7, 11, 11), (12, 12, 16, 16)]
ground_truths = [(1, 1, 5, 5), (6, 6, 10, 10), (12, 12, 16, 16)]

# 다양한 임계값을 생성
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

# 정밀도 및 재현율 계산
precisions, recalls = calculate_precision_recall(predictions, ground_truths, thresholds)

# Precision-Recall 그래프 그리기
plt.plot(recalls, precisions, marker='o')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.grid(True)
plt.show()

위 코드는 다양한 임계값(thresholds)에 대한 정밀도 및 재현율을 계산하고, Matplotlib을 사용하여 Precision-Recall 그래프를 그립니다. 그래프는 재현율을 x축에, 정밀도를 y축에 표시하며, 각 임계값에서의 정밀도와 재현율을 연결합니다. 이를 통해 Precision-Recall 곡선을 시각화할 수 있습니다.