from typing import Any, Optional, Dict, List
from DeepPhysX.Core.Manager.DatabaseManager import DatabaseManager
from DeepPhysX.Core.Manager.EnvironmentManager import EnvironmentManager
from DeepPhysX.Core.Environment.BaseEnvironmentConfig import BaseEnvironmentConfig
from DeepPhysX.Core.Database.BaseDatabaseConfig import BaseDatabaseConfig
from DeepPhysX.Core.Database.DatabaseHandler import DatabaseHandler
[docs]class DataManager:
def __init__(self,
pipeline: Any,
database_config: Optional[BaseDatabaseConfig] = None,
environment_config: Optional[BaseEnvironmentConfig] = None,
session: str = 'sessions/default',
new_session: bool = True,
produce_data: bool = True,
batch_size: int = 1):
"""
DataManager deals with the generation, storage and loading of training data.
:param pipeline: Pipeline that handle the DataManager.
:param database_config: Configuration object with the parameters of the Database.
:param environment_config: Configuration object with the parameters of the Environment.
:param session: Path to the session repository.
:param new_session: If True, the session is done in a new repository.
:param produce_data: If True, this session will store data in the Database.
:param batch_size: Number of samples in a single batch.
"""
self.name: str = self.__class__.__name__
# Session variables
self.pipeline: Optional[Any] = pipeline
self.database_manager: Optional[DatabaseManager] = None
self.environment_manager: Optional[EnvironmentManager] = None
# Create a DatabaseManager
self.database_manager = DatabaseManager(database_config=database_config,
data_manager=self,
pipeline=pipeline.type,
session=session,
new_session=new_session,
produce_data=produce_data)
# Create an EnvironmentManager if required
if environment_config is not None:
self.environment_manager = EnvironmentManager(environment_config=environment_config,
data_manager=self,
pipeline=pipeline.type,
session=session,
produce_data=produce_data,
batch_size=batch_size)
# DataManager variables
self.produce_data = produce_data
self.batch_size = batch_size
self.data_lines: List[List[int]] = []
@property
def nb_environment(self) -> Optional[int]:
"""
Get the number of Environments managed by the EnvironmentManager.
"""
if self.environment_manager is None:
return None
return 1 if self.environment_manager.server is None else self.environment_manager.number_of_thread
@property
def normalization(self) -> Dict[str, List[float]]:
"""
Get the normalization coefficients computed by the DatabaseManager.
"""
return self.database_manager.normalization
[docs] def connect_handler(self,
handler: DatabaseHandler) -> None:
"""
Add a new DatabaseHandler to the list of handlers of the DatabaseManager.
:param handler: New handler to register.
"""
self.database_manager.connect_handler(handler)
[docs] def get_data(self,
epoch: int = 0,
animate: bool = True,
load_samples: bool = True) -> None:
"""
Fetch data from the EnvironmentManager or the DatabaseManager according to the context.
:param epoch: Current epoch number.
:param animate: Allow EnvironmentManager to trigger a step itself in order to generate a new sample.
:param load_samples: If True, trigger a sample loading from the Database.
"""
# Data generation case
if self.pipeline.type == 'data_generation':
self.environment_manager.get_data(animate=animate)
self.database_manager.add_data()
# Training case
elif self.pipeline.type == 'training':
# Get data from Environment(s) if used and if the data should be created at this epoch
if self.environment_manager is not None and self.produce_data and \
(epoch == 0 or self.environment_manager.always_produce):
self.data_lines = self.environment_manager.get_data(animate=animate)
self.database_manager.add_data(self.data_lines)
# Get data from Dataset
else:
self.data_lines = self.database_manager.get_data(batch_size=self.batch_size)
# Dispatch a batch to clients
if self.environment_manager is not None:
if self.environment_manager.load_samples and \
(epoch == 0 or not self.environment_manager.only_first_epoch):
self.environment_manager.dispatch_batch(data_lines=self.data_lines,
animate=animate)
# Environment is no longer used
else:
self.environment_manager.close()
self.environment_manager = None
# Prediction pipeline
else:
# Get data from Dataset
if self.environment_manager.load_samples:
if load_samples:
self.data_lines = self.database_manager.get_data(batch_size=1)
self.environment_manager.dispatch_batch(data_lines=self.data_lines,
animate=animate,
request_prediction=True,
save_data=self.produce_data)
# Get data from Environment
else:
self.data_lines = self.environment_manager.get_data(animate=animate,
request_prediction=True,
save_data=self.produce_data)
if self.produce_data:
self.database_manager.add_data(self.data_lines)
[docs] def load_sample(self) -> List[int]:
"""
Load a sample from the Database.
:return: Index of the loaded line.
"""
self.data_lines = self.database_manager.get_data(batch_size=1)
return self.data_lines[0]
[docs] def get_prediction(self,
instance_id: int) -> None:
"""
Get a Network prediction for the specified Environment instance.
"""
# Get a prediction
if self.pipeline is None:
raise ValueError("Cannot request prediction if Manager (and then NetworkManager) does not exist.")
self.pipeline.network_manager.compute_online_prediction(instance_id=instance_id,
normalization=self.normalization)
[docs] def close(self) -> None:
"""
Launch the closing procedure of the DataManager.
"""
if self.environment_manager is not None:
self.environment_manager.close()
if self.database_manager is not None:
self.database_manager.close()
def __str__(self):
data_manager_str = ""
if self.environment_manager:
data_manager_str += str(self.environment_manager)
if self.database_manager:
data_manager_str += str(self.database_manager)
return data_manager_str