import hashlib
import json
import os
from ast import literal_eval
from typing import Any, Dict, List, Tuple, Union
from multimethod import multimethod
from torchpack.utils import io
__all__ = ['Config', 'configs']
[docs]class Config(dict):
def __getattr__(self, key: str) -> Any:
if key not in self:
raise AttributeError(key)
return self[key]
def __setattr__(self, key: str, value: Any) -> None:
self[key] = value
def __delattr__(self, key: str) -> None:
del self[key]
[docs] def load(self, fpath: str, *, recursive: bool = False) -> None:
if not os.path.exists(fpath):
raise FileNotFoundError(fpath)
fpaths = [fpath]
if recursive:
while fpath:
fpath = os.path.dirname(fpath)
for fname in ['default.yaml', 'default.yml']:
fpaths.append(os.path.join(fpath, fname))
for fpath in reversed(fpaths):
if os.path.exists(fpath):
self.update(io.load(fpath))
[docs] def reload(self, fpath: str, *, recursive: bool = False) -> None:
self.clear()
self.load(fpath, recursive=recursive)
@multimethod
def update(self, other: Dict) -> None:
for key, value in other.items():
if isinstance(value, dict):
if key not in self or not isinstance(self[key], Config):
self[key] = Config()
self[key].update(value)
else:
self[key] = value
[docs] @multimethod
def update(self, opts: Union[List, Tuple]) -> None:
index = 0
while index < len(opts):
opt = opts[index]
if opt.startswith('--'):
opt = opt[2:]
if '=' in opt:
key, value = opt.split('=', 1)
index += 1
else:
key, value = opt, opts[index + 1]
index += 2
current = self
subkeys = key.split('.')
try:
value = literal_eval(value)
except:
pass
for subkey in subkeys[:-1]:
current = current.setdefault(subkey, Config())
current[subkeys[-1]] = value
[docs] def dict(self) -> Dict[str, Any]:
configs = dict()
for key, value in self.items():
if isinstance(value, Config):
value = value.dict()
configs[key] = value
return configs
[docs] def hash(self) -> str:
buffer = json.dumps(self.dict(), sort_keys=True)
return hashlib.sha256(buffer.encode()).hexdigest()
def __str__(self) -> str:
texts = []
for key, value in self.items():
if isinstance(value, Config):
seperator = '\n'
else:
seperator = ' '
text = key + ':' + seperator + str(value)
lines = text.split('\n')
for k, line in enumerate(lines[1:]):
lines[k + 1] = (' ' * 2) + line
texts.extend(lines)
return '\n'.join(texts)
configs = Config()