import time
from typing import List
import torch
import tqdm
from torch.utils.data import DataLoader
from torchpack.callbacks.callback import Callback, Callbacks
from torchpack.utils import humanize
from torchpack.utils.logging import logger
from torchpack.utils.typing import Trainer
__all__ = ['InferenceRunner']
[docs]class InferenceRunner(Callback):
"""
A callback that runs inference with a list of :class:`Callback`.
"""
def __init__(self, dataflow: DataLoader, *,
callbacks: List[Callback]) -> None:
self.dataflow = dataflow
self.callbacks = Callbacks(callbacks)
def _set_trainer(self, trainer: Trainer) -> None:
self.callbacks.set_trainer(trainer)
def _trigger_epoch(self) -> None:
self._trigger()
def _trigger(self) -> None:
start_time = time.perf_counter()
self.callbacks.before_epoch()
with torch.no_grad():
for feed_dict in tqdm.tqdm(self.dataflow, ncols=0):
self.callbacks.before_step(feed_dict)
output_dict = self.trainer.run_step(feed_dict)
self.callbacks.after_step(output_dict)
self.callbacks.after_epoch()
logger.info('Inference finished in {}.'.format(
humanize.naturaldelta(time.perf_counter() - start_time)))