from typing import Dict, Any
from collections import namedtuple
from DeepPhysX.Core.Network.BaseNetwork import BaseNetwork
[docs]class BaseOptimization:
def __init__(self, config: namedtuple):
"""
BaseOptimization computes loss between prediction and target and optimizes the Network parameters.
:param config: Set of BaseOptimization parameters.
"""
self.manager: Any = None
# Loss
self.loss_class = config.loss
self.loss = None
self.loss_value = 0.
# Optimizer
self.optimizer_class = config.optimizer
self.optimizer = None
self.lr = config.lr
[docs] def set_loss(self) -> None:
"""
Initialize the loss function.
"""
raise NotImplementedError
[docs] def compute_loss(self,
data_pred: Dict[str, Any],
data_opt: Dict[str, Any]) -> Dict[str, Any]:
"""
Compute loss from prediction / ground truth.
:param data_pred: Tensor produced by the forward pass of the Network.
:param data_opt: Ground truth tensor to be compared with prediction.
:return: Loss value.
"""
raise NotImplementedError
[docs] def set_optimizer(self,
net: BaseNetwork) -> None:
"""
Define an optimization process.
:param net: Network whose parameters will be optimized.
"""
raise NotImplementedError
[docs] def optimize(self) -> None:
"""
Run an optimization step.
"""
raise NotImplementedError
def __str__(self):
description = "\n"
description += f" {self.__class__.__name__}\n"
description += f" Loss class: {self.loss_class.__name__}\n" if self.loss_class else f" Loss class: None\n"
description += f" Optimizer class: {self.optimizer_class.__name__}\n" if self.optimizer_class else \
f" Optimizer class: None\n"
description += f" Learning rate: {self.lr}\n"
return description