from __future__ import annotations
import math
from collections.abc import Iterable
from typing import Any, Optional
from ._observable_dict import _ObservableDict
from ._rust_stringdist import RustLevenshteinCalculator
from .default_ocr_distances import ocr_distance_map
from .edit_operation import EditOperation
[docs]
class WeightedLevenshtein:
"""
Calculates Levenshtein distance with custom, configurable costs.
This class is initialized with cost dictionaries and settings that define
how the distance is measured. Once created, its methods can be used to
efficiently compute distances and explain the edit operations.
:param substitution_costs: Maps (str, str) tuples to their substitution cost.
Defaults to costs based on common OCR errors.
:param insertion_costs: Maps a string to its insertion cost.
:param deletion_costs: Maps a string to its deletion cost.
:param symmetric_substitution: If True, a cost defined for, e.g., ('0', 'O') will automatically
apply to ('O', '0'). If False, both must be defined explicitly.
:param default_substitution_cost: Default cost for single-char substitutions not in the map.
:param default_insertion_cost: Default cost for single-char insertions not in the map.
:param default_deletion_cost: Default cost for single-char deletions not in the map.
:raises TypeError, ValueError: If the provided arguments are invalid.
"""
def __init__(
self,
substitution_costs: Optional[dict[tuple[str, str], float]] = None,
insertion_costs: Optional[dict[str, float]] = None,
deletion_costs: Optional[dict[str, float]] = None,
*,
symmetric_substitution: bool = True,
default_substitution_cost: float = 1.0,
default_insertion_cost: float = 1.0,
default_deletion_cost: float = 1.0,
) -> None:
self._symmetric_substitution = symmetric_substitution
self._default_insertion_cost = self._validate_cost(
"default_insertion_cost", default_insertion_cost
)
self._default_deletion_cost = self._validate_cost(
"default_deletion_cost", default_deletion_cost
)
# A substitution can always be expressed as a deletion + insertion, so
# capping here keeps the substitution default from being effectively
# ignored when the user supplies a value above the del/ins ceiling.
self._default_substitution_cost = min(
self._validate_cost("default_substitution_cost", default_substitution_cost),
self._default_insertion_cost + self._default_deletion_cost,
)
self._calculator = None
# Initialize Observable Dicts
sub_init = ocr_distance_map if substitution_costs is None else substitution_costs
self._substitution_costs = _ObservableDict(
sub_init, self._invalidate_calculator, self._validate_sub_entry
)
self._insertion_costs = _ObservableDict(
insertion_costs or {}, self._invalidate_calculator, self._validate_unary_entry
)
self._deletion_costs = _ObservableDict(
deletion_costs or {}, self._invalidate_calculator, self._validate_unary_entry
)
def _invalidate_calculator(self) -> None:
"""Mark the Rust backend as out of sync; it will be rebuilt on next use."""
self._calculator = None
def _get_calculator(self) -> RustLevenshteinCalculator: # type: ignore[no-any-unimported]
"""Return a Rust backend in sync with the current Python-side state."""
if self._calculator is None:
substitution_costs, insertion_costs, deletion_costs = (
self._effective_cost_maps_for_calculator()
)
self._calculator = RustLevenshteinCalculator(
substitution_costs=substitution_costs,
insertion_costs=insertion_costs,
deletion_costs=deletion_costs,
symmetric_substitution=self._symmetric_substitution,
default_substitution_cost=self._default_substitution_cost,
default_insertion_cost=self._default_insertion_cost,
default_deletion_cost=self._default_deletion_cost,
)
return self._calculator
def _effective_cost_maps_for_calculator(
self,
) -> tuple[dict[tuple[str, str], float], dict[str, float], dict[str, float]]:
"""
Split substitution entries with empty source/target into the
insertion/deletion maps, taking the minimum where they overlap.
"""
substitution_costs: dict[tuple[str, str], float] = {}
insertion_costs = dict(self._insertion_costs)
deletion_costs = dict(self._deletion_costs)
for (source, target), cost in self._substitution_costs.items():
if source == "":
self._set_min_cost(insertion_costs, target, cost)
elif target == "":
self._set_min_cost(deletion_costs, source, cost)
else:
substitution_costs[(source, target)] = cost
return substitution_costs, insertion_costs, deletion_costs
@staticmethod
def _set_min_cost(costs: dict[str, float], key: str, cost: float) -> None:
costs[key] = min(costs.get(key, cost), cost)
# --- Properties ---
@property
def substitution_costs(self) -> dict[tuple[str, str], float]:
return self._substitution_costs
@substitution_costs.setter
def substitution_costs(self, value: dict[tuple[str, str], float]) -> None:
self._substitution_costs = _ObservableDict(
value, self._invalidate_calculator, self._validate_sub_entry
)
self._invalidate_calculator()
@property
def insertion_costs(self) -> dict[str, float]:
return self._insertion_costs
@insertion_costs.setter
def insertion_costs(self, value: dict[str, float]) -> None:
self._insertion_costs = _ObservableDict(
value, self._invalidate_calculator, self._validate_unary_entry
)
self._invalidate_calculator()
@property
def deletion_costs(self) -> dict[str, float]:
return self._deletion_costs
@deletion_costs.setter
def deletion_costs(self, value: dict[str, float]) -> None:
self._deletion_costs = _ObservableDict(
value, self._invalidate_calculator, self._validate_unary_entry
)
self._invalidate_calculator()
@property
def symmetric_substitution(self) -> bool:
return self._symmetric_substitution
@symmetric_substitution.setter
def symmetric_substitution(self, value: bool) -> None:
self._symmetric_substitution = value
self._invalidate_calculator()
@property
def default_substitution_cost(self) -> float:
return self._default_substitution_cost
@default_substitution_cost.setter
def default_substitution_cost(self, value: float) -> None:
self._default_substitution_cost = self._validate_cost("default_substitution_cost", value)
self._invalidate_calculator()
@property
def default_insertion_cost(self) -> float:
return self._default_insertion_cost
@default_insertion_cost.setter
def default_insertion_cost(self, value: float) -> None:
self._default_insertion_cost = self._validate_cost("default_insertion_cost", value)
self._invalidate_calculator()
@property
def default_deletion_cost(self) -> float:
return self._default_deletion_cost
@default_deletion_cost.setter
def default_deletion_cost(self, value: float) -> None:
self._default_deletion_cost = self._validate_cost("default_deletion_cost", value)
self._invalidate_calculator()
# --- Validation Helpers ---
def _validate_cost(self, name: str, val: float) -> float:
if not isinstance(val, (int, float)):
raise TypeError(f"{name} must be a number, but got: {type(val).__name__}")
if not math.isfinite(val):
raise ValueError(f"{name} must be finite, got value: {val}")
if val < 0:
raise ValueError(f"{name} must be non-negative, got value: {val}")
return float(val)
def _validate_sub_entry(self, key: Any, cost: Any) -> None:
if not (isinstance(key, tuple) and len(key) == 2 and all(isinstance(k, str) for k in key)):
raise TypeError(f"substitution_costs keys must be tuples of two strings, found: {key}")
if key == ("", ""):
raise ValueError('substitution_costs key ("", "") is not a meaningful edit operation')
self._validate_cost(f"Cost for {key}", cost)
def _validate_unary_entry(self, key: Any, cost: Any) -> None:
if not isinstance(key, str):
raise TypeError(f"Cost keys must be strings, found: {key}")
self._validate_cost(f"Cost for {key}", cost)
[docs]
@classmethod
def unweighted(cls) -> WeightedLevenshtein:
"""Creates an instance with all operations having equal cost of 1.0."""
return cls(substitution_costs={}, insertion_costs={}, deletion_costs={})
[docs]
def transitive_closure(
self,
*,
prune: bool = False,
max_node_length: Optional[int] = None,
) -> WeightedLevenshtein:
"""
Returns a new instance whose cost dictionaries are filled with effective
(transitive) edit costs.
If, for example, `substitution_costs[("a", "b")] = 0.1` and
`substitution_costs[("b", "c")] = 0.1`, the closed instance's
`substitution_costs[("a", "c")]` is `0.2` rather than the default.
Insertion and deletion chains, and chains that cross `ε` (e.g.
`del("y") + ins("x")` becoming an effective `("y", "x")` substitution),
are likewise materialized.
:param prune: If True, remove generated substitutions whose costs are
already represented by matches, insertions, deletions, and
shorter substitutions. This can make the returned cost map
easier to inspect, but it is much more expensive for large
closures.
:param max_node_length: Maximum length (in characters) of intermediate
graph nodes the closure may construct. `None`
derives a sensible default from the input
(twice the longest raw token, with a small
floor); pass an `int` to override. The cap is
what guarantees termination - without it,
the graph could grow without bound. Floyd-Warshall is
:math:`O(N^3)` in the resulting node count, so a
higher cap can be substantially slower.
:raises ValueError: If the generated closure graph is too large to
process safely.
``explain()`` on the closed instance returns flat single-step ops; the
original chain that produced an effective cost is not preserved.
For repeated use, save via :meth:`to_dict` and reload via
:meth:`from_dict` so the closure is computed once.
"""
sub_dict, ins_dict, del_dict = self._get_calculator().closed_cost_maps(
prune, max_node_length
)
return WeightedLevenshtein(
substitution_costs=dict(sub_dict),
insertion_costs=dict(ins_dict),
deletion_costs=dict(del_dict),
symmetric_substitution=False,
default_substitution_cost=self.default_substitution_cost,
default_insertion_cost=self.default_insertion_cost,
default_deletion_cost=self.default_deletion_cost,
)
[docs]
def distance(self, s1: str, s2: str) -> float:
"""Calculates the weighted Levenshtein distance between two strings."""
return self._get_calculator().distance(s1, s2) # type: ignore[no-any-return]
[docs]
def explain(self, s1: str, s2: str, filter_matches: bool = True) -> list[EditOperation]:
"""
Returns the list of edit operations to transform s1 into s2.
:param s1: First string (interpreted as the string read via OCR)
:param s2: Second string (interpreted as the target string)
:param filter_matches: If True, 'match' operations are excluded from the result.
:return: List of :class:`EditOperation` instances.
"""
raw_path = self._get_calculator().explain(s1, s2)
parsed_path = [EditOperation(*op) for op in raw_path]
if filter_matches:
return list(filter(lambda op: op.op_type != "match", parsed_path))
return parsed_path
[docs]
def batch_distance(self, s: str, candidates: list[str]) -> list[float]:
"""Calculates distances between a string and a list of candidates."""
return self._get_calculator().batch_distance(s, candidates) # type: ignore[no-any-return]
[docs]
@classmethod
def learn_from(cls, pairs: Iterable[tuple[str, str]]) -> WeightedLevenshtein:
"""
Creates an instance by learning costs from a dataset of (OCR, ground truth) string pairs.
For more advanced learning configuration, see the
:class:`ocr_stringdist.learner.CostLearner` class.
:param pairs: An iterable of (ocr_string, ground_truth_string) tuples. Correct pairs
are not intended to be filtered; they are needed to learn well-aligned costs.
:return: A new `WeightedLevenshtein` instance with the learned costs.
Example::
from ocr_stringdist import WeightedLevenshtein
training_data = [
("8N234", "BN234"), # read '8' instead of 'B'
("BJK18", "BJK18"), # correct
("ABC0.", "ABC0"), # extra '.'
]
wl = WeightedLevenshtein.learn_from(training_data)
print(wl.substitution_costs) # learned cost for substituting '8' with 'B'
print(wl.deletion_costs) # learned cost for deleting '.'
"""
from .learner import CostLearner
return CostLearner().fit(pairs)
def __repr__(self) -> str:
return (
f"{type(self).__name__}("
f"substitution_costs={dict(self._substitution_costs)!r}, "
f"insertion_costs={dict(self._insertion_costs)!r}, "
f"deletion_costs={dict(self._deletion_costs)!r}, "
f"symmetric_substitution={self._symmetric_substitution!r}, "
f"default_substitution_cost={self._default_substitution_cost!r}, "
f"default_insertion_cost={self._default_insertion_cost!r}, "
f"default_deletion_cost={self._default_deletion_cost!r})"
)
def __str__(self) -> str:
return (
f"{type(self).__name__}("
f"substitution_costs=<{len(self._substitution_costs)} entries>, "
f"insertion_costs=<{len(self._insertion_costs)} entries>, "
f"deletion_costs=<{len(self._deletion_costs)} entries>, "
f"symmetric_substitution={self._symmetric_substitution}, "
f"default_substitution_cost={self._default_substitution_cost}, "
f"default_insertion_cost={self._default_insertion_cost}, "
f"default_deletion_cost={self._default_deletion_cost})"
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, WeightedLevenshtein):
return NotImplemented
return (
self.substitution_costs == other.substitution_costs
and self.insertion_costs == other.insertion_costs
and self.deletion_costs == other.deletion_costs
and self.symmetric_substitution == other.symmetric_substitution
and self.default_substitution_cost == other.default_substitution_cost
and self.default_insertion_cost == other.default_insertion_cost
and self.default_deletion_cost == other.default_deletion_cost
)
[docs]
def to_dict(self) -> dict[str, Any]:
"""
Serializes the instance's configuration to a dictionary.
The result can be written to, say, JSON.
For the counterpart, see :meth:`WeightedLevenshtein.from_dict`.
"""
# Convert tuple keys to a list of lists/objects for broader compatibility (e.g., JSON)
sub_costs_serializable = [
{"from": k[0], "to": k[1], "cost": v} for k, v in self.substitution_costs.items()
]
return {
"substitution_costs": sub_costs_serializable,
"insertion_costs": self.insertion_costs,
"deletion_costs": self.deletion_costs,
"symmetric_substitution": self.symmetric_substitution,
"default_substitution_cost": self.default_substitution_cost,
"default_insertion_cost": self.default_insertion_cost,
"default_deletion_cost": self.default_deletion_cost,
}
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> WeightedLevenshtein:
"""
Deserialize from a dictionary.
For the counterpart, see :meth:`WeightedLevenshtein.to_dict`.
:param data: A dictionary with (not necessarily all of) the following keys:
- "substitution_costs": {"from": str, "to": str, "cost": float}
- "insertion_costs": dict[str, float]
- "deletion_costs": dict[str, float]
- "symmetric_substitution": bool
- "default_substitution_cost": float
- "default_insertion_cost": float
- "default_deletion_cost": float
"""
# Convert the list of substitution costs back to the required dict format
sub_costs: dict[tuple[str, str], float] = {
(item["from"], item["to"]): item["cost"] for item in data.get("substitution_costs", {})
}
return cls(
substitution_costs=sub_costs,
insertion_costs=data.get("insertion_costs"),
deletion_costs=data.get("deletion_costs"),
symmetric_substitution=data.get("symmetric_substitution", True),
default_substitution_cost=data.get("default_substitution_cost", 1.0),
default_insertion_cost=data.get("default_insertion_cost", 1.0),
default_deletion_cost=data.get("default_deletion_cost", 1.0),
)