from collections import defaultdict, deque
from typing import Any, Deque, Iterable, Optional, Tuple, Union
import numpy as np
import torch
from torchpack.callbacks.writers import SummaryWriter
from torchpack.utils.typing import Trainer
__all__ = ['Summary']
[docs]class Summary:
def __init__(self) -> None:
self.history = defaultdict(deque)
[docs] def set_trainer(self, trainer: Trainer) -> None:
self.trainer = trainer
self._set_trainer(trainer)
def _set_trainer(self, trainer: Trainer) -> None:
self.writers = []
for callback in trainer.callbacks:
if isinstance(callback, SummaryWriter):
self.writers.append(callback)
[docs] def add_scalar(self,
name: str,
scalar: Union[int, float, np.integer, np.floating],
*,
max_to_keep: Optional[int] = None) -> None:
if isinstance(scalar, np.integer):
scalar = int(scalar)
if isinstance(scalar, np.floating):
scalar = float(scalar)
assert isinstance(scalar, (int, float)), type(scalar)
self._add_scalar(name, scalar, max_to_keep=max_to_keep)
def _add_scalar(self, name: str, scalar: Union[int, float], *,
max_to_keep: Optional[int]) -> None:
self.history[name].append((self.trainer.global_step, scalar))
while max_to_keep is not None and \
len(self.history[name]) > max_to_keep:
self.history[name].popleft()
for writer in self.writers:
writer.add_scalar(name, scalar)
[docs] def add_image(self,
name: str,
tensor: Union[np.ndarray, torch.Tensor],
*,
max_to_keep: Optional[int] = None) -> None:
if isinstance(tensor, torch.Tensor):
tensor = tensor.cpu().numpy()
assert isinstance(tensor, np.ndarray), type(tensor)
if tensor.ndim == 2:
tensor = tensor[np.newaxis, ...]
elif tensor.ndim == 3 and tensor.shape[-1] in [1, 3, 4]:
tensor = tensor.transpose(2, 0, 1)
assert tensor.ndim == 3 and tensor.shape[0] in [1, 3, 4], tensor.shape
self._add_image(name, tensor, max_to_keep=max_to_keep)
def _add_image(self, name: str, tensor: np.ndarray, *,
max_to_keep: Optional[int]) -> None:
self.history[name].append((self.trainer.global_step, tensor))
while max_to_keep is not None and \
len(self.history[name]) > max_to_keep:
self.history[name].popleft()
for writer in self.writers:
writer.add_image(name, tensor)
[docs] def keys(self) -> Iterable[str]:
for key in self.history.keys():
yield key
[docs] def values(self) -> Iterable[Deque[Tuple[int, Any]]]:
for value in self.history.values():
yield value
[docs] def items(self) -> Iterable[Tuple[str, Deque[Tuple[int, Any]]]]:
for key, value in self.history.items():
yield key, value
def __contains__(self, key: str) -> bool:
return key in self.history
def __getitem__(self, key: str) -> Deque[Tuple[int, Any]]:
return self.history[key]