分享

大模型黑科技。

 天承办公室 2023-10-18 发布于北京

开局一道面试题:不量化,不损失精度,如何用一张16GB的显卡推断 fp16的70B大模型?

简介

我们传统的模型加载到显卡的流程是:

1.创建模型 

2.在内存中加载其权重(通常在一个叫做state_dict的对象中) 

3.在创建的模型中加载这些权重 

4.将模型移动到设备上进行推理

在这个常规步骤中,第二步需要把模型加载到RAM里,用fp16的话的显存消耗量大约是模型参数量的两倍的数量级,比如70B模型,大概需要140GB显存存放模型。要是使用32位全精度的话,需要4倍的数据数量级参数, 计算原理(1 个float32耗费4bytes,1个float16耗费 2bytes),关于参数和显存的推演,更详细的细节看之前的文章:大模型训练为什么用A100不用4090

并且在第三步,模型从RAM挪到显存,还需要一份额外的拷贝,耗费同样的大小的内存。

当你想加载更大的模型,比如BLOOM或OPT-176B(1760亿个参数)的话,按上面推演,你将需要1.4TB的CPU RAM。还是十分夸张的!而所有这些只是为了在第4步将模型移动到GPU上。

所以正常来讲,70B的大模型,至少需要4张40GB的A100来推断,一张16GB的显卡肯定是不够的,另外32GB的RAM也是远远不够的,别说显存,内存都放不下一个模型的参数。

那么我们有什么办法呢?

既然显卡放不下一个模型,那我们能否放一部分到显卡上,边推理边释放呢?

答案是肯定的,并且这个方案在 Kaggle比赛里已经实践过,在一些特定的限制资料的场景,可以用来当一个极其节省现存的推断方案。

具体的流程可以讲解为:

1.创建一个空的(例如,没有权重的)模型

2.决定每一层将要去哪里(当有多个设备可用时)

3.在内存中加载其权重的一部分

4.在空模型中加载这些权重

5.将权重移动到设备上进行推理 

6.从第3步重复,直到所有的权重都被加载


这个过程实现得益得以依赖于pytorch 1.9的一个叫meta device的玩意儿,

PyTorch 1.9引入了一种新的设备,称为元设备(meta device)。

这使我们能够创建没有任何数据附加的张量,元设备上的张量只需要一个shape,只要你在元设备上,你就可以创建任意大的张量,而不必担心CPU(或GPU)的RAM够不够。

比如下面的代码,内存不够的话就会崩掉

import torch
large_tensor = torch.randn(100000100000)

这个大张量需要4 * 10**10字节(默认精度是FP32,所以张量的每个元素占用4字节),因此需要40GB的RAM。然而,在元设备上执行相同的操作就可以正常运行:

import torch
large_tensor = torch.randn(100000100000, device='meta')

这个张量没有关联的数据,只有一个形状。你可以直接在元设备上实例化一个模型:

large_model = torch.nn.Linear(100000100000, device='meta')

但是对于现成的模型来说,这种语法需要你重写所有的建模代码,以便每个模型的子部分都接受并传递一个设备关键字参数。由于这对Transformers库的预训练模型来说不切实际,accelerate库有一个context manager,整合了meta device可以实例化一个空模型。

# Load meta model (no memory used)
with init_empty_weights():
    self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)
    self.model.tie_weights()

这一步很关键,我们知道每个权重的形状,因此我们可以知道一旦我们完全加载预训练的张量,它们将消耗多少内存。因此,我们可以决定如何在CPU和GPU之间分割我们的模型。

除此之外,定义了两个关键的方法,分别是load_layer_to_cpu,负责把 权重从disk挪到CPU,另外一个是move_layer_to_device,负责把权重从cpu挪到显卡。还有一个释放显存的方法clean_memory,负责清空显存。

def load_layer_to_cpu(self, layer_name):
    self.weights_loader.set_state_dict(layer_name, self.device)
    state_dict = self.weights_loader.get_state_dict(self.device)
    if 'value_head.weight' in state_dict:
        state_dict = {'lm_head.weight' : state_dict['value_head.weight']}
    return state_dict
    
def move_layer_to_device(self, state_dict):
    for param_name, param in state_dict.items():
        assert param.dtype != torch.int8, 'int8 not supported (need to add fp16_statistics)'
        set_module_tensor_to_device(self.model, param_name, self.device, value=param, dtype=self.dtype)

def clean_memory():
    gc.collect()
    ctypes.CDLL('libc.so.6').malloc_trim(0)
    torch.cuda.empty_cache()

好了,对过程大概弄懂了就可以看完整的代码了。注意,下面的代码也包含了题目设定里特定的prefix和suffix,正常的推理可以忽略相关的逻辑,仅保留一个prompt即可。下面展示完整的代码

# For LLM
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from safetensors.torch import load_file
from optimum.bettertransformer import BetterTransformer

N_BATCHES = 3
MAX_LENGTH = 4096

def clean_memory():
    gc.collect()
    ctypes.CDLL('libc.so.6').malloc_trim(0)
    torch.cuda.empty_cache()


