Source code for ocr_stringdist.levenshtein

from __future__ import annotations

from collections.abc import Iterable
from typing import Any, Optional

from ._rust_stringdist import (
    _batch_weighted_levenshtein_distance,
    _explain_weighted_levenshtein_distance,
    _weighted_levenshtein_distance,
)
from .default_ocr_distances import ocr_distance_map
from .edit_operation import EditOperation


[docs] class WeightedLevenshtein: """ Calculates Levenshtein distance with custom, configurable costs. This class is initialized with cost dictionaries and settings that define how the distance is measured. Once created, its methods can be used to efficiently compute distances and explain the edit operations. :param substitution_costs: Maps (str, str) tuples to their substitution cost. Defaults to costs based on common OCR errors. :param insertion_costs: Maps a string to its insertion cost. :param deletion_costs: Maps a string to its deletion cost. :param symmetric_substitution: If True, a cost defined for, e.g., ('0', 'O') will automatically apply to ('O', '0'). If False, both must be defined explicitly. :param default_substitution_cost: Default cost for single-char substitutions not in the map. :param default_insertion_cost: Default cost for single-char insertions not in the map. :param default_deletion_cost: Default cost for single-char deletions not in the map. :raises TypeError, ValueError: If the provided arguments are invalid. """ substitution_costs: dict[tuple[str, str], float] insertion_costs: dict[str, float] deletion_costs: dict[str, float] symmetric_substitution: bool default_substitution_cost: float default_insertion_cost: float default_deletion_cost: float def __init__( self, substitution_costs: Optional[dict[tuple[str, str], float]] = None, insertion_costs: Optional[dict[str, float]] = None, deletion_costs: Optional[dict[str, float]] = None, *, symmetric_substitution: bool = True, default_substitution_cost: float = 1.0, default_insertion_cost: float = 1.0, default_deletion_cost: float = 1.0, ) -> None: # Validate default costs for cost_name, cost_val in [ ("default_substitution_cost", default_substitution_cost), ("default_insertion_cost", default_insertion_cost), ("default_deletion_cost", default_deletion_cost), ]: if not isinstance(cost_val, (int, float)): raise TypeError(f"{cost_name} must be a number, but got: {type(cost_val).__name__}") if cost_val < 0: raise ValueError(f"{cost_name} must be non-negative, got value: {cost_val}") # Validate substitution_costs dictionary sub_costs = ocr_distance_map if substitution_costs is None else substitution_costs for key, cost in sub_costs.items(): if not ( isinstance(key, tuple) and len(key) == 2 and isinstance(key[0], str) and isinstance(key[1], str) ): raise TypeError( f"substitution_costs keys must be tuples of two strings, but found: {key}" ) if not isinstance(cost, (int, float)): raise TypeError( f"Cost for substitution {key} must be a number, but got: {type(cost).__name__}" ) if cost < 0: raise ValueError(f"Cost for substitution {key} cannot be negative, but got: {cost}") self.substitution_costs = sub_costs self.insertion_costs = {} if insertion_costs is None else insertion_costs self.deletion_costs = {} if deletion_costs is None else deletion_costs self.symmetric_substitution = symmetric_substitution self.default_substitution_cost = default_substitution_cost self.default_insertion_cost = default_insertion_cost self.default_deletion_cost = default_deletion_cost
[docs] @classmethod def unweighted(cls) -> WeightedLevenshtein: """Creates an instance with all operations having equal cost of 1.0.""" return cls(substitution_costs={}, insertion_costs={}, deletion_costs={})
[docs] def distance(self, s1: str, s2: str) -> float: """Calculates the weighted Levenshtein distance between two strings.""" return _weighted_levenshtein_distance(s1, s2, **self.__dict__) # type: ignore[no-any-return]
[docs] def explain(self, s1: str, s2: str, filter_matches: bool = True) -> list[EditOperation]: """ Returns the list of edit operations to transform s1 into s2. :param s1: First string (interpreted as the string read via OCR) :param s2: Second string (interpreted as the target string) :param filter_matches: If True, 'match' operations are excluded from the result. :return: List of :class:`EditOperation` instances. """ raw_path = _explain_weighted_levenshtein_distance(s1, s2, **self.__dict__) parsed_path = [EditOperation(*op) for op in raw_path] if filter_matches: return list(filter(lambda op: op.op_type != "match", parsed_path)) return parsed_path
[docs] def batch_distance(self, s: str, candidates: list[str]) -> list[float]: """Calculates distances between a string and a list of candidates.""" return _batch_weighted_levenshtein_distance(s, candidates, **self.__dict__) # type: ignore[no-any-return]
[docs] @classmethod def learn_from(cls, pairs: Iterable[tuple[str, str]]) -> WeightedLevenshtein: """ Creates an instance by learning costs from a dataset of (OCR, ground truth) string pairs. For more advanced learning configuration, see the :class:`ocr_stringdist.learner.CostLearner` class. :param pairs: An iterable of (ocr_string, ground_truth_string) tuples. Correct pairs are not intended to be filtered; they are needed to learn well-aligned costs. :return: A new `WeightedLevenshtein` instance with the learned costs. Example:: from ocr_stringdist import WeightedLevenshtein training_data = [ ("8N234", "BN234"), # read '8' instead of 'B' ("BJK18", "BJK18"), # correct ("ABC0.", "ABC0"), # extra '.' ] wl = WeightedLevenshtein.learn_from(training_data) print(wl.substitution_costs) # learned cost for substituting '8' with 'B' print(wl.deletion_costs) # learned cost for deleting '.' """ from .learner import CostLearner return CostLearner().fit(pairs)
def __eq__(self, other: object) -> bool: if not isinstance(other, WeightedLevenshtein): return NotImplemented return ( self.substitution_costs == other.substitution_costs and self.insertion_costs == other.insertion_costs and self.deletion_costs == other.deletion_costs and self.symmetric_substitution == other.symmetric_substitution and self.default_substitution_cost == other.default_substitution_cost and self.default_insertion_cost == other.default_insertion_cost and self.default_deletion_cost == other.default_deletion_cost )
[docs] def to_dict(self) -> dict[str, Any]: """ Serializes the instance's configuration to a dictionary. The result can be written to, say, JSON. For the counterpart, see :meth:`WeightedLevenshtein.from_dict`. """ # Convert tuple keys to a list of lists/objects for broader compatibility (e.g., JSON) sub_costs_serializable = [ {"from": k[0], "to": k[1], "cost": v} for k, v in self.substitution_costs.items() ] return { "substitution_costs": sub_costs_serializable, "insertion_costs": self.insertion_costs, "deletion_costs": self.deletion_costs, "symmetric_substitution": self.symmetric_substitution, "default_substitution_cost": self.default_substitution_cost, "default_insertion_cost": self.default_insertion_cost, "default_deletion_cost": self.default_deletion_cost, }
[docs] @classmethod def from_dict(cls, data: dict[str, Any]) -> WeightedLevenshtein: """ Deserialize from a dictionary. For the counterpart, see :meth:`WeightedLevenshtein.to_dict`. :param data: A dictionary with (not necessarily all of) the following keys: - "substitution_costs": {"from": str, "to": str, "cost": float} - "substitution_costs": dict[str, float] - "deletion_costs": dict[str, float] - "symmetric_substitution": bool - "default_substitution_cost": float - "default_insertion_cost": float - "default_deletion_cost": float """ # Convert the list of substitution costs back to the required dict format sub_costs: dict[tuple[str, str], float] = { (item["from"], item["to"]): item["cost"] for item in data.get("substitution_costs", {}) } return cls( substitution_costs=sub_costs, insertion_costs=data.get("insertion_costs"), deletion_costs=data.get("deletion_costs"), symmetric_substitution=data.get("symmetric_substitution", True), default_substitution_cost=data.get("default_substitution_cost", 1.0), default_insertion_cost=data.get("default_insertion_cost", 1.0), default_deletion_cost=data.get("default_deletion_cost", 1.0), )