Source code for ocr_stringdist.levenshtein

from __future__ import annotations

import math
from collections.abc import Iterable
from typing import Any, Optional

from ._observable_dict import _ObservableDict
from ._rust_stringdist import RustLevenshteinCalculator
from .default_ocr_distances import ocr_distance_map
from .edit_operation import EditOperation


[docs] class WeightedLevenshtein: """ Calculates Levenshtein distance with custom, configurable costs. This class is initialized with cost dictionaries and settings that define how the distance is measured. Once created, its methods can be used to efficiently compute distances and explain the edit operations. :param substitution_costs: Maps (str, str) tuples to their substitution cost. Defaults to costs based on common OCR errors. :param insertion_costs: Maps a string to its insertion cost. :param deletion_costs: Maps a string to its deletion cost. :param symmetric_substitution: If True, a cost defined for, e.g., ('0', 'O') will automatically apply to ('O', '0'). If False, both must be defined explicitly. :param default_substitution_cost: Default cost for single-char substitutions not in the map. :param default_insertion_cost: Default cost for single-char insertions not in the map. :param default_deletion_cost: Default cost for single-char deletions not in the map. :raises TypeError, ValueError: If the provided arguments are invalid. """ def __init__( self, substitution_costs: Optional[dict[tuple[str, str], float]] = None, insertion_costs: Optional[dict[str, float]] = None, deletion_costs: Optional[dict[str, float]] = None, *, symmetric_substitution: bool = True, default_substitution_cost: float = 1.0, default_insertion_cost: float = 1.0, default_deletion_cost: float = 1.0, ) -> None: self._symmetric_substitution = symmetric_substitution self._default_insertion_cost = self._validate_cost( "default_insertion_cost", default_insertion_cost ) self._default_deletion_cost = self._validate_cost( "default_deletion_cost", default_deletion_cost ) # A substitution can always be expressed as a deletion + insertion, so # capping here keeps the substitution default from being effectively # ignored when the user supplies a value above the del/ins ceiling. self._default_substitution_cost = min( self._validate_cost("default_substitution_cost", default_substitution_cost), self._default_insertion_cost + self._default_deletion_cost, ) self._calculator = None # Initialize Observable Dicts sub_init = ocr_distance_map if substitution_costs is None else substitution_costs self._substitution_costs = _ObservableDict( sub_init, self._invalidate_calculator, self._validate_sub_entry ) self._insertion_costs = _ObservableDict( insertion_costs or {}, self._invalidate_calculator, self._validate_unary_entry ) self._deletion_costs = _ObservableDict( deletion_costs or {}, self._invalidate_calculator, self._validate_unary_entry ) def _invalidate_calculator(self) -> None: """Mark the Rust backend as out of sync; it will be rebuilt on next use.""" self._calculator = None def _get_calculator(self) -> RustLevenshteinCalculator: # type: ignore[no-any-unimported] """Return a Rust backend in sync with the current Python-side state.""" if self._calculator is None: substitution_costs, insertion_costs, deletion_costs = ( self._effective_cost_maps_for_calculator() ) self._calculator = RustLevenshteinCalculator( substitution_costs=substitution_costs, insertion_costs=insertion_costs, deletion_costs=deletion_costs, symmetric_substitution=self._symmetric_substitution, default_substitution_cost=self._default_substitution_cost, default_insertion_cost=self._default_insertion_cost, default_deletion_cost=self._default_deletion_cost, ) return self._calculator def _effective_cost_maps_for_calculator( self, ) -> tuple[dict[tuple[str, str], float], dict[str, float], dict[str, float]]: """ Split substitution entries with empty source/target into the insertion/deletion maps, taking the minimum where they overlap. """ substitution_costs: dict[tuple[str, str], float] = {} insertion_costs = dict(self._insertion_costs) deletion_costs = dict(self._deletion_costs) for (source, target), cost in self._substitution_costs.items(): if source == "": self._set_min_cost(insertion_costs, target, cost) elif target == "": self._set_min_cost(deletion_costs, source, cost) else: substitution_costs[(source, target)] = cost return substitution_costs, insertion_costs, deletion_costs @staticmethod def _set_min_cost(costs: dict[str, float], key: str, cost: float) -> None: costs[key] = min(costs.get(key, cost), cost) # --- Properties --- @property def substitution_costs(self) -> dict[tuple[str, str], float]: return self._substitution_costs @substitution_costs.setter def substitution_costs(self, value: dict[tuple[str, str], float]) -> None: self._substitution_costs = _ObservableDict( value, self._invalidate_calculator, self._validate_sub_entry ) self._invalidate_calculator() @property def insertion_costs(self) -> dict[str, float]: return self._insertion_costs @insertion_costs.setter def insertion_costs(self, value: dict[str, float]) -> None: self._insertion_costs = _ObservableDict( value, self._invalidate_calculator, self._validate_unary_entry ) self._invalidate_calculator() @property def deletion_costs(self) -> dict[str, float]: return self._deletion_costs @deletion_costs.setter def deletion_costs(self, value: dict[str, float]) -> None: self._deletion_costs = _ObservableDict( value, self._invalidate_calculator, self._validate_unary_entry ) self._invalidate_calculator() @property def symmetric_substitution(self) -> bool: return self._symmetric_substitution @symmetric_substitution.setter def symmetric_substitution(self, value: bool) -> None: self._symmetric_substitution = value self._invalidate_calculator() @property def default_substitution_cost(self) -> float: return self._default_substitution_cost @default_substitution_cost.setter def default_substitution_cost(self, value: float) -> None: self._default_substitution_cost = self._validate_cost("default_substitution_cost", value) self._invalidate_calculator() @property def default_insertion_cost(self) -> float: return self._default_insertion_cost @default_insertion_cost.setter def default_insertion_cost(self, value: float) -> None: self._default_insertion_cost = self._validate_cost("default_insertion_cost", value) self._invalidate_calculator() @property def default_deletion_cost(self) -> float: return self._default_deletion_cost @default_deletion_cost.setter def default_deletion_cost(self, value: float) -> None: self._default_deletion_cost = self._validate_cost("default_deletion_cost", value) self._invalidate_calculator() # --- Validation Helpers --- def _validate_cost(self, name: str, val: float) -> float: if not isinstance(val, (int, float)): raise TypeError(f"{name} must be a number, but got: {type(val).__name__}") if not math.isfinite(val): raise ValueError(f"{name} must be finite, got value: {val}") if val < 0: raise ValueError(f"{name} must be non-negative, got value: {val}") return float(val) def _validate_sub_entry(self, key: Any, cost: Any) -> None: if not (isinstance(key, tuple) and len(key) == 2 and all(isinstance(k, str) for k in key)): raise TypeError(f"substitution_costs keys must be tuples of two strings, found: {key}") if key == ("", ""): raise ValueError('substitution_costs key ("", "") is not a meaningful edit operation') self._validate_cost(f"Cost for {key}", cost) def _validate_unary_entry(self, key: Any, cost: Any) -> None: if not isinstance(key, str): raise TypeError(f"Cost keys must be strings, found: {key}") self._validate_cost(f"Cost for {key}", cost)
[docs] @classmethod def unweighted(cls) -> WeightedLevenshtein: """Creates an instance with all operations having equal cost of 1.0.""" return cls(substitution_costs={}, insertion_costs={}, deletion_costs={})
[docs] def transitive_closure( self, *, prune: bool = False, max_node_length: Optional[int] = None, ) -> WeightedLevenshtein: """ Returns a new instance whose cost dictionaries are filled with effective (transitive) edit costs. If, for example, `substitution_costs[("a", "b")] = 0.1` and `substitution_costs[("b", "c")] = 0.1`, the closed instance's `substitution_costs[("a", "c")]` is `0.2` rather than the default. Insertion and deletion chains, and chains that cross `ε` (e.g. `del("y") + ins("x")` becoming an effective `("y", "x")` substitution), are likewise materialized. :param prune: If True, remove generated substitutions whose costs are already represented by matches, insertions, deletions, and shorter substitutions. This can make the returned cost map easier to inspect, but it is much more expensive for large closures. :param max_node_length: Maximum length (in characters) of intermediate graph nodes the closure may construct. `None` derives a sensible default from the input (twice the longest raw token, with a small floor); pass an `int` to override. The cap is what guarantees termination - without it, the graph could grow without bound. Floyd-Warshall is :math:`O(N^3)` in the resulting node count, so a higher cap can be substantially slower. :raises ValueError: If the generated closure graph is too large to process safely. ``explain()`` on the closed instance returns flat single-step ops; the original chain that produced an effective cost is not preserved. For repeated use, save via :meth:`to_dict` and reload via :meth:`from_dict` so the closure is computed once. """ sub_dict, ins_dict, del_dict = self._get_calculator().closed_cost_maps( prune, max_node_length ) return WeightedLevenshtein( substitution_costs=dict(sub_dict), insertion_costs=dict(ins_dict), deletion_costs=dict(del_dict), symmetric_substitution=False, default_substitution_cost=self.default_substitution_cost, default_insertion_cost=self.default_insertion_cost, default_deletion_cost=self.default_deletion_cost, )
[docs] def distance(self, s1: str, s2: str) -> float: """Calculates the weighted Levenshtein distance between two strings.""" return self._get_calculator().distance(s1, s2) # type: ignore[no-any-return]
[docs] def explain(self, s1: str, s2: str, filter_matches: bool = True) -> list[EditOperation]: """ Returns the list of edit operations to transform s1 into s2. :param s1: First string (interpreted as the string read via OCR) :param s2: Second string (interpreted as the target string) :param filter_matches: If True, 'match' operations are excluded from the result. :return: List of :class:`EditOperation` instances. """ raw_path = self._get_calculator().explain(s1, s2) parsed_path = [EditOperation(*op) for op in raw_path] if filter_matches: return list(filter(lambda op: op.op_type != "match", parsed_path)) return parsed_path
[docs] def batch_distance(self, s: str, candidates: list[str]) -> list[float]: """Calculates distances between a string and a list of candidates.""" return self._get_calculator().batch_distance(s, candidates) # type: ignore[no-any-return]
[docs] @classmethod def learn_from(cls, pairs: Iterable[tuple[str, str]]) -> WeightedLevenshtein: """ Creates an instance by learning costs from a dataset of (OCR, ground truth) string pairs. For more advanced learning configuration, see the :class:`ocr_stringdist.learner.CostLearner` class. :param pairs: An iterable of (ocr_string, ground_truth_string) tuples. Correct pairs are not intended to be filtered; they are needed to learn well-aligned costs. :return: A new `WeightedLevenshtein` instance with the learned costs. Example:: from ocr_stringdist import WeightedLevenshtein training_data = [ ("8N234", "BN234"), # read '8' instead of 'B' ("BJK18", "BJK18"), # correct ("ABC0.", "ABC0"), # extra '.' ] wl = WeightedLevenshtein.learn_from(training_data) print(wl.substitution_costs) # learned cost for substituting '8' with 'B' print(wl.deletion_costs) # learned cost for deleting '.' """ from .learner import CostLearner return CostLearner().fit(pairs)
def __repr__(self) -> str: return ( f"{type(self).__name__}(" f"substitution_costs={dict(self._substitution_costs)!r}, " f"insertion_costs={dict(self._insertion_costs)!r}, " f"deletion_costs={dict(self._deletion_costs)!r}, " f"symmetric_substitution={self._symmetric_substitution!r}, " f"default_substitution_cost={self._default_substitution_cost!r}, " f"default_insertion_cost={self._default_insertion_cost!r}, " f"default_deletion_cost={self._default_deletion_cost!r})" ) def __str__(self) -> str: return ( f"{type(self).__name__}(" f"substitution_costs=<{len(self._substitution_costs)} entries>, " f"insertion_costs=<{len(self._insertion_costs)} entries>, " f"deletion_costs=<{len(self._deletion_costs)} entries>, " f"symmetric_substitution={self._symmetric_substitution}, " f"default_substitution_cost={self._default_substitution_cost}, " f"default_insertion_cost={self._default_insertion_cost}, " f"default_deletion_cost={self._default_deletion_cost})" ) def __eq__(self, other: object) -> bool: if not isinstance(other, WeightedLevenshtein): return NotImplemented return ( self.substitution_costs == other.substitution_costs and self.insertion_costs == other.insertion_costs and self.deletion_costs == other.deletion_costs and self.symmetric_substitution == other.symmetric_substitution and self.default_substitution_cost == other.default_substitution_cost and self.default_insertion_cost == other.default_insertion_cost and self.default_deletion_cost == other.default_deletion_cost )
[docs] def to_dict(self) -> dict[str, Any]: """ Serializes the instance's configuration to a dictionary. The result can be written to, say, JSON. For the counterpart, see :meth:`WeightedLevenshtein.from_dict`. """ # Convert tuple keys to a list of lists/objects for broader compatibility (e.g., JSON) sub_costs_serializable = [ {"from": k[0], "to": k[1], "cost": v} for k, v in self.substitution_costs.items() ] return { "substitution_costs": sub_costs_serializable, "insertion_costs": self.insertion_costs, "deletion_costs": self.deletion_costs, "symmetric_substitution": self.symmetric_substitution, "default_substitution_cost": self.default_substitution_cost, "default_insertion_cost": self.default_insertion_cost, "default_deletion_cost": self.default_deletion_cost, }
[docs] @classmethod def from_dict(cls, data: dict[str, Any]) -> WeightedLevenshtein: """ Deserialize from a dictionary. For the counterpart, see :meth:`WeightedLevenshtein.to_dict`. :param data: A dictionary with (not necessarily all of) the following keys: - "substitution_costs": {"from": str, "to": str, "cost": float} - "insertion_costs": dict[str, float] - "deletion_costs": dict[str, float] - "symmetric_substitution": bool - "default_substitution_cost": float - "default_insertion_cost": float - "default_deletion_cost": float """ # Convert the list of substitution costs back to the required dict format sub_costs: dict[tuple[str, str], float] = { (item["from"], item["to"]): item["cost"] for item in data.get("substitution_costs", {}) } return cls( substitution_costs=sub_costs, insertion_costs=data.get("insertion_costs"), deletion_costs=data.get("deletion_costs"), symmetric_substitution=data.get("symmetric_substitution", True), default_substitution_cost=data.get("default_substitution_cost", 1.0), default_insertion_cost=data.get("default_insertion_cost", 1.0), default_deletion_cost=data.get("default_deletion_cost", 1.0), )