# Class for sharded llama
class ShardedLlama:
    def __init__(self, checkpoint_path, weights_loader, device='cuda:0', dtype=torch.float16):

        # Save parameters
        self.checkpoint_path = Path(checkpoint_path)
        self.weights_loader = weights_loader
        self.device = device 
        self.dtype = dtype

        # Create model
        self.config = AutoConfig.from_pretrained(self.checkpoint_path)   
        self.tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = 'right'
        self.init_model()
        self.layer_names = ['model.embed_tokens'] + [f'model.layers.{i}' for i in range(len(self.model.model.layers))] + ['model.norm''value_head']

    def init_model(self):
    
        # Load meta model (no memory used)
        with init_empty_weights():
            self.model = AutoModelForCausalLM.from_config(self.config)
            self.model.lm_head = torch.nn.Linear(81928, bias=False# originally 32k
            self.model.eval()
            self.model = BetterTransformer.transform(self.model) # enable flash attention
            self.model.tie_weights()
            
        self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) + [self.model.model.norm, self.model.lm_head]

        # Move buffers to device (note that much GPU memory used)
        for buffer_name, buffer in self.model.named_buffers():
            set_module_tensor_to_device(self.model, buffer_name, self.device, value=buffer, dtype=self.dtype)

    def load_layer_to_cpu(self, layer_name):
        self.weights_loader.set_state_dict(layer_name, self.device)
        state_dict = self.weights_loader.get_state_dict(self.device)
        if 'value_head.weight' in state_dict:
            state_dict = {'lm_head.weight' : state_dict['value_head.weight']}
        return state_dict
        
    def move_layer_to_device(self, state_dict):
        for param_name, param in state_dict.items():
            assert param.dtype != torch.int8, 'int8 not supported (need to add fp16_statistics)'
            set_module_tensor_to_device(self.model, param_name, self.device, value=param, dtype=self.dtype)

    def __call__(self, inputs):
        # inputs = [(prefix, suffix), ...] with prefix.shape[0] = 1 and suffix.shape[0] = 5
        
        # Reboot the model to make sure buffers are loaded and memory is clean
        del self.model
        clean_memory()
        self.init_model()
        
       # Send batch to device
        batch = [(prefix.to(self.device), suffix.to(self.device)) for prefix, suffix in inputs]
        n_suffixes = len(batch[0][1])
        suffix_eos = [(suffix != self.tokenizer.pad_token_id).sum(1) - 1 for _, suffix in inputs]

        # Create attention mask for the largest input, and position ids to use KV cache
        attention_mask = torch.ones(MAX_LENGTH, MAX_LENGTH)
        attention_mask = attention_mask.triu(diagonal=1)[NoneNone, ...] == 0
        attention_mask = attention_mask.to(self.device)
        position_ids = torch.arange(MAX_LENGTH, dtype=torch.long, device=self.device)[None, :]

        with ThreadPoolExecutor() as executor, torch.inference_mode():

            # Load first layer
            future = executor.submit(self.load_layer_to_cpu, 'model.embed_tokens')

            for i, (layer_name, layer) in tqdm(enumerate(zip(self.layer_names, self.layers)), desc=self.device, total=len(self.layers)):

                # Load current layer and prepare next layer
                state_dict = future.result()
                if (i + 1) < len(self.layer_names):
                    future = executor.submit(self.load_layer_to_cpu, self.layer_names[i + 1])
                self.move_layer_to_device(state_dict)
                
                # Run layer
                for j, (prefix, suffix) in enumerate(batch):
                    if layer_name == 'model.embed_tokens':
                        batch[j] = (layer(prefix), layer(suffix))
                    elif layer_name == 'model.norm':
                        # Only keep the last token at this point
                        batch[j] = (None, layer(suffix[torch.arange(n_suffixes), suffix_eos[j]][:, None]))
                    elif layer_name == 'value_head':
                        batch[j] = layer(suffix)[:, 0].mean(1).detach().cpu().numpy()
                    else:
                        # Run prefix
                        len_p, len_s = prefix.shape[1], suffix.shape[1]
                        new_prefix, (k_cache, v_cache) = layer(prefix, use_cache=True, attention_mask=attention_mask[:, :, -len_p:, -len_p:])
                        
                        # Run suffix
                        pos = position_ids[:, len_p:len_p + len_s].expand(n_suffixes, -1)
                        attn = attention_mask[:, :, -len_s:, -len_p - len_s:].expand(n_suffixes, -1-1-1)
                        kv_cache = (k_cache.expand(n_suffixes, -1-1-1), v_cache.expand(n_suffixes, -1-1-1))
                        new_suffix = layer(suffix, past_key_value=kv_cache, position_ids=pos, attention_mask=attn)[0]
                        batch[j] = (new_prefix, new_suffix)

                # Remove previous layer from memory (including buffers)
                layer.to('meta')
                clean_memory() # proposed by CPMP

        # Get scores
        return batch




def run_model(device, df, weights_loader):
    model = ShardedLlama(checkpoint_path, weights_loader, device=device)
    f = partial(get_tokens, tokenizer=model.tokenizer)
    inputs = df.apply(f, axis=1).values
    batches = np.array_split(inputs, N_BATCHES)
    outputs = []
    for i, batch in enumerate(batches):
        outputs += model(batch)
    return outputs
    

完整的代码可以参考链接:https://www./code/simjeg/platypus2-70b-without-wikipedia-rag

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多