Source code for BaseEnvironment

from typing import Any, Optional, Dict, Union, Tuple, List, Type
from numpy import ndarray
from os.path import isfile, join

from SSD.Core.Storage.Database import Database

from DeepPhysX.Core.Visualization.VedoFactory import VedoFactory
from DeepPhysX.Core.AsyncSocket.AbstractEnvironment import AbstractEnvironment
from DeepPhysX.Core.Database.DatabaseHandler import DatabaseHandler

[docs]class BaseEnvironment(AbstractEnvironment): def __init__(self, as_tcp_ip_client: bool = True, instance_id: int = 1, instance_nb: int = 1, **kwargs): """ BaseEnvironment computes simulated data for the Network and its training process. :param as_tcp_ip_client: Environment is a TcpIpObject if True, is owned by an EnvironmentManager if False. :param instance_id: ID of the instance. :param instance_nb: Number of simultaneously launched instances. """ AbstractEnvironment.__init__(self, as_tcp_ip_client=as_tcp_ip_client, instance_id=instance_id, instance_nb=instance_nb) # Training data variables self.__data_training: Dict[str, ndarray] = {} self.__data_additional: Dict[str, ndarray] = {} self.compute_training_data: bool = True # Dataset data variables self.update_line: Optional[int] = None self.sample_training: Optional[Dict[str, Any]] = None self.sample_additional: Optional[Dict[str, Any]] = None self.__first_add: List[bool] = [True, True] # Connect the Environment to the data Database self.__database_handler = DatabaseHandler(on_init_handler=self.__database_handler_init) # Connect the Factory to the visualization Database self.factory: Optional[VedoFactory] = None ########################################################################################## ########################################################################################## # Environment initialization # ########################################################################################## ##########################################################################################
[docs] def create(self) -> None: """ Create the Environment. Automatically called when Environment is launched. Must be implemented by user. """ raise NotImplementedError
[docs] def init(self) -> None: """ Initialize the Environment. Automatically called when Environment is launched. Not mandatory. """ pass
[docs] def init_database(self) -> None: """ Define the fields of the training dataset. Automatically called when Environment is launched. Must be implemented by user. """ raise NotImplementedError
[docs] def init_visualization(self) -> None: """ Define the visualization objects to send to the Visualizer. Automatically called when Environment is launched. Not mandatory. """ pass
[docs] def save_parameters(self, **kwargs) -> None: """ Save a set of parameters in the Database. """ # Create a dedicated Database database_dir = self.__database_handler.get_database_dir() if isfile(join(database_dir, 'environment_parameters.db')): database = Database(database_dir=database_dir, database_name='environment_parameters').load() else: database = Database(database_dir=database_dir, database_name='environment_parameters').new() # Create fields and add data fields = [(field, type(value)) for field, value in kwargs.items()] database.create_table(table_name=f'Environment_{self.instance_id}', fields=fields) database.add_data(table_name=f'Environment_{self.instance_id}', data=kwargs) database.close()
[docs] def load_parameters(self) -> Dict[str, Any]: """ Load a set of parameters from the Database. """ # Load the dedicated Database and the parameters database_dir = self.__database_handler.get_database_dir() if isfile(join(database_dir, 'environment_parameters.db')): database = Database(database_dir=database_dir, database_name='environment_parameters').load() parameters = database.get_line(table_name=f'Environment_{self.instance_id}') del parameters['id'] return parameters return {}
########################################################################################## ########################################################################################## # Environment behavior # ########################################################################################## ##########################################################################################
[docs] async def step(self) -> None: """ Compute the number of steps in the Environment specified by simulations_per_step in EnvironmentConfig. Must be implemented by user. """ raise NotImplementedError
[docs] def check_sample(self) -> bool: """ Check if the current produced sample is usable for training. Not mandatory. :return: Current data can be used or not """ return True
[docs] def apply_prediction(self, prediction: Dict[str, ndarray]) -> None: """ Apply network prediction in environment. Not mandatory. :param prediction: Prediction data. """ pass
[docs] def close(self) -> None: """ Close the Environment. Automatically called when Environment is shut down. Not mandatory. """ pass
########################################################################################## ########################################################################################## # Defining data samples # ########################################################################################## ##########################################################################################
[docs] def define_training_fields(self, fields: Union[List[Tuple[str, Type]], Tuple[str, Type]]) -> None: """ Specify the training data fields names and types. :param fields: Field or list of fields to tag as training data. """ self.__database_handler.create_fields(table_name='Training', fields=fields)
[docs] def define_additional_fields(self, fields: Union[List[Tuple[str, Type]], Tuple[str, Type]]) -> None: """ Specify the additional data fields names and types. :param fields: Field or list of Fields to tag as additional data. """ self.__database_handler.create_fields(table_name='Additional', fields=fields)
[docs] def set_training_data(self, **kwargs) -> None: """ Set the training data to send to the TcpIpServer or the EnvironmentManager. """ # Check kwargs if self.__first_add[0]: self.__database_handler.load() self.__first_add[0] = False required_fields = list(set(self.__database_handler.get_fields(table_name='Training')) - {'id', 'env_id'}) if len(required_fields) > 0: for field in kwargs.keys(): if field not in required_fields: raise ValueError(f"[{}] The field '{field}' is not in the training Database." f"Required fields are {required_fields}.") for field in required_fields: if field not in kwargs.keys(): raise ValueError(f"[{}] The field '{field}' was not defined in training data." f"Required fields are {required_fields}.") # Training data is set if the Environment can compute data if self.compute_training_data: self.__data_training = kwargs self.__data_training['env_id'] = self.instance_id
[docs] def set_additional_data(self, **kwargs) -> None: """ Set the additional data to send to the TcpIpServer or the EnvironmentManager. """ # Additional data is also set if the Environment can compute data if self.compute_training_data: self.__data_additional = kwargs self.__data_additional['env_id'] = self.instance_id
########################################################################################## ########################################################################################## # Available requests # ########################################################################################## ##########################################################################################
[docs] def get_prediction(self, **kwargs) -> Dict[str, ndarray]: """ Request a prediction from Network. :return: Network prediction. """ # Check kwargs if self.__first_add[1]: if (self.environment_manager is not None and not self.environment_manager.allow_prediction_requests) or \ (self.tcp_ip_client is not None and not self.tcp_ip_client.allow_prediction_requests): raise ValueError(f"[{}] Prediction request is not available in data generation Pipeline.") if len(kwargs) == 0 and len(self.__data_training) == 0: raise ValueError(f"[{}] The prediction request requires the network fields.") self.__database_handler.load() self.__first_add[1] = False required_fields = list(set(self.__database_handler.get_fields(table_name='Exchange')) - {'id'}) for field in kwargs.keys(): if field not in required_fields: raise ValueError(f"[{}] The field '{field}' is not in the training Database." f"Required fields are {required_fields}.") # Avoid empty sample if len(kwargs) == 0: required_fields = set(self.__database_handler.get_fields(table_name='Exchange')) - {'id'} necessary_fields = list(required_fields.intersection(self.__data_training.keys())) kwargs = {field: self.__data_training[field] for field in necessary_fields} # If Environment is a TcpIpClient, send request to the Server if self.tcp_ip_client is not None: return self.tcp_ip_client.get_prediction(**kwargs) # Otherwise, check the hierarchy of managers elif self.environment_manager is not None: if self.environment_manager.data_manager is None: raise ValueError("Cannot request prediction if DataManager does not exist") # Get a prediction self.__database_handler.update(table_name='Exchange', data=kwargs, line_id=self.instance_id) self.environment_manager.data_manager.get_prediction(self.instance_id) data_pred = self.__database_handler.get_line(table_name='Exchange', line_id=self.instance_id) del data_pred['id'] return data_pred else: raise ValueError(f"[{}] This Environment has not Manager.")
[docs] def update_visualisation(self) -> None: """ Triggers the Visualizer update. """ if self.factory is not None: # If Environment is a TcpIpClient, request to the Server if self.as_tcp_ip_client: self.tcp_ip_client.request_update_visualization() self.factory.render()
def _get_prediction(self): """ Request a prediction from Network and apply it to the Environment. Should not be used by users. """ training_data = self.__data_training.copy() required_fields = self.__database_handler.get_fields(table_name='Exchange') for field in self.__data_training.keys(): if field not in required_fields: del training_data[field] self.apply_prediction(self.get_prediction(**training_data)) ########################################################################################## ########################################################################################## # Database communication # ########################################################################################## ##########################################################################################
[docs] def get_database_handler(self) -> DatabaseHandler: """ Get the DatabaseHandler of the Environment. """ return self.__database_handler
def __database_handler_init(self): """ Init event of the DatabaseHandler. """ # Load Database and create basic fields self.__database_handler.load() def _create_visualization(self, visualization_db: Union[Database, Tuple[str, str]]) -> None: """ Create a Factory for the Environment. """ if type(visualization_db) == list: self.factory = VedoFactory(database_path=visualization_db, idx_instance=self.instance_id, remote=True) else: self.factory = VedoFactory(database=visualization_db, idx_instance=self.instance_id) def _send_training_data(self) -> List[int]: """ Add the training data and the additional data in their respective Databases. Should not be used by users. :return: Index of the samples in the Database. """ line_id = self.__database_handler.add_data(table_name='Training', data=self.__data_training) self.__database_handler.add_data(table_name='Additional', data=self.__data_additional) return line_id def _update_training_data(self, line_id: List[int]) -> None: """ Update the training data and the additional data in their respective Databases. Should not be used by users. :param line_id: Index of the samples to update. """ if self.__data_training != {}: self.__database_handler.update(table_name='Training', data=self.__data_training, line_id=line_id) if self.__data_additional != {}: self.__database_handler.update(table_name='Additional', data=self.__data_additional, line_id=line_id) def _get_training_data(self, line_id: List[int]) -> None: """ Get the training data and the additional data from their respective Databases. Should not be used by users. :param line_id: Index of the sample to get. """ self.update_line = line_id self.sample_training = self.__database_handler.get_line(table_name='Training', line_id=line_id) self.sample_additional = self.__database_handler.get_line(table_name='Additional', line_id=line_id) self.sample_additional = None if len(self.sample_additional) == 1 else self.sample_additional def _reset_training_data(self) -> None: """ Reset the training data and the additional data variables. Should not be used by users. """ self.__data_training = {} self.__data_additional = {} self.sample_training = None self.sample_additional = None self.update_line = None def __str__(self): description = "\n" description += f" {}\n" description += f" Name: {}{self.instance_id}\n" description += f" Comments:\n" description += f" Input size:\n" description += f" Output size:\n" return description