from typing import Any, Optional, Type
from os.path import isdir
from numpy import typeDict
from DeepPhysX.Core.Network.BaseNetwork import BaseNetwork
from DeepPhysX.Core.Network.BaseOptimization import BaseOptimization
from DeepPhysX.Core.Network.BaseTransformation import BaseTransformation
from DeepPhysX.Core.Utils.configs import make_config, namedtuple
[docs]class BaseNetworkConfig:
def __init__(self,
network_class: Type[BaseNetwork] = BaseNetwork,
optimization_class: Type[BaseOptimization] = BaseOptimization,
data_transformation_class: Type[BaseTransformation] = BaseTransformation,
network_dir: Optional[str] = None,
network_name: str = 'Network',
network_type: str = 'BaseNetwork',
which_network: int = -1,
save_each_epoch: bool = False,
data_type: str = 'float32',
lr: Optional[float] = None,
require_training_stuff: bool = True,
loss: Optional[Any] = None,
optimizer: Optional[Any] = None):
"""
BaseNetworkConfig is a configuration class to parameterize and create BaseNetwork, BaseOptimization and
BaseTransformation for the NetworkManager.
:param network_class: BaseNetwork class from which an instance will be created.
:param optimization_class: BaseOptimization class from which an instance will be created.
:param data_transformation_class: BaseTransformation class from which an instance will be created.
:param network_dir: Name of an existing network repository.
:param network_name: Name of the network.
:param network_type: Type of the network.
:param which_network: If several networks in network_dir, load the specified one.
:param save_each_epoch: If True, network state will be saved at each epoch end; if False, network state
will be saved at the end of the training.
:param data_type: Type of the training data.
:param lr: Learning rate.
:param require_training_stuff: If specified, loss and optimizer class can be not necessary for training.
:param loss: Loss class.
:param optimizer: Network's parameters optimizer class.
"""
self.name = self.__class__.__name__
# Check network_dir type and existence
if network_dir is not None:
if type(network_dir) != str:
raise TypeError(
f"[{self.__class__.__name__}] Wrong 'network_dir' type: str required, get {type(network_dir)}")
if not isdir(network_dir):
raise ValueError(f"[{self.__class__.__name__}] Given 'network_dir' does not exists: {network_dir}")
# Check network_name type
if type(network_name) != str:
raise TypeError(
f"[{self.__class__.__name__}] Wrong 'network_name' type: str required, get {type(network_name)}")
# Check network_tpe type
if type(network_type) != str:
raise TypeError(
f"[{self.__class__.__name__}] Wrong 'network_type' type: str required, get {type(network_type)}")
# Check which_network type and value
if type(which_network) != int:
raise TypeError(
f"[{self.__class__.__name__}] Wrong 'which_network' type: int required, get {type(which_network)}")
# Check save_each_epoch type
if type(save_each_epoch) != bool:
raise TypeError(
f"[{self.__class__.__name__}] Wrong 'save each epoch' type: bool required, get {type(save_each_epoch)}")
# Check data type
if data_type not in typeDict:
raise ValueError(
f"[{self.__class__.__name__}] The following data type is not a numpy type: {data_type}")
# BaseNetwork parameterization
self.network_class: Type[BaseNetwork] = network_class
self.network_config: namedtuple = make_config(configuration_object=self,
configuration_name='network_config',
network_name=network_name,
network_type=network_type,
data_type=data_type)
# BaseOptimization parameterization
self.optimization_class: Type[BaseOptimization] = optimization_class
self.optimization_config: namedtuple = make_config(configuration_object=self,
configuration_name='optimization_config',
loss=loss,
lr=lr,
optimizer=optimizer)
self.training_stuff: bool = (loss is not None) and (optimizer is not None) or (not require_training_stuff)
# NetworkManager parameterization
self.data_transformation_class: Type[BaseTransformation] = data_transformation_class
self.data_transformation_config: namedtuple = make_config(configuration_object=self,
configuration_name='data_transformation_config')
# NetworkManager parameterization
self.network_dir: str = network_dir
self.which_network: int = which_network
self.save_each_epoch: bool = save_each_epoch and self.training_stuff
[docs] def create_network(self) -> BaseNetwork:
"""
Create an instance of network_class with given parameters.
:return: BaseNetwork object from network_class and its parameters.
"""
# Create instance
network = self.network_class(config=self.network_config)
if not isinstance(network, BaseNetwork):
raise TypeError(f"[{self.name}] The given 'network_class'={self.network_class} must be a BaseNetwork.")
return network
[docs] def create_optimization(self) -> BaseOptimization:
"""
Create an instance of optimization_class with given parameters.
:return: BaseOptimization object from optimization_class and its parameters.
"""
# Create instance
optimization = self.optimization_class(config=self.optimization_config)
if not isinstance(optimization, BaseOptimization):
raise TypeError(f"[{self.name}] The given 'optimization_class'={self.optimization_class} must be a "
f"BaseOptimization.")
return optimization
def __str__(self):
description = "\n"
description += f"{self.__class__.__name__}\n"
description += f" Network class: {self.network_class.__name__}\n"
description += f" Optimization class: {self.optimization_class.__name__}\n"
description += f" Training materials: {self.training_stuff}\n"
description += f" Network directory: {self.network_dir}\n"
description += f" Which network: {self.which_network}\n"
description += f" Save each epoch: {self.save_each_epoch}\n"
return description