Source code for torchpack.utils.device

import os
from typing import List, Union

import torch.cuda

__all__ = ['parse_cuda_devices', 'set_cuda_visible_devices']


[docs]def parse_cuda_devices(text: str) -> List[int]: if text == '*': return [device for device in range(torch.cuda.device_count())] devices = [] for device in text.split(','): device = device.strip().lower() if device == 'cpu': continue if device.startswith('gpu'): device = device[3:] if '-' in device: l, r = device.split('-') devices.extend(range(int(l), int(r) + 1)) else: devices.append(int(device)) return devices
[docs]def set_cuda_visible_devices(devices: Union[str, List[int]], *, environ: os._Environ = os.environ) -> List[int]: if isinstance(devices, str): devices = parse_cuda_devices(devices) environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, devices)) return devices