目录
1. PyTorch 预训练模型Pytorch 提供了许多 Pre-Trained Model on ImageNet,仅需调用 torchvision.models 即可,具体细节可查看官方文档。 往往我们需要对 Pre-Trained Model 进行相应的修改,以适应我们的任务。这种情况下,我们可以先输出 Pre-Trained Model 的结构,确定好对哪些层修改,或者添加哪些层,接着,再将其修改即可。 比如,我需要将 ResNet-50 的 Layer 3 后的所有层去掉,在分别连接十个分类器,分类器由 ResNet-50.layer4 和 AvgPool Layer 和 FC Layer 构成。这里就需要用到 torch.nn.ModuleList 了,比如:: 代码中的 [nn.Linear(10, 10) for i in range(10)] 是一个python列表,必须要把它转换成一个Module Llist列表才可以被 PyTorch 使用,否则在运行的时候会报错: RuntimeError: Input type (CUDAFloatTensor) and weight type (CPUFloatTensor) should be the same 2. 保存模型参数PyTorch 中保存模型的方式有许多种: # 保存整个网络 torch.save(model, PATH) # 保存网络中的参数, 速度快,占空间少 torch.save(model.state_dict(),PATH) # 选择保存网络中的一部分参数或者额外保存其余的参数 torch.save({'state_dict': model.state_dict(), 'fc_dict':model.fc.state_dict(), 'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma}, PATH) 3. 读取模型参数同样的,PyTorch 中读取模型参数的方式也有许多种:
4. 冻结部分模型参数,进行 fine-tuning加载完 Pre-Trained Model 后,我们需要对其进行 Finetune。但是在此之前,我们往往需要冻结一部分的模型参数: # 第一种方式 for p in freeze.parameters(): # 将需要冻结的参数的 requires_grad 设置为 False p.requires_grad = False for p in no_freeze.parameters(): # 将fine-tuning 的参数的 requires_grad 设置为 True p.requires_grad = True # 将需要 fine-tuning 的参数放入optimizer 中 optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
# 第二种方式 optim_param = [] for p in freeze.parameters(): # 将需要冻结的参数的 requires_grad 设置为 False p.requires_grad = False for p in no_freeze.parameters(): # 将fine-tuning 的参数的 requires_grad 设置为 True p.requires_grad = True optim_param.append(p) optimizer.SGD(optim_param, lr=1e-3) # 将需要 fine-tuning 的参数放入optimizer 中 5. 模型训练与测试的设置训练时,应调用 model.train() ;测试时,应调用 model.eval(),以及 with torch.no_grad(): model.train():使 model 变成训练模式,此时 dropout 和 batch normalization 的操作在训练起到防止网络过拟合的问题。 model.eval():PyTorch会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。不然的话,一旦测试集的 Batch Size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。 with torch.no_grad():PyTorch 将不再计算梯度,这将使得模型 forward 的时候,显存的需求大幅减少,速度大幅提高。 注意:若模型中具有 Batch Normalization 操作,想固定该操作进行训练时,需调用对应的 module 的 eval() 函数。这是因为 BN Module 除了参数以外,还会对输入的数据进行统计,若不调用 eval(),统计量将发生改变!具体代码可以这样写:
在其他地方看到的解释:
6. 利用 torch.nn.DataParallel 进行多 GPU 训练
-完- |
|
来自: 520jefferson > 《待分类》