Skip to content

Utils

atr_ner_eval.utils

Utils functions.

Attributes

logger module-attribute

logger = logging.getLogger(__name__)

Functions

check_complete

check_complete(labels: list[Path], predictions: list[Path])

Check that each label BIO file has a corresponding prediction BIO file and each prediction BIO file has a corresponding label BIO file. Otherwise raise an error.

Parameters:

Name Type Description Default
labels list[Path]

List of sorted label BIO files.

required
predictions list[Path]

List of sorted prediction BIO files.

required
Source code in atr_ner_eval/utils.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def check_complete(labels: list[Path], predictions: list[Path]):
    """Check that each label BIO file has a corresponding prediction BIO file and each prediction BIO file has a corresponding label BIO file. Otherwise raise an error.

    Args:
        labels: List of sorted label BIO files.
        predictions: List of sorted prediction BIO files.
    """
    # List filenames in prediction and label directories.
    label_filenames = {label.name for label in labels}
    prediction_filenames = {prediction.name for prediction in predictions}

    # Raise an error if there are any missing files.
    if label_filenames != prediction_filenames:
        messages = []
        missing_label_files = prediction_filenames.difference(label_filenames)
        missing_pred_files = label_filenames.difference(prediction_filenames)
        if len(missing_pred_files) > 0:
            messages.append(f"Missing prediction files: {missing_pred_files}.")
        if len(missing_label_files) > 0:
            messages.append(f"Missing label files: {missing_label_files}.")
        raise FileNotFoundError("\n".join(messages))

check_valid_bio

check_valid_bio(bio_files: list[Path]) -> list[Document]

Check that BIO files exists and are valid.

Parameters:

Name Type Description Default
bio_files list[Path]

List of BIO files to check.

required
Source code in atr_ner_eval/utils.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def check_valid_bio(bio_files: list[Path]) -> list[Document]:
    """Check that BIO files exists and are valid.

    Args:
        bio_files: List of BIO files to check.
    """
    parsed = []
    for filename in bio_files:
        # Raise an error if the document does not exist
        if not filename.exists():
            raise FileNotFoundError(
                f"BIO file {filename} does not exist.",
            )

        # Raise an error if the document is not valid
        try:
            document = Document.from_file(filename)
        except Exception as e:
            raise FileNotFoundError(
                f"BIO file {filename} is not valid: {e}",
            ) from e

        # Raise an error if an entity is named GLOBAL_STAT_NAME
        if GLOBAL_STAT_NAME in {entity[0] for entity in document.entities}:
            raise Exception(
                f"Invalid entity name {GLOBAL_STAT_NAME}: reserved for global statistics ({filename}).",
            )
        parsed.append(document)
    return parsed

load_dataset

load_dataset(
    label_dir: Path, prediction_dir: Path
) -> list[tuple[Document, Document]]

Load BIO files for a given dataset.

Parameters:

Name Type Description Default
label_dir Path

Path to the label directory.

required
prediction_dir Path

Path to prediction directory.

required

Returns:

Type Description
list[tuple[Document, Document]]

list[tuple[Document, Document]]: A list of tuple containing the label and corresponding prediction Documents.

Source code in atr_ner_eval/utils.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def load_dataset(
    label_dir: Path,
    prediction_dir: Path,
) -> list[tuple[Document, Document]]:
    """Load BIO files for a given dataset.

    Args:
        label_dir (Path): Path to the label directory.
        prediction_dir (Path): Path to prediction directory.

    Returns:
        list[tuple[Document, Document]]: A list of tuple containing the label and corresponding prediction Documents.
    """
    sorted_labels = sorted(label_dir.glob("*.bio"), key=attrgetter("name"))
    sorted_predictions = sorted(prediction_dir.glob("*.bio"), key=attrgetter("name"))

    # Check if a directory is empty
    if not (sorted_labels and sorted_predictions):
        messages = []
        if not sorted_labels:
            messages.append(f"Empty label directory: {label_dir}.")
        if not sorted_predictions:
            messages.append(f"Empty prediction directory: {prediction_dir}.")
        raise FileNotFoundError("\n".join(messages))

    # Check that the dataset is complete and valid
    check_complete(sorted_labels, sorted_predictions)

    logger.info("Loading labels...")
    labels = check_valid_bio(sorted_labels)

    logger.info("Loading prediction...")
    predictions = check_valid_bio(sorted_predictions)

    logger.info("The dataset is complete and valid.")
    # Return each label and prediction Document couple
    return list(zip(labels, predictions))

sort_categories

sort_categories(
    categories: list[list[str]],
) -> list[list[str]]

Sort a list of categories with their associated metrics.

All categories are alphabetically sorted except for GLOBAL_STAT_NAME which is appended at the very end of the sorted list.

Parameters:

Name Type Description Default
categories list[list[str]]

List of categories with their metrics.

required

Returns:

Type Description
list[list[str]]

list[list[str]]: A sorted version of the provided list of categories.

Source code in atr_ner_eval/utils.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def sort_categories(categories: list[list[str]]) -> list[list[str]]:
    """Sort a list of categories with their associated metrics.

    All categories are alphabetically sorted except for GLOBAL_STAT_NAME
    which is appended at the very end of the sorted list.

    Args:
        categories (list[list[str]]): List of categories with their metrics.

    Returns:
        list[list[str]]: A sorted version of the provided list of categories.
    """
    sorted_categories = sorted(categories)
    return sorted(
        sorted_categories,
        key=lambda e: sorted_categories.index(e)
        if e[0] != GLOBAL_STAT_NAME
        else len(sorted_categories),
    )