设为 “星标”,DLCV消息即可送达! 
前言 在使用Pytorch建模时,常见的流程为先写Model,再写Dataset,最后写Trainer。Dataset 是整个项目开发中投入时间第二多,也是中间关键的步骤。往往需要事先对于其设计有明确的思考,不然可能会因为Dataset的一些问题又要去调整Model,Trainer。 本文将目前开发中的一些思考以及遇到的问题做一个总结,提供给各位读者一个比较通用的模版,抛砖引玉~
from torch.utils.data import Dataset, DataLoader, RandomSampler
class BaseDataset(Dataset): def __init__(self, config): self.config = config if os.path.isfile(config.file_path) is False: raise ValueError(f'Input file path {config.file_path} not found') logger.info(f'Creating features from dataset file at {config.file_path}') # 一次性全读进内存 self.data = joblib.load(config.file_path) self.nums = len(self.data)
def __len__(self): return self.nums
def __getitem__(self, i) -> Dict[str, tensor]: sample_i = self.data[i] return {'f1':torch.tensor(sample_i['f1']).long(),'f2':torch.tensor(sample_i['f2']).long(),torch.LongTensor([sample_i['label']])}
def build_dataset(task_type, features, **kwargs): assert task_type in ['task1', 'task2'], 'task mismatch'
if task_type == 'task1': dataset = task1Dataset(features)) else: dataset = task2Dataset(features)
return dataset
train_sampler = RandomSampler(train_dataset) train_loader = DataLoader(dataset=train_dataset, batch_size=args.train_batch_size, sampler=train_sampler, shuffle=(train_sampler is None) collate_fn=None, # 一般不用设置 num_workers=4)
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
def merge_sample(x): return zip(*x)
train_loader = DataLoader(dataset=train_dataset, batch_size=args.train_batch_size, sampler=train_sampler, shuffle=(train_sampler is None) collate_fn=merge_sample, num_workers=4)
值得注意的是在cpu环境下,如果要自定义collate_fn,num_workers必须设置为0,不然就会有问题.. for step, batch_data in enumerate(train_loader): if step < 1: print(batch_data) else: break
for key in batch_data.keys(): batch_data[key] = batch_data[key].to(device) loss = model(**batch_data)
|