import itertools
import math
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Optional
if TYPE_CHECKING:
from .edit_operation import EditOperation
from .levenshtein import WeightedLevenshtein
from .protocols import Aligner
CostFunction = Callable[[float], float]
def negative_log_likelihood(probability: float) -> float:
if probability <= 0.0:
raise ValueError("Probability must be positive to compute negative log likelihood.")
return -math.log(probability)
@dataclass
class TallyCounts:
substitutions: defaultdict[tuple[str, str], int] = field(
default_factory=lambda: defaultdict(int)
)
insertions: defaultdict[str, int] = field(default_factory=lambda: defaultdict(int))
deletions: defaultdict[str, int] = field(default_factory=lambda: defaultdict(int))
source_chars: defaultdict[str, int] = field(default_factory=lambda: defaultdict(int))
target_chars: defaultdict[str, int] = field(default_factory=lambda: defaultdict(int))
vocab: set[str] = field(default_factory=set)
@dataclass
class _Costs:
substitutions: dict[tuple[str, str], float]
insertions: dict[str, float]
deletions: dict[str, float]
[docs]
class CostLearner:
"""
Configures and executes the process of learning Levenshtein costs from data.
This class uses a builder pattern, allowing chaining configuration methods
before running the final calculation with .fit().
Example::
from ocr_stringdist import CostLearner
data = [
("Hell0", "Hello"),
]
learner = CostLearner().with_smoothing(1.0)
wl = learner.fit(data) # Substitution 0 -> o learned with cost < 1.0
"""
# Configuration parameters
_smoothing_k: float
# These attributes are set during fitting
counts: Optional[TallyCounts] = None
vocab_size: Optional[int] = None
def __init__(self) -> None:
self._smoothing_k = 1.0
[docs]
def with_smoothing(self, k: float) -> "CostLearner":
r"""
Sets the smoothing parameter `k`.
This parameter controls how strongly the model defaults to a uniform
probability distribution by adding a "pseudo-count" of `k` to every
possible event.
:param k: The smoothing factor, which must be a non-negative number.
:return: The CostLearner instance for method chaining.
:raises ValueError: If k < 0.
Notes
-----
This parameter allows for a continuous transition between two modes:
- **k > 0 (recommended):** This enables additive smoothing, with `k = 1.0`
being Laplace smoothing. It regularizes the model by assuming no event is impossible.
The final costs are a measure of "relative surprisal," normalized by the vocabulary size
- **k = 0:** This corresponds to a normalized Maximum Likelihood Estimation.
Probabilities are derived from the raw observed frequencies. The final costs are
normalized using the same logic as the `k > 0` case, making `k=0` the continuous limit
of the smoothed model. In this mode, costs can only be calculated for events observed in
the training data. Unseen events will receive the default cost, regardless of
the value of `calculate_for_unseen` in :meth:`fit`.
"""
if k < 0:
raise ValueError("Smoothing parameter k must be non-negative.")
self._smoothing_k = k
return self
def _tally_operations(self, operations: Iterable["EditOperation"]) -> TallyCounts:
"""Tally all edit operations."""
counts = TallyCounts()
for op in operations:
if op.source_token is not None:
counts.vocab.add(op.source_token)
if op.target_token is not None:
counts.target_chars[op.target_token] += 1
counts.vocab.add(op.target_token)
if op.op_type == "substitute":
if op.source_token is None or op.target_token is None:
raise ValueError("Tokens cannot be None for 'substitute'")
counts.substitutions[(op.source_token, op.target_token)] += 1
counts.source_chars[op.source_token] += 1
elif op.op_type == "delete":
if op.source_token is None:
raise ValueError("Source token cannot be None for 'delete'")
counts.deletions[op.source_token] += 1
counts.source_chars[op.source_token] += 1
elif op.op_type == "insert":
if op.target_token is None:
raise ValueError("Target token cannot be None for 'insert'")
counts.insertions[op.target_token] += 1
elif op.op_type == "match":
if op.source_token is None:
raise ValueError("Source token cannot be None for 'match'")
counts.source_chars[op.source_token] += 1
return counts
def _calculate_costs(
self, counts: TallyCounts, vocab: set[str], calculate_for_unseen: bool = False
) -> _Costs:
"""
Calculates the costs for edit operations based on tallied counts.
"""
sub_costs: dict[tuple[str, str], float] = {}
ins_costs: dict[str, float] = {}
del_costs: dict[str, float] = {}
k = self._smoothing_k
if k == 0:
calculate_for_unseen = False
# Error space size V for all conditional probabilities.
# The space of possible outcomes for a given source character (from OCR)
# includes all vocab characters (for matches/substitutions) plus the empty
# character (for deletions). This gives V = len(vocab) + 1.
# Symmetrically, the space of outcomes for a given target character (from GT)
# includes all vocab characters plus the empty character (for insertions/misses).
V = len(vocab) + 1
# Normalization ceiling Z' = -log(1/V).
normalization_ceiling = math.log(V) if V > 1 else 1.0
# Substitutions
sub_iterator = (
itertools.product(vocab, vocab) if calculate_for_unseen else counts.substitutions.keys()
)
for source, target in sub_iterator:
count = counts.substitutions[(source, target)]
total_count = counts.source_chars[source]
prob = (count + k) / (total_count + k * V)
base_cost = negative_log_likelihood(prob)
sub_costs[(source, target)] = base_cost / normalization_ceiling
# Deletions
del_iterator = vocab if calculate_for_unseen else counts.deletions.keys()
for source in del_iterator:
count = counts.deletions[source]
total_count = counts.source_chars[source]
prob = (count + k) / (total_count + k * V)
base_cost = negative_log_likelihood(prob)
del_costs[source] = base_cost / normalization_ceiling
# Insertions
ins_iterator = vocab if calculate_for_unseen else counts.insertions.keys()
for target in ins_iterator:
count = counts.insertions[target]
total_target_count = counts.target_chars[target]
prob = (count + k) / (total_target_count + k * V)
base_cost = negative_log_likelihood(prob)
ins_costs[target] = base_cost / normalization_ceiling
return _Costs(substitutions=sub_costs, insertions=ins_costs, deletions=del_costs)
def _calculate_operations(
self, pairs: Iterable[tuple[str, str]], aligner: "Aligner"
) -> list["EditOperation"]:
"""Calculate edit operations for all string pairs using the provided aligner."""
all_ops = [
op
for ocr_str, truth_str in pairs
for op in aligner.explain(ocr_str, truth_str, filter_matches=False)
]
return all_ops
[docs]
def fit(
self,
pairs: Iterable[tuple[str, str]],
*,
initial_model: "Aligner | None" = None,
calculate_for_unseen: bool = False,
) -> "WeightedLevenshtein":
"""
Fits the costs of a WeightedLevenshtein instance to the provided data.
Note that learning multi-character tokens is only supported if an initial alignment model
is provided that can handle those multi-character tokens.
This method analyzes pairs of strings to learn the costs of edit operations
based on their observed frequencies. The underlying model calculates costs
based on the principle of relative information cost.
For a detailed explanation of the methodology, please see the
:doc:`Cost Learning Model <cost_learning_model>` documentation page.
:param pairs: An iterable of (ocr_string, ground_truth_string) tuples.
:param initial_model: Optional initial model used to align OCR outputs and ground truth
strings. By default, an unweighted Levenshtein distance is used.
:param calculate_for_unseen: If True (and k > 0), pre-calculates costs for all
possible edit operations based on the vocabulary.
If False (default), only calculates costs for operations
observed in the data.
:return: A `WeightedLevenshtein` instance with the learned costs.
"""
from .levenshtein import WeightedLevenshtein
if not pairs:
return WeightedLevenshtein.unweighted()
if initial_model is None:
initial_model = WeightedLevenshtein.unweighted()
all_ops = self._calculate_operations(pairs, aligner=initial_model)
self.counts = self._tally_operations(all_ops)
vocab = self.counts.vocab
self.vocab_size = len(vocab)
if not self.vocab_size:
return WeightedLevenshtein.unweighted()
costs = self._calculate_costs(self.counts, vocab, calculate_for_unseen=calculate_for_unseen)
return WeightedLevenshtein(
substitution_costs=costs.substitutions,
insertion_costs=costs.insertions,
deletion_costs=costs.deletions,
symmetric_substitution=False,
default_substitution_cost=1.0,
default_insertion_cost=1.0,
default_deletion_cost=1.0,
)