Skip to content

Nerval

atr_ner_eval.metrics.nerval

Compute Precision, Recall and F1 from a label/prediction dataset.

Attributes

NERVAL_GLOBAL_STAT module-attribute

NERVAL_GLOBAL_STAT = 'All'

METRICS module-attribute

METRICS = ['predicted', 'matched', 'Support']

Functions

compute_precision

compute_precision(matched: int, predicted: int) -> float

Compute precision.

Source code in atr_ner_eval/metrics/nerval.py
30
31
32
33
34
def compute_precision(matched: int, predicted: int) -> float:
    """Compute precision."""
    if predicted == 0:
        return 100 if matched == 0 else 0
    return 100 * matched / predicted

compute_recall

compute_recall(matched: int, support: int) -> float

Compute recall.

Source code in atr_ner_eval/metrics/nerval.py
37
38
39
40
41
def compute_recall(matched: int, support: int) -> float:
    """Compute recall."""
    if support == 0:
        return 100 if matched == 0 else 0
    return 100 * matched / support

compute_f1

compute_f1(precision: float, recall: float) -> float

Compute F1 score.

Source code in atr_ner_eval/metrics/nerval.py
44
45
46
47
48
def compute_f1(precision: float, recall: float) -> float:
    """Compute F1 score."""
    if precision + recall == 0:
        return 0
    return 2 * precision * recall / (precision + recall)

compute_nerval

compute_nerval(
    label_dir: Path,
    prediction_dir: Path,
    threshold: float,
    by_category: bool = False,
) -> None

Read BIO files and compute Precision, Recall and F1 globally or for each NER category.

Parameters:

Name Type Description Default
label_dir Path

Path to the reference BIO file.

required
prediction_dir Path

Path to the prediction BIO file.

required
threshold float

Character Error Rate threshold used to match entities.

required
by_category bool

Whether to display Precision/Recall/F1 by category.

False

Returns:

Type Description
None

A Markdown formatted table containing evaluation results.

Source code in atr_ner_eval/metrics/nerval.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def compute_nerval(
    label_dir: Path,
    prediction_dir: Path,
    threshold: float,
    by_category: bool = False,
) -> None:
    """Read BIO files and compute Precision, Recall and F1 globally or for each NER category.

    Args:
        label_dir (Path): Path to the reference BIO file.
        prediction_dir (Path): Path to the prediction BIO file.
        threshold (float): Character Error Rate threshold used to match entities.
        by_category (bool): Whether to display Precision/Recall/F1 by category.

    Returns:
        A Markdown formatted table containing evaluation results.
    """
    # Load the dataset
    dataset = load_dataset(label_dir, prediction_dir)

    # Iterate over the dataset
    scores = defaultdict(lambda: defaultdict(int))
    for label, prediction in dataset:
        cor_score = evaluate(
            _format_document(label),
            _format_document(prediction),
            threshold,
        )
        for entity, results in cor_score.items():
            # Nerval uses a different global statistic than us
            if entity == NERVAL_GLOBAL_STAT:
                entity = GLOBAL_STAT_NAME

            for metric in METRICS:
                scores[entity][metric] += results[metric] or 0

    results = []

    for entity, score in scores.items():
        if entity != GLOBAL_STAT_NAME and not by_category:
            continue
        precision = compute_precision(score["matched"], score["predicted"])
        recall = compute_recall(score["matched"], score["Support"])
        results.append(
            [
                entity,
                round(precision, 2),
                round(recall, 2),
                round(compute_f1(precision, recall), 2),
                score["Support"],
            ],
        )

    print_markdown_table(
        ["Category", "Precision", "Recall", "F1", "Support"],
        sort_categories(results),
    )