开局一道面试题:不量化,不损失精度,如何用一张16GB的显卡推断 fp16的70B大模型?
我们传统的模型加载到显卡的流程是:
1.创建模型
2.在内存中加载其权重(通常在一个叫做state_dict的对象中)
3.在创建的模型中加载这些权重
4.将模型移动到设备上进行推理
在这个常规步骤中,第二步需要把模型加载到RAM里,用fp16的话的显存消耗量大约是模型参数量的两倍的数量级,比如70B模型,大概需要140GB显存存放模型。要是使用32位全精度的话,需要4倍的数据数量级参数,
计算原理(1 个float32耗费4bytes,1个float16耗费 2bytes),关于参数和显存的推演,更详细的细节看之前的文章:大模型训练为什么用A100不用4090
并且在第三步,模型从RAM挪到显存,还需要一份额外的拷贝,耗费同样的大小的内存。
当你想加载更大的模型,比如BLOOM或OPT-176B(1760亿个参数)的话,按上面推演,你将需要1.4TB的CPU RAM。还是十分夸张的!而所有这些只是为了在第4步将模型移动到GPU上。
所以正常来讲,70B的大模型,至少需要4张40GB的A100来推断,一张16GB的显卡肯定是不够的,另外32GB的RAM也是远远不够的,别说显存,内存都放不下一个模型的参数。
那么我们有什么办法呢?
既然显卡放不下一个模型,那我们能否放一部分到显卡上,边推理边释放呢?
答案是肯定的,并且这个方案在 Kaggle比赛里已经实践过,在一些特定的限制资料的场景,可以用来当一个极其节省现存的推断方案。
具体的流程可以讲解为:
1.创建一个空的(例如,没有权重的)模型
2.决定每一层将要去哪里(当有多个设备可用时)
3.在内存中加载其权重的一部分
4.在空模型中加载这些权重
5.将权重移动到设备上进行推理
6.从第3步重复,直到所有的权重都被加载
这个过程实现得益得以依赖于pytorch 1.9的一个叫meta device的玩意儿,
PyTorch 1.9引入了一种新的设备,称为元设备(meta device)。
这使我们能够创建没有任何数据附加的张量,元设备上的张量只需要一个shape,只要你在元设备上,你就可以创建任意大的张量,而不必担心CPU(或GPU)的RAM够不够。
比如下面的代码,内存不够的话就会崩掉
import torch
large_tensor = torch.randn(100000, 100000)
这个大张量需要4 * 10**10字节(默认精度是FP32,所以张量的每个元素占用4字节),因此需要40GB的RAM。然而,在元设备上执行相同的操作就可以正常运行:
import torch
large_tensor = torch.randn(100000, 100000, device='meta')
这个张量没有关联的数据,只有一个形状。你可以直接在元设备上实例化一个模型:
large_model = torch.nn.Linear(100000, 100000, device='meta')
但是对于现成的模型来说,这种语法需要你重写所有的建模代码,以便每个模型的子部分都接受并传递一个设备关键字参数。由于这对Transformers库的预训练模型来说不切实际,accelerate库有一个context manager,整合了meta device可以实例化一个空模型。
# Load meta model (no memory used)
with init_empty_weights():
self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)
self.model.tie_weights()
这一步很关键,我们知道每个权重的形状,因此我们可以知道一旦我们完全加载预训练的张量,它们将消耗多少内存。因此,我们可以决定如何在CPU和GPU之间分割我们的模型。
除此之外,定义了两个关键的方法,分别是load_layer_to_cpu,负责把 权重从disk挪到CPU,另外一个是move_layer_to_device,负责把权重从cpu挪到显卡。还有一个释放显存的方法clean_memory,负责清空显存。
def load_layer_to_cpu(self, layer_name):
self.weights_loader.set_state_dict(layer_name, self.device)
state_dict = self.weights_loader.get_state_dict(self.device)
if 'value_head.weight' in state_dict:
state_dict = {'lm_head.weight' : state_dict['value_head.weight']}
return state_dict
def move_layer_to_device(self, state_dict):
for param_name, param in state_dict.items():
assert param.dtype != torch.int8, 'int8 not supported (need to add fp16_statistics)'
set_module_tensor_to_device(self.model, param_name, self.device, value=param, dtype=self.dtype)
def clean_memory():
gc.collect()
ctypes.CDLL('libc.so.6').malloc_trim(0)
torch.cuda.empty_cache()
好了,对过程大概弄懂了就可以看完整的代码了。注意,下面的代码也包含了题目设定里特定的prefix和suffix,正常的推理可以忽略相关的逻辑,仅保留一个prompt即可。下面展示完整的代码
# For LLM
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from safetensors.torch import load_file
from optimum.bettertransformer import BetterTransformer
N_BATCHES = 3
MAX_LENGTH = 4096
def clean_memory():
gc.collect()
ctypes.CDLL('libc.so.6').malloc_trim(0)
torch.cuda.empty_cache()
# Class for sharded llama
class ShardedLlama:
def __init__(self, checkpoint_path, weights_loader, device='cuda:0', dtype=torch.float16):
# Save parameters
self.checkpoint_path = Path(checkpoint_path)
self.weights_loader = weights_loader
self.device = device
self.dtype = dtype
# Create model
self.config = AutoConfig.from_pretrained(self.checkpoint_path)
self.tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = 'right'
self.init_model()
self.layer_names = ['model.embed_tokens'] + [f'model.layers.{i}' for i in range(len(self.model.model.layers))] + ['model.norm', 'value_head']
def init_model(self):
# Load meta model (no memory used)
with init_empty_weights():
self.model = AutoModelForCausalLM.from_config(self.config)
self.model.lm_head = torch.nn.Linear(8192, 8, bias=False) # originally 32k
self.model.eval()
self.model = BetterTransformer.transform(self.model) # enable flash attention
self.model.tie_weights()
self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) + [self.model.model.norm, self.model.lm_head]
# Move buffers to device (note that much GPU memory used)
for buffer_name, buffer in self.model.named_buffers():
set_module_tensor_to_device(self.model, buffer_name, self.device, value=buffer, dtype=self.dtype)
def load_layer_to_cpu(self, layer_name):
self.weights_loader.set_state_dict(layer_name, self.device)
state_dict = self.weights_loader.get_state_dict(self.device)
if 'value_head.weight' in state_dict:
state_dict = {'lm_head.weight' : state_dict['value_head.weight']}
return state_dict
def move_layer_to_device(self, state_dict):
for param_name, param in state_dict.items():
assert param.dtype != torch.int8, 'int8 not supported (need to add fp16_statistics)'
set_module_tensor_to_device(self.model, param_name, self.device, value=param, dtype=self.dtype)
def __call__(self, inputs):
# inputs = [(prefix, suffix), ...] with prefix.shape[0] = 1 and suffix.shape[0] = 5
# Reboot the model to make sure buffers are loaded and memory is clean
del self.model
clean_memory()
self.init_model()
# Send batch to device
batch = [(prefix.to(self.device), suffix.to(self.device)) for prefix, suffix in inputs]
n_suffixes = len(batch[0][1])
suffix_eos = [(suffix != self.tokenizer.pad_token_id).sum(1) - 1 for _, suffix in inputs]
# Create attention mask for the largest input, and position ids to use KV cache
attention_mask = torch.ones(MAX_LENGTH, MAX_LENGTH)
attention_mask = attention_mask.triu(diagonal=1)[None, None, ...] == 0
attention_mask = attention_mask.to(self.device)
position_ids = torch.arange(MAX_LENGTH, dtype=torch.long, device=self.device)[None, :]
with ThreadPoolExecutor() as executor, torch.inference_mode():
# Load first layer
future = executor.submit(self.load_layer_to_cpu, 'model.embed_tokens')
for i, (layer_name, layer) in tqdm(enumerate(zip(self.layer_names, self.layers)), desc=self.device, total=len(self.layers)):
# Load current layer and prepare next layer
state_dict = future.result()
if (i + 1) < len(self.layer_names):
future = executor.submit(self.load_layer_to_cpu, self.layer_names[i + 1])
self.move_layer_to_device(state_dict)
# Run layer
for j, (prefix, suffix) in enumerate(batch):
if layer_name == 'model.embed_tokens':
batch[j] = (layer(prefix), layer(suffix))
elif layer_name == 'model.norm':
# Only keep the last token at this point
batch[j] = (None, layer(suffix[torch.arange(n_suffixes), suffix_eos[j]][:, None]))
elif layer_name == 'value_head':
batch[j] = layer(suffix)[:, 0].mean(1).detach().cpu().numpy()
else:
# Run prefix
len_p, len_s = prefix.shape[1], suffix.shape[1]
new_prefix, (k_cache, v_cache) = layer(prefix, use_cache=True, attention_mask=attention_mask[:, :, -len_p:, -len_p:])
# Run suffix
pos = position_ids[:, len_p:len_p + len_s].expand(n_suffixes, -1)
attn = attention_mask[:, :, -len_s:, -len_p - len_s:].expand(n_suffixes, -1, -1, -1)
kv_cache = (k_cache.expand(n_suffixes, -1, -1, -1), v_cache.expand(n_suffixes, -1, -1, -1))
new_suffix = layer(suffix, past_key_value=kv_cache, position_ids=pos, attention_mask=attn)[0]
batch[j] = (new_prefix, new_suffix)
# Remove previous layer from memory (including buffers)
layer.to('meta')
clean_memory() # proposed by CPMP
# Get scores
return batch
def run_model(device, df, weights_loader):
model = ShardedLlama(checkpoint_path, weights_loader, device=device)
f = partial(get_tokens, tokenizer=model.tokenizer)
inputs = df.apply(f, axis=1).values
batches = np.array_split(inputs, N_BATCHES)
outputs = []
for i, batch in enumerate(batches):
outputs += model(batch)
return outputs