Source code for CorrectOCR.tokens.list._super

from __future__ import annotations

import abc
import collections
import logging
import random


[docs]class TokenList(collections.abc.MutableSequence): log = logging.getLogger(f'{__name__}.TokenList') _subclasses = dict()
[docs] @staticmethod def register(storagetype: str): """ Decorator which registers a :class:`TokenList` subclass with the base class. :param storagetype: `fs` or `db` """ def wrapper(cls): TokenList._subclasses[storagetype] = cls return cls return wrapper
[docs] @staticmethod def new(config, docid = None, tokens = None) -> TokenList: if tokens: return TokenList.for_type(config.type)(config, docid=docid, tokens=tokens) else: return TokenList.for_type(config.type)(config, docid=docid)
[docs] @staticmethod def for_type(type: str) -> TokenList.__class__: TokenList.log.debug(f'_subclasses: {TokenList._subclasses}') if type not in TokenList._subclasses: raise NameError(f'Unknown storage type: {type}') return TokenList._subclasses[type]
def __init__(self, config, docid = None, tokens = None): if type(self) is TokenList: raise TypeError("Token base class cannot not be directly instantiated") self.config = config self.docid = docid if tokens: self.tokens = tokens else: self.tokens = list() TokenList.log.debug(f'init: {self.config} {self.docid}') def __str__(self): output = [] ts = iter(self) for t in ts: #TokenList.log.debug(f't: {t}') output.append(t.gold or t.original) #TokenList.log.debug(f'output: {output}') if t.is_hyphenated: n = next(ts) #TokenList.log.debug(f'n: {n}') output[-1] = output[-1][:-1] + (n.gold or n.original) #TokenList.log.debug(f'output: {output}') return str.join(' ', output) def __len__(self): return len(self.tokens) def __delitem__(self, key): return self.tokens.__delitem__(key) def __setitem__(self, key, value): return self.tokens.__setitem__(key, value)
[docs] def insert(self, key, value): return self.tokens.insert(key, value)
def __getitem__(self, key): return self.tokens.__getitem__(key)
[docs] @staticmethod def exists(config, docid: str) -> bool: return TokenList.for_type(config.type).exists(config, docid)
[docs] @abc.abstractmethod def load(self, docid: str): pass
[docs] @abc.abstractmethod def save(self, token: 'Token' = None): pass
@property def corrected_count(self): return len([t for t in self if t.gold and t.gold != '']) @property def discarded_count(self): return len([t for t in self if not t.is_discarded])
[docs] def random_token_index(self, has_gold=False, is_discarded=False): return self.random_token(has_gold, is_discarded).index
[docs] def random_token(self, has_gold=False, is_discarded=False): filtered_tokens = filter(lambda t: t.is_discarded == is_discarded, self.tokens) if has_gold: filtered_tokens = filter(lambda t: t.gold and t.gold != '', filtered_tokens) filtered_tokens = list(filtered_tokens) if len(filtered_tokens) == 0: return None else: return random.choice(filtered_tokens)
##########################################################################################