import itertools
import logging
import re
from collections import defaultdict, Counter
from pathlib import Path
from typing import DefaultDict, Dict, List, Optional, Tuple, Sequence
import progressbar
from . import punctuationRE
from ._cache import PickledLRUCache, cached
from .dictionary import Dictionary
from .fileio import FileIO
from .tokens import KBestItem, TokenList
[docs]class HMM(object):
log = logging.getLogger(f'{__name__}.HMM')
@property
def init(self) -> DefaultDict[str, float]:
"""Initial probabilities."""
return self._init
@init.setter
def init(self, initial: Dict[str, float]):
self._init = defaultdict(float)
self._init.update(initial)
@property
def tran(self) -> DefaultDict[str, DefaultDict[str, float]]:
"""Transition probabilities."""
return self._tran
@tran.setter
def tran(self, transition: Dict[str, Dict[str, float]]):
self._tran = defaultdict(lambda: defaultdict(float))
for outer, d in transition.items():
for inner, e in d.items():
self._tran[outer][inner] = e
@property
def emis(self) -> DefaultDict[str, DefaultDict[str, float]]:
"""Emission probabilities."""
return self._emis
@emis.setter
def emis(self, emission: Dict[str, Dict[str, float]]):
self._emis = defaultdict(lambda: defaultdict(float))
for outer, d in emission.items():
for inner, e in d.items():
self._emis[outer][inner] = e
def __init__(self, path: Path, multichars=None, dictionary: Dictionary = None):
"""
:param path: Path for loading and saving.
:param multichars: A dictionary of possible multicharacter substitutions (eg. 'cr': 'æ' or vice versa).
:param dictionary: The dictionary against which to check validity.
"""
if multichars is None:
multichars = {}
self.multichars = multichars
self.dictionary = dictionary
self.path = path
if self.path:
HMM.log.info(f'Loading HMM parameters from {path}')
(self.init, self.tran, self.emis) = FileIO.load(path)
else:
(self.init, self.tran, self.emis) = (None, None, None)
self.states = self.init.keys()
#HMM.log.debug(f'init: {self.init}')
#HMM.log.debug(f'tran: {self.tran}')
#HMM.log.debug(f'emis: {self.emis}')
HMM.log.debug(f'states: {self.states}')
if not self.is_valid():
HMM.log.critical(f'Parameter check failed for {self}')
else:
HMM.log.debug(f'HMM initialized: {self}')
self.cache = PickledLRUCache.by_name(f'{__name__}.HMM.kbest')
def __str__(self):
return f'<{self.__class__.__name__} {"".join(sorted(self.states))}>'
def __repr__(self):
return self.__str__()
[docs] def save(self, path: Path = None):
"""
Save the HMM parameters.
:param path: Optional new path to save to.
"""
if not self.is_valid():
HMM.log.error('Not going to save faulty HMM parameters.')
raise SystemExit(-1)
path = path or self.path
HMM.log.info(f'Saving HMM parameters to {path}')
FileIO.save([self.init, self.tran, self.emis], path)
self.cache.delete() # redoing the model invalidates the cache
[docs] def is_valid(self) -> bool:
"""
Verify that parameters are valid (ie. the keys in init/tran/emis match).
"""
all_match = True
if set(self.init) != set(self.tran):
all_match = False
HMM.log.error('Initial keys do not match transition keys.')
if set(self.init) != set(self.emis):
all_match = False
keys = set(self.init).symmetric_difference(set(self.emis))
HMM.log.error(
f'Initial keys do not match emission keys:'
f' diff: {[k for k in keys]}'
f' init: {[self.init.get(k, None) for k in keys]}'
f' emis: {[self.emis.get(k, None) for k in keys]}'
)
for key in self.tran:
if set(self.tran[key]) != set(self.tran):
all_match = False
HMM.log.error(f'Outer transition keys do not match inner keys: {key}')
if all_match:
HMM.log.info('Parameters match.')
return all_match
[docs] def viterbi(self, char_seq: Sequence[str]) -> str:
"""
TODO
:param char_seq:
:return:
"""
# delta[t][j] is probability of max probability path to state j
# at time t given the observation sequence up to time t.
delta: List[Optional[Dict[str, float]]] = [None] * len(char_seq)
back_pointers: List[Optional[Dict[str, float]]] = [None] * len(char_seq)
delta[0] = {i: self.init[i] * self.emis[i][char_seq[0]]
for i in self.states}
for t in range(1, len(char_seq)):
# (preceding state with max probability, value of max probability)
d = {j: max({i: delta[t-1][i] * self.tran[i][j] for i in self.states}.items(),
key=lambda x: x[1]) for j in self.states}
delta[t] = {i: d[i][1] * self.emis[i][char_seq[t]] for i in self.states}
back_pointers[t] = {i: d[i][0] for i in self.states}
best_state = max(delta[-1], key=lambda x: delta[-1][x])
selected_states = [best_state] * len(char_seq)
for t in range(len(char_seq) - 1, 0, -1):
best_state = back_pointers[t][best_state]
selected_states[t-1] = best_state
return ''.join(selected_states)
def _k_best_beam(self, word: str, k: int) -> List[Tuple[str, float]]:
# Single symbol input is just initial * emission.
if len(word) == 1:
paths = [(i, self.init[i] * self.emis[i][word[0]])
for i in self.states]
paths = sorted(paths, key=lambda x: x[1], reverse=True)
else:
# Create the N*N sequences for the first two characters
# of the word.
paths = [((i, j), (self.init[i] * self.emis[i][word[0]] * self.tran[i][j] * self.emis[j][word[1]]))
for i in self.states for j in self.states]
# Keep the k best sequences.
paths = sorted(paths, key=lambda x: x[1], reverse=True)[:k]
# Continue through the input word, only keeping k sequences at
# each time step.
for t in range(2, len(word)):
temp = [(x[0] + (j,), (x[1] * self.tran[x[0][-1]][j] * self.emis[j][word[t]]))
for j in self.states for x in paths]
paths = sorted(temp, key=lambda x: x[1], reverse=True)[:k]
#print(t, len(temp), temp[:5], len(paths), temp[:5])
return [(''.join(seq), prob) for seq, prob in paths[:k]]
[docs] @cached
def kbest_for_word(self, word: str, k: int) -> DefaultDict[int, KBestItem]:
"""
Generates *k*-best correction candidates for a single word.
:param word: The word for which to generate candidates
:param k: How many candidates to generate.
:return: A dictionary with ranked candidates keyed by 1..*k*.
"""
#HMM.log.debug(f'kbest_for_word: {word}')
if len(word) == 0:
return defaultdict(KBestItem, {n: KBestItem('', 0.0) for n in range(1, k+1)})
k_best = self._k_best_beam(word, k)
# Check for common multi-character errors. If any are present,
# make substitutions and compare probabilties of results.
for sub in self.multichars:
# Only perform the substitution if none of the k-best candidates are present in the dictionary
if sub in word and all(punctuationRE.sub('', x[0]) not in self.dictionary for x in k_best):
variant_words = HMM._multichar_variants(word, sub, self.multichars[sub])
for v in variant_words:
if v != word:
k_best.extend(self._k_best_beam(v, k))
# Keep the k best
k_best = sorted(k_best, key=lambda x: x[1], reverse=True)[:k]
return defaultdict(KBestItem, {i: KBestItem(''.join(seq), prob) for (i, (seq, prob)) in enumerate(k_best[:k], 1)})
@classmethod
def _multichar_variants(cls, word: str, original: str, replacements: List[str]):
variants = [original] + replacements
variant_words = set()
pieces = re.split(original, word)
# Reassemble the word using original or replacements
for x in itertools.product(variants, repeat=word.count(original)):
variant_words.add(''.join([elem for pair in itertools.zip_longest(
pieces, x, fillvalue='') for elem in pair]))
return variant_words
[docs] def generate_kbest(self, tokens: TokenList, k: int = 4, force = False):
"""
Generates *k*-best correction candidates for a list of Tokens and adds them
to each token.
:param tokens: List of tokens.
:param k: How many candidates to generate.
"""
if len(tokens) == 0:
HMM.log.error(f'No tokens were supplied?!')
raise SystemExit(-1)
HMM.log.info(f'Generating {k}-best suggestions for each token')
for i, token in enumerate(progressbar.progressbar(tokens)):
if force or not token.kbest:
token.kbest = self.kbest_for_word(token.normalized, k)
#HMM.log.debug(vars(token))
HMM.log.debug(f'Generated for {len(tokens)} tokens, first 10: {tokens[:10]}')
##########################################################################################
# TODO make build class method on HMM instead
[docs]class HMMBuilder(object):
log = logging.getLogger(f'{__name__}.HMMBuilder')
def __init__(self, dictionary: Dictionary, smoothingParameter: float, characterSet, readCounts, remove_chars: List[str], gold_words: List[str]):
"""
Calculates parameters for a HMM based on the input. They can be accessed via the three properties.
:param dictionary: The dictionary to use for generating probabilities.
:param smoothingParameter: Lower bound for probabilities.
:param characterSet: Set of required characters for the final HMM.
:param readCounts: See :class:`Aligner<CorrectOCR.aligner.Aligner>`.
:param remove_chars: List of characters to remove from the final HMM.
:param gold_words: List of known correct words.
"""
self._dictionary = dictionary
self._smoothingParameter = smoothingParameter
self._remove_chars = remove_chars
self._charset = set(characterSet)
confusion = self._generate_confusion(readCounts)
char_counts = self._text_char_counts(gold_words)
self._charset = self._charset | set(char_counts) | set(confusion)
HMMBuilder.log.debug(f'Final characterSet: {sorted(self._charset)}')
# Create the emission probabilities from the read counts and the character counts
emis = self._emission_probabilities(confusion, char_counts)
self.emis: DefaultDict[str, float] = emis #: Emission probabilities.
# Create the initial and transition probabilities from the gold documents
init, tran = self._init_tran_probabilities(gold_words)
self.init: DefaultDict[str, float] = init #: Initial probabilities.
self.tran: DefaultDict[str, DefaultDict[str, float]] = tran #: Transition probabilities.
# Start with read counts, remove any keys which are not single
# characters, remove specified characters, and combine into a single
# dictionary.
def _generate_confusion(self, readCounts: Dict) -> Dict[str, Dict[str, int]]:
# Outer keys are the correct characters. Inner keys are the counts of
# what each character was read as.
confusion = defaultdict(Counter)
confusion.update(readCounts)
# Strip out any outer keys that aren't a single character
confusion = {key: value for key, value in confusion.items()
if len(key) == 1}
for unwanted in self._remove_chars:
if unwanted in confusion:
del confusion[unwanted]
# Strip out any inner keys that aren't a single character.
# Later, these may be useful, for now, remove them.
for outer in confusion:
wrongsize = [key for key in confusion[outer] if len(key) != 1]
for key in wrongsize:
del confusion[outer][key]
for unwanted in self._remove_chars:
if unwanted in confusion[outer]:
del confusion[outer][unwanted]
#HMMBuilder.log.debug(confusion)
return confusion
# Get the character counts of the training docs. Used for filling in
# gaps in the confusion probabilities.
def _text_char_counts(self, words: List[str]) -> Dict[str, int]:
char_count = Counter()
#HMMBuilder.log.debug(f'words: {words}')
for word in words:
char_count.update(list(word))
for word in self._dictionary:
char_count.update(list(word))
for char in set(char_count.keys()):
if char not in self._charset:
del char_count[char]
for unwanted in self._remove_chars:
if unwanted in char_count:
del char_count[unwanted]
return char_count
# Create the emission probabilities using read counts and character
# counts. Optionally a file of expected characters can be used to add
# expected characters as model states whose emission probabilities are set to
# only output themselves.
def _emission_probabilities(self, confusion, char_counts):
# Add missing dictionary elements.
# Missing outer terms are ones which were always read correctly.
for char in char_counts:
if char not in confusion:
confusion[char] = {char: char_counts[char]}
# Inner terms are just added with 0 probability.
charset = set().union(*[confusion[i].keys() for i in confusion])
for char in confusion:
for missing in charset:
if missing not in confusion[char]:
confusion[char][missing] = 0.0
# Smooth and convert to probabilities.
for i in confusion:
denom = sum(confusion[i].values()) + (self._smoothingParameter * len(confusion[i]))
for j in confusion[i]:
confusion[i][j] = (confusion[i][j] + self._smoothingParameter) / denom
# Add characters that are expected to occur in the texts.
# Get the characters which aren't already present.
extra_chars = self._charset - set(self._remove_chars)
# Add them as new states.
for char in extra_chars:
if char not in confusion:
confusion[char] = {i: 0 for i in charset}
# Add them with 0 probability to every state.
for i in confusion:
for char in extra_chars:
if char not in confusion[i]:
confusion[i][char] = 0.0
# Set them to emit themselves
for char in extra_chars:
confusion[char][char] = 1.0
for outer in set(confusion.keys()):
if outer not in self._charset:
del confusion[outer]
else:
for inner in set(confusion[outer].keys()):
if inner not in self._charset:
del confusion[outer][inner]
#logging.getLogger(f'{__name__}.emission_probabilities').debug(confusion)
return confusion
# Create the initial and transition probabilities from the gold
# text in the training data.
def _init_tran_probabilities(self, gold_words):
tran = defaultdict(lambda: defaultdict(int))
init = defaultdict(int)
def add_word(_word):
if len(_word) > 0:
init[_word[0]] += 1
# Record each occurrence of character pair ij in tran[i][j]
for (a, b) in zip(_word[0:], _word[1:]):
tran[a][b] += 1
for word in gold_words:
add_word(word)
for word in self._dictionary:
add_word(word)
for unwanted in self._remove_chars:
if unwanted in self._charset:
self._charset.remove(unwanted)
if unwanted in init:
del init[unwanted]
if unwanted in tran:
del tran[unwanted]
for i in tran:
if unwanted in tran[i]:
del tran[i][unwanted]
tran_out = defaultdict(lambda: defaultdict(float))
init_out = defaultdict(float)
# Add missing characters to the parameter dictionaries and apply smoothing.
init_denom = sum(init.values()) + (self._smoothingParameter * len(self._charset))
for i in self._charset:
init_out[i] = (init[i] + self._smoothingParameter) / init_denom
tran_denom = sum(tran[i].values()) + (self._smoothingParameter * len(self._charset))
for j in self._charset:
tran_out[i][j] = (tran[i][j] + self._smoothingParameter) / tran_denom
return init_out, tran_out