分享

Pytorch 数据流中常见Trick总结

 InfoRich 2021-12-06

设为 “星标”,DLCV消息即可送达!

Image


前言 在使用Pytorch建模时,常见的流程为先写Model,再写Dataset,最后写Trainer。Dataset 是整个项目开发中投入时间第二多,也是中间关键的步骤。往往需要事先对于其设计有明确的思考,不然可能会因为Dataset的一些问题又要去调整Model,Trainer。

本文将目前开发中的一些思考以及遇到的问题做一个总结,提供给各位读者一个比较通用的模版,抛砖引玉~


from torch.utils.data import Dataset, DataLoader, RandomSampler
class BaseDataset(Dataset):
    def __init__(self, config):
        self.config = config
        if os.path.isfile(config.file_path) is False:
            raise ValueError(f'Input file path {config.file_path} not found')
        logger.info(f'Creating features from dataset file at {config.file_path}')
        # 一次性全读进内存
        self.data = joblib.load(config.file_path)
        self.nums = len(self.data)

    def __len__(self):
        return self.nums

    def __getitem__(self, i) -> Dict[str, tensor]:
        sample_i = self.data[i]
        return {'f1':torch.tensor(sample_i['f1']).long(),'f2':torch.tensor(sample_i['f2']).long(),torch.LongTensor([sample_i['label']])}
def build_dataset(task_type, features, **kwargs):
    assert task_type in ['task1''task2'], 'task mismatch'

    if task_type == 'task1':
        dataset = task1Dataset(features))
    else:
        dataset = task2Dataset(features)

    return dataset


train_sampler = RandomSampler(train_dataset)
train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.train_batch_size,
                              sampler=train_sampler,
                              shuffle=(train_sampler is None)
                              collate_fn=None# 一般不用设置
                              num_workers=4)
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
def merge_sample(x):
    return zip(*x)

train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.train_batch_size,
                              sampler=train_sampler,
                              shuffle=(train_sampler is None)
                              collate_fn=merge_sample,
                              num_workers=4)
值得注意的是在cpu环境下,如果要自定义collate_fn,num_workers必须设置为0,不然就会有问题..
for step, batch_data in enumerate(train_loader):
    if step < 1:
        print(batch_data)
    else:
        break
for key in batch_data.keys():
    batch_data[key] = batch_data[key].to(device)
loss = model(**batch_data)

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多