Source code for torchpack.callbacks.progress

import time
from collections import deque
from typing import List, Union

import numpy as np
import tqdm

from torchpack.callbacks.callback import Callback
from torchpack.utils import humanize
from torchpack.utils.logging import logger
from torchpack.utils.matching import NameMatcher

__all__ = ['ProgressBar', 'EstimatedTimeLeft']


[docs]class ProgressBar(Callback): """ A progress bar based on `tqdm`. """ master_only: bool = True def __init__(self, scalars: Union[str, List[str]] = '*') -> None: self.matcher = NameMatcher(patterns=scalars) def _before_epoch(self) -> None: self.pbar = tqdm.trange(self.trainer.steps_per_epoch, ncols=0) def _trigger_step(self) -> None: texts = [] for name in sorted(self.trainer.summary.keys()): step, scalar = self.trainer.summary[name][-1] if self.matcher.match(name) and step == self.trainer.global_step and \ isinstance(scalar, (int, float)): texts.append('[{}] = {:.3g}'.format(name, scalar)) if texts: self.pbar.set_description(', '.join(texts)) self.pbar.update() def _after_epoch(self) -> None: self.pbar.close()
[docs]class EstimatedTimeLeft(Callback): """ Estimate the time left until completion. """ master_only: bool = True def __init__(self, *, last_k_epochs: int = 8) -> None: self.last_k_epochs = last_k_epochs def _before_train(self) -> None: self.times = deque(maxlen=self.last_k_epochs) self.last_time = time.perf_counter() def _trigger_epoch(self) -> None: if self.trainer.epoch_num < self.trainer.num_epochs: self.times.append(time.perf_counter() - self.last_time) self.last_time = time.perf_counter() estimated_time = (self.trainer.num_epochs - self.trainer.epoch_num) * np.mean(self.times) logger.info('Estimated time left: {}.'.format( humanize.naturaldelta(estimated_time)))