import glob
import os
from collections import deque
from typing import Any, ClassVar, Dict, Optional
from torchpack.callbacks.callback import Callback
from torchpack.environ import get_run_dir
from torchpack.utils import fs, io
from torchpack.utils.logging import logger
from torchpack.utils.typing import Trainer
__all__ = ['Saver', 'MinSaver', 'MaxSaver', 'SaverRestore']
[docs]class Saver(Callback):
"""
Save the checkpoint once triggered.
"""
master_only: bool = True
def __init__(self, *, max_to_keep: int = 4,
save_dir: Optional[str] = None) -> None:
self.max_to_keep = max_to_keep
if save_dir is None:
save_dir = os.path.join(get_run_dir(), 'checkpoints')
self.save_dir = fs.normpath(save_dir)
def _set_trainer(self, trainer: Trainer) -> None:
self.checkpoints = deque()
for fpath in sorted(glob.glob(os.path.join(self.save_dir,
'step-*.pt')),
key=os.path.getmtime):
self._add_checkpoint(fpath)
def _trigger_epoch(self) -> None:
self._trigger()
def _trigger(self) -> None:
save_path = os.path.join(self.save_dir,
f'step-{self.trainer.global_step}.pt')
try:
io.save(save_path, self.trainer.state_dict())
except OSError:
logger.exception(
f'Error occurred when saving checkpoint "{save_path}".')
else:
logger.info(f'Checkpoint saved: "{save_path}".')
self._add_checkpoint(save_path)
def _add_checkpoint(self, fpath: str) -> None:
self.checkpoints.append(fpath)
while self.max_to_keep is not None and \
len(self.checkpoints) > self.max_to_keep:
fpath = self.checkpoints.popleft()
try:
fs.remove(fpath)
except OSError:
logger.exception(
f'Error occurred when removing checkpoint "{fpath}".')
class BestSaver(Callback):
"""
Save the checkpoint with best value of some scalar in `trainer.summary`.
"""
master_only: bool = True
extreme: ClassVar[str]
def __init__(self,
scalar: str,
*,
name: Optional[str] = None,
save_dir: Optional[str] = None) -> None:
self.scalar = scalar
if name is None:
name = self.extreme + '-' + scalar.replace('/', '-')
self.name = name
if save_dir is None:
save_dir = os.path.join(get_run_dir(), 'checkpoints')
self.save_dir = fs.normpath(save_dir)
def _set_trainer(self, trainer: Trainer) -> None:
self.step, self.best = None, None
def _trigger_epoch(self) -> None:
self._trigger()
def _trigger(self):
if self.scalar not in self.trainer.summary:
logger.warning(
f'`{self.scalar}` has not been added to `trainer.summary`.')
return
step, value = self.trainer.summary[self.scalar][-1]
if self.step is not None and step <= self.step:
logger.warning(
f'`{self.scalar}` has not been updated since last trigger.')
return
self.step = step
if self.best is None or (self.extreme == 'min' and value < self.best[1]) \
or (self.extreme == 'max' and value > self.best[1]):
self.best = (step, value)
save_path = os.path.join(self.save_dir, self.name + '.pt')
try:
io.save(save_path, self.trainer.state_dict())
except OSError:
logger.exception(
f'Error occurred when saving checkpoint "{save_path}".')
else:
logger.info(f'Checkpoint saved: "{save_path}" ({value:.5g}).')
if self.best is not None:
self.trainer.summary.add_scalar(self.scalar + '/' + self.extreme,
self.best[1])
def _state_dict(self) -> Dict[str, Any]:
return {'step': self.step, 'best': self.best}
def _load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.step, self.best = state_dict['step'], state_dict['best']
[docs]class MinSaver(BestSaver):
"""
Save the checkpoint with minimum value of some scalar in `trainer.summary`.
"""
extreme: ClassVar[str] = 'min'
[docs]class MaxSaver(BestSaver):
"""
Save the checkpoint with maximum value of some scalar in `trainer.summary`.
"""
extreme: ClassVar[str] = 'max'
[docs]class SaverRestore(Callback):
def __init__(self, load_dir: Optional[str] = None) -> None:
if load_dir is None:
load_dir = os.path.join(get_run_dir(), 'checkpoints')
self.load_dir = fs.normpath(load_dir)
def _before_train(self) -> None:
checkpoints = glob.glob(os.path.join(self.load_dir, 'step-*.pt'))
if not checkpoints:
logger.warning(f'No checkpoints found: "{self.load_dir}".')
return
load_path = max(checkpoints, key=os.path.getmtime)
try:
state_dict = io.load(load_path, map_location='cpu')
self.trainer.load_state_dict(state_dict)
except OSError:
logger.exception(
f'Error occurred when loading checkpoint "{load_path}".')
else:
logger.info(f'Checkpoint loaded: "{load_path}".')