from typing import Any, Dict, Type
from socket import socket
from asyncio import get_event_loop
from asyncio import AbstractEventLoop as EventLoop
from asyncio import run as async_run
from numpy import ndarray
from DeepPhysX.Core.AsyncSocket.TcpIpObject import TcpIpObject
from DeepPhysX.Core.AsyncSocket.AbstractEnvironment import AbstractEnvironment
[docs]class TcpIpClient(TcpIpObject):
def __init__(self,
environment: Type[AbstractEnvironment],
ip_address: str = 'localhost',
port: int = 10000,
instance_id: int = 0,
instance_nb: int = 1):
"""
TcpIpClient is a TcpIpObject which communicate with a TcpIpServer and manages an Environment to compute data.
:param environment: Environment class.
:param ip_address: IP address of the TcpIpObject.
:param port: Port number of the TcpIpObject.
:param instance_id: Index of this instance.
:param instance_nb: Number of simultaneously launched instances.
"""
TcpIpObject.__init__(self,
ip_address=ip_address,
port=port)
# Environment instance
self.environment: AbstractEnvironment
self.environment_class = environment
self.environment_instance = (instance_id, instance_nb)
# Bind to client address and send ID
self.sock.connect((ip_address, port))
self.sync_send_labeled_data(data_to_send=instance_id, label="instance_ID", receiver=self.sock,
send_read_command=False)
self.close_client: bool = False
##########################################################################################
##########################################################################################
# Initializing Environment #
##########################################################################################
##########################################################################################
[docs] def initialize(self) -> None:
"""
Receive parameters from the server to create environment.
"""
async_run(self.__initialize())
async def __initialize(self) -> None:
"""
Receive parameters from the server to create environment.
"""
loop = get_event_loop()
# Receive additional arguments
env_kwargs = {}
await self.receive_dict(recv_to=env_kwargs, loop=loop, sender=self.sock)
env_kwargs = env_kwargs['env_kwargs'] if 'env_kwargs' in env_kwargs else {}
self.environment = self.environment_class(as_tcp_ip_client=True,
instance_id=self.environment_instance[0],
instance_nb=self.environment_instance[1],
**env_kwargs)
self.environment.tcp_ip_client = self
# Receive prediction requests authorization
self.allow_prediction_requests = await self.receive_data(loop=loop, sender=self.sock)
# Receive number of sub-steps
self.simulations_per_step = await self.receive_data(loop=loop, sender=self.sock)
# Receive partitions
partitions_list = await self.receive_data(loop=loop, sender=self.sock)
partitions_list, exchange = partitions_list.split('%%%')
partitions = [[partitions_list.split('///')[0], partition_name]
for partition_name in partitions_list.split('///')[1:]]
exchange = [exchange.split('///')[0], exchange.split('///')[1]]
self.environment.get_database_handler().init_remote(storing_partitions=partitions,
exchange_db=exchange)
# Receive visualization database
visualization_db = await self.receive_data(loop=loop, sender=self.sock)
visualization_db = None if visualization_db == 'None' else visualization_db.split('///')
# Initialize the environment
self.environment.create()
self.environment.init()
self.environment.init_database()
if visualization_db is not None:
self.environment._create_visualization(visualization_db=visualization_db)
self.environment.init_visualization()
# Initialization done
await self.send_data(data_to_send='done', loop=loop, receiver=self.sock)
# Synchronize Database
_ = await self.receive_data(loop=loop, sender=self.sock)
self.environment.get_database_handler().load()
##########################################################################################
##########################################################################################
# Running Client #
##########################################################################################
##########################################################################################
[docs] def launch(self) -> None:
"""
Trigger the main communication protocol with the server.
"""
async_run(self.__launch())
async def __launch(self) -> None:
"""
Trigger the main communication protocol with the server.
"""
try:
# Run the communication protocol with server while Client is not asked to shut down
while not self.close_client:
await self.__communicate(server=self.sock)
except KeyboardInterrupt:
print(f"[{self.name}] KEYBOARD INTERRUPT: CLOSING PROCEDURE")
finally:
# Closing procedure when Client is asked to shut down
await self.__close()
async def __communicate(self,
server: socket) -> None:
"""
Communication protocol with a server. First receive a command from the client, then process the appropriate
actions.
:param server: TcpIpServer to communicate with.
"""
loop = get_event_loop()
await self.listen_while_not_done(loop=loop, sender=server, data_dict={})
async def __close(self) -> None:
"""
Close the environment and shutdown the client.
"""
# Close environment
try:
self.environment.close()
except NotImplementedError:
pass
# Confirm exit command to the server
loop = get_event_loop()
await self.send_command_exit(loop=loop, receiver=self.sock)
# Close socket
self.sock.close()
##########################################################################################
##########################################################################################
# Available requests to Server #
##########################################################################################
##########################################################################################
[docs] def get_prediction(self, **kwargs) -> Dict[str, ndarray]:
"""
Request a prediction from Network.
:return: Prediction of the Network.
"""
# Get a prediction
self.environment.get_database_handler().update(table_name='Exchange',
data=kwargs,
line_id=self.environment.instance_id)
self.sync_send_command_prediction()
_ = self.sync_receive_data()
data_pred = self.environment.get_database_handler().get_line(table_name='Exchange',
line_id=self.environment.instance_id)
del data_pred['id']
return data_pred
[docs] def request_update_visualization(self) -> None:
"""
Triggers the Visualizer update.
"""
self.sync_send_command_visualisation()
self.sync_send_labeled_data(data_to_send=self.environment.instance_id,
label='instance')
##########################################################################################
##########################################################################################
# Actions to perform on commands #
##########################################################################################
##########################################################################################
[docs] async def action_on_exit(self,
data: ndarray,
client_id: int,
sender: socket,
loop: EventLoop) -> None:
"""
Action to run when receiving the 'exit' command.
:param data: Dict storing data.
:param client_id: ID of the TcpIpClient.
:param loop: Asyncio event loop.
:param sender: TcpIpObject sender.
"""
# Close client flag set to True
self.close_client = True
[docs] async def action_on_prediction(self,
data: ndarray,
client_id: int,
sender: socket,
loop: EventLoop) -> None:
"""
Action to run when receiving the 'prediction' command.
:param data: Dict storing data.
:param client_id: ID of the TcpIpClient.
:param loop: Asyncio event loop.
:param sender: TcpIpObject sender.
"""
# Receive prediction
prediction = await self.receive_data(loop=loop, sender=sender)
# Apply the prediction in Environment
self.environment.apply_prediction(prediction)
[docs] async def action_on_sample(self,
data: ndarray,
client_id: int,
sender: socket,
loop: EventLoop) -> None:
"""
Action to run when receiving the 'sample' command.
:param data: Dict storing data.
:param client_id: ID of the TcpIpClient.
:param loop: Asyncio event loop.
:param sender: TcpIpObject sender.
"""
dataset_batch = await self.receive_data(loop=loop, sender=sender)
self.environment._get_training_data(dataset_batch)
[docs] async def action_on_step(self,
data: ndarray,
client_id: int,
sender: socket,
loop: EventLoop) -> None:
"""
Action to run when receiving the 'step' command.
:param data: Dict storing data.
:param client_id: ID of the TcpIpClient.
:param loop: Asyncio event loop.
:param sender: TcpIpObject sender.
"""
# Execute the required number of steps
for step in range(self.simulations_per_step):
# Compute data only on final step
self.compute_training_data = step == self.simulations_per_step - 1
await self.environment.step()
# If produced sample is not usable, run again
while not self.environment.check_sample():
for step in range(self.simulations_per_step):
# Compute data only on final step
self.compute_training_data = step == self.simulations_per_step - 1
await self.environment.step()
# Sent training data to Server
if self.environment.update_line is None:
line = self.environment._send_training_data()
else:
self.environment._update_training_data(self.environment.update_line)
line = self.environment.update_line
self.environment._reset_training_data()
await self.send_command_done(loop=loop, receiver=sender)
await self.send_data(data_to_send=line, loop=loop, receiver=sender)
[docs] async def action_on_change_db(self,
data: Dict[Any, Any],
client_id: int, sender: socket,
loop: EventLoop) -> None:
"""
Action to run when receiving the 'step' command.
:param data: Dict storing data.
:param client_id: ID of the TcpIpClient.
:param loop: Asyncio event loop.
:param sender: TcpIpObject sender.
"""
# Update the partition list in the DatabaseHandler
new_database = await self.receive_data(loop=loop, sender=sender)
self.environment.get_database_handler().update_list_partitions_remote(new_database.split('///'))