Source code for BaseTransformation

from typing import Callable, Any, Optional, Tuple, Dict
from collections import namedtuple


[docs]class BaseTransformation: def __init__(self, config: namedtuple): """ BaseTransformation manages data operations before and after network predictions. :param config: Set of BaseTransformation parameters. """ self.name = self.__class__.__name__ self.config: Any = config self.data_type = any @staticmethod def check_type(func: Callable[[Any, Any], Any]): def inner(self, *args): for data in [a for a in args if a is not None]: for value in data.values(): if value is not None and type(value) != self.data_type: raise TypeError(f"[{self.name}] Wrong data type: {self.data_type} required, get {type(value)}") return func(self, *args) return inner
[docs] def transform_before_prediction(self, data_net: Dict[str, Any]) -> Dict[str, Any]: """ Apply data operations before network's prediction. :param data_net: Data used by the Network. :return: Transformed data_net. """ return data_net
[docs] def transform_before_loss(self, data_pred: Dict[str, Any], data_opt: Optional[Dict[str, Any]] = None) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: """ Apply data operations between network's prediction and loss computation. :param data_pred: Data produced by the Network. :param data_opt: Data used by the Optimizer. :return: Transformed data_pred, data_opt. """ return data_pred, data_opt
[docs] def transform_before_apply(self, data_pred: Dict[str, Any]) -> Dict[str, Any]: """ Apply data operations between loss computation and prediction apply in environment. :param data_pred: Data produced by the Network. :return: Transformed data_pred. """ return data_pred
def __str__(self): description = "\n" description += f" {self.__class__.__name__}\n" description += f" Data type: {self.data_type}\n" description += f" Transformation before prediction: Identity\n" description += f" Transformation before loss: Identity\n" description += f" Transformation before apply: Identity\n" return description