当前位置: 首页 > news >正文

llama3源码解读之推理-infer

文章目录

  • 前言
  • 一、整体源码解读
    • 1、完整main源码
    • 2、tokenizer加载
    • 3、llama3模型加载
    • 4、llama3测试数据文本加载
    • 5、llama3模型推理模块
      • 1、模型推理模块的数据处理
      • 2、模型推理模块的model.generate预测
      • 3、模型推理模块的预测结果处理
    • 6、多轮对话
  • 二、llama3推理数据处理
    • 1、完整数据处理源码
    • 2、使用prompt方式询问数据加载
    • 3、推理处理数据
  • 三、llama3推理generate调用_sample方法
    • 1、GenerationMode.SAMPLE方法源码
    • 2、huggingface的_sample源码
    • 3、_sample的初始化与准备
    • 4、_sample的初始化(attention / hidden states / scores)
    • 5、encoder-decoder模式模型处理
    • 6、保持相应变量内容
    • 7、进入while循环模块
    • 8、结果返回处理
  • 四、_sample的while循环模块内容
    • 1、模型输入加工
      • 1、首次迭代数据加工
      • 2、再次迭代数据加工
    • 2、模型推理
      • 1、首次迭代数据加工
      • 2、再次迭代数据加工
    • 3、预测结果取值
    • 4、预测logits惩罚处理
    • 5、预测softmax处理
    • 6、预测token选择处理
    • 7、停止条件更新处理
    • 8、再次循环迭代input_ids与attention_mask处理(生成式预测)
    • 8、停止条件再次更新与处理
  • 五、模型推理(self)
    • 1、模型预测输入参数
    • 2、进入包装选择调用方法
    • 3、进入forward函数--class LlamaForCausalLM(LlamaPreTrainedModel)
    • 4、llama3推理forward的前提说明
    • 5、llama3推理forward的设置参数
    • 6、llama3推理forward的推理(十分重要)
    • 7、llama3推理forward的的lm_head转换
    • 8、llama3推理forward的结果包装(CausalLMOutputWithPast)
  • 六、llama3模型forward的self.model方法
    • 1、模型输入内容
      • 1、self.model模型初次输入内容
      • 2、self.model模型再次输入内容
    • 2、进入包装选择调用方法
    • 3、进入forward函数--class LlamaModel(LlamaPreTrainedModel)
    • 4、llama3模型forward的源码
    • 5、llama3模型forward的输入准备与embedding
      • 1、初次输入llama3模型内容
      • 2、再此输入llama3模型内容
      • 3、说明
    • 6、llama3模型forward的cache
      • 1、DynamicCache.from_legacy_cache(past_key_values)函数
      • 2、cache = cls()类方法
        • 该类位置
        • 该类方法
      • 2、cache.update的更新
    • 7、llama3模型decoder_layer方法
      • 1、decoder_layer调用源码
      • 2、decoder_layer源码方法与调用self.self_attn方法
      • 3、self.self_attn源码方法
    • 8、hidden_states的LlamaRMSNorm
    • 9、next_cache保存与输出结果
  • 七、llama3模型atten方法
    • 1、模型位置
    • 2、llama3推理模型初始化
    • 3、llama3推理的forward方法
    • 4、llama3模型结构图


前言

本项目是解读开源github的代码,该项目基于Meta最新发布的新一代开源大模型Llama-3开发,是Chinese-LLaMA-Alpaca开源大模型相关系列项目(一期、二期)的第三期。而本项目开源了中文Llama-3基座模型和中文Llama-3-Instruct指令精调大模型。这些模型在原版Llama-3的基础上使用了大规模中文数据进行增量预训练,并且使用精选指令数据进行精调,进一步提升了中文基础语义和指令理解能力,相比二代相关模型获得了显著性能提升。因此,我是基于该项目解读训练与推理相关原理与内容,并以代码形式带领读者一步一步解读,理解其大语言模型运行机理。而该博客首先给出llama3推理源码相关内容解读,我将按照源码流程给出解读。


一、整体源码解读

1、完整main源码

我先给出完整的源码,后面推理使用哪些部分代码,我在深度解读。而一些较为简单内容我不在解读了。

if __name__ == '__main__':load_type = torch.float16# Move the model to the MPS device if availableif torch.backends.mps.is_available():device = torch.device("mps")else:if torch.cuda.is_available():device = torch.device(0)else:device = torch.device('cpu')print(f"Using device: {device}")if args.tokenizer_path is None:args.tokenizer_path = args.base_modeltokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)terminators = [tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids("<|eot_id|>")]if args.use_vllm:model = LLM(model=args.base_model,tokenizer=args.tokenizer_path,tensor_parallel_size=len(args.gpus.split(',')),dtype=load_type)generation_config["stop_token_ids"] = terminatorsgeneration_config["stop"] = ["<|eot_id|>", "<|end_of_text|>"]else:if args.load_in_4bit or args.load_in_8bit:quantization_config = BitsAndBytesConfig(load_in_4bit=args.load_in_4bit,load_in_8bit=args.load_in_8bit,bnb_4bit_compute_dtype=load_type,bnb_4bit_use_double_quant=True,bnb_4bit_quant_type="nf4")model = AutoModelForCausalLM.from_pretrained(args.base_model,torch_dtype=load_type,low_cpu_mem_usage=True,device_map='auto',quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None,attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa")if device==torch.device('cpu'):model.float()model.eval()# test dataif args.data_file is None:examples = sample_dataelse:with open(args.data_file, 'r') as f:examples = [line.strip() for line in f.readlines()]print("first 10 examples:")for example in examples[:10]:print(example)with torch.no_grad():if args.interactive:print("Start inference with instruction mode.")print('='*85)print("+ 该模式下仅支持单轮问答,无多轮对话能力。\n""+ 如要进行多轮对话,请使用llama.cpp")print('-'*85)print("+ This mode only supports single-turn QA.\n""+ If you want to experience multi-turn dialogue, please use llama.cpp")print('='*85)while True:raw_input_text = input("Input:")if len(raw_input_text.strip())==0:breakif args.with_prompt:input_text = generate_prompt(instruction=raw_input_text)else:input_text = raw_input_textif args.use_vllm:output = model.generate([input_text], SamplingParams(**generation_config), use_tqdm=False)response = output[0].outputs[0].textelse:inputs = tokenizer(input_text,return_tensors="pt")  #add_special_tokens=False ?generation_output = model.generate(input_ids = inputs["input_ids"].to(device),attention_mask = inputs['attention_mask'].to(device),eos_token_id=terminators,pad_token_id=tokenizer.eos_token_id,generation_config = generation_config)s = generation_output[0]output = tokenizer.decode(s, skip_special_tokens=True)if args.with_prompt:response = output.split("assistant\n\n")[-1].strip()else:response = outputprint("Response: ",response)print("\n")else:print("Start inference.")results = []if args.use_vllm:if args.with_prompt is True:inputs = [generate_prompt(example) for example in examples]else:inputs = examplesoutputs = model.generate(inputs, SamplingParams(**generation_config))for index, (example, output) in enumerate(zip(examples, outputs)):response = output.outputs[0].textprint(f"======={index}=======")print(f"Input: {example}\n")print(f"Output: {response}\n")results.append({"Input":example,"Output":response})else:for index, example in enumerate(examples):if args.with_prompt:input_text = generate_prompt(instruction=example)else:input_text = exampleinputs = tokenizer(input_text,return_tensors="pt")  #add_special_tokens=False ?generation_output = model.generate(input_ids = inputs["input_ids"].to(device),attention_mask = inputs['attention_mask'].to(device),eos_token_id=terminators,pad_token_id=tokenizer.eos_token_id,generation_config = generation_config)s = generation_output[0]output = tokenizer.decode(s,skip_special_tokens=True)if args.with_prompt:response = output.split("assistant\n\n")[1].strip()else:response = outputprint(f"======={index}=======")print(f"Input: {example}\n")print(f"Output: {response}\n")results.append({"Input":input_text,"Output":response})dirname = os.path.dirname(args.predictions_file)os.makedirs(dirname,exist_ok=True)with open(args.predictions_file,'w') as f:json.dump(results,f,ensure_ascii=False,indent=2)if args.use_vllm:with open(dirname+'/generation_config.json','w') as f:json.dump(generation_config,f,ensure_ascii=False,indent=2)else:generation_config.save_pretrained('./')

2、tokenizer加载

有关tokenzier相关加载可参考博客这里。这里,我直接给出其源码,如下:

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
terminators = [tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids("<|eot_id|>")]

tokenizer.eos_token_id=128009,而terminators=[128009,128009]。

3、llama3模型加载

huggingface模型加载可参考博客这里。这里,llama3的模型加载不在介绍,如下源码:

model = AutoModelForCausalLM.from_pretrained(args.base_model,  # 权重路径文件夹torch_dtype=load_type,low_cpu_mem_usage=True,device_map='auto',quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None,attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa"
)
if device==torch.device('cpu'):model.float()
model.eval()

注意:model.eval()为固定权重方式,这是pytorch评估类似。

4、llama3测试数据文本加载

 # test dataif args.data_file is None:examples = sample_data  #  ["为什么要减少污染,保护环境?","你有什么建议?"]else:with open(args.data_file, 'r') as f:examples 

相关文章:

llama3源码解读之推理-infer

文章目录 前言一、整体源码解读1、完整main源码2、tokenizer加载3、llama3模型加载4、llama3测试数据文本加载5、llama3模型推理模块1、模型推理模块的数据处理2、模型推理模块的model.generate预测3、模型推理模块的预测结果处理6、多轮对话二、llama3推理数据处理1、完整数据…...

【教程】Linux安装Redis步骤记录

下载地址 Index of /releases/ Downloads - Redis 安装redis-7.4.0.tar.gz 1.下载安装包 wget https://download.redis.io/releases/redis-7.4.0.tar.gz 2.解压 tar -zxvf redis-7.4.0.tar.gz 3.进入目录 cd redis-7.4.0/ 4.编译 make 5.安装 make install PREFIX/u…...

全球汽车线控制动系统市场规模预测:未来六年CAGR为17.3%

引言&#xff1a; 随着汽车行业的持续发展和对安全性能需求的增加&#xff0c;汽车线控制动系统作为提升车辆安全性和操控性的关键组件&#xff0c;正逐渐受到市场的广泛关注。本文旨在通过深度分析汽车线控制动系统行业的各个维度&#xff0c;揭示行业发展趋势和潜在机会。 【…...

Ubuntu运行深度学习代码,代码随机epoch中断没有任何报错

深度学习运行代码直接中断 文章目录 深度学习运行代码直接中断问题描述设备信息问题补充解决思路问题发现及正确解决思路新问题出现最终问题&#xff1a;ubuntu系统&#xff0c;4090显卡安装英伟达驱动535.x外的驱动会导致开机无法进入桌面问题记录 问题描述 运行深度学习代码…...

只有4%知道的Linux,看了你也能上手Ubuntu桌面系统,Ubuntu简易设置,源更新,root密码,远程服务...

创作不易 只因热爱!! 热衷分享&#xff0c;一起成长! “你的鼓励就是我努力付出的动力” 最近常提的一句话&#xff0c;那就是“但行好事&#xff0c;莫问前程"! 与辉同行的董工说​&#xff1a;​守正出奇。坚持分享&#xff0c;坚持付出&#xff0c;坚持奉献&#xff0c…...

Tomcat部署——个人笔记

Tomcat部署——个人笔记 文章目录 [toc]简介安装配置文件WEB项目的标准结构WEB项目部署IDEA中开发并部署运行WEB项目 本学习笔记参考尚硅谷等教程。 简介 Apache Tomcat 官网 Tomcat是Apache 软件基金会&#xff08;Apache Software Foundation&#xff09;的Jakarta 项目中…...

常见且重要的用户体验原则

以下是一些常见且重要的用户体验原则&#xff1a; 1. 以用户为中心 - 深入了解用户的需求、期望、目标和行为习惯。通过用户研究、调查、访谈等方法获取真实的用户反馈&#xff0c;以此来设计产品或服务。 - 例如&#xff0c;在设计一款老年手机时&#xff0c;充分考虑老年…...

web基础及nginx搭建

第四周 上午 静态资源 根据开发者保存在项目资源目录中的路径访问静态资源 html 图片 js css 音乐 视频 f12 &#xff0c;开发者工具&#xff0c;网络 1 、 web 基本概念 web 服务器&#xff08; web server &#xff09;&#xff1a;也称 HTTP 服务器&#xff08; HTTP …...

C++ 布隆过滤器

1. 布隆过滤器提出 我们在使用新闻客户端看新闻时&#xff0c;它会给我们不停地推荐新的内容&#xff0c;它每次推荐时要去重&#xff0c;去掉 那些已经看过的内容。问题来了&#xff0c;新闻客户端推荐系统如何实现推送去重的&#xff1f; 用服务器记录了用 户看过的所有历史…...

使用HTML创建用户注册表单

在当今数字化时代&#xff0c;网页表单对于收集用户信息和促进网站交互至关重要。无论您设计简单的注册表单还是复杂的调查表&#xff0c;了解HTML的基础知识可以帮助您构建有效的用户界面。在本教程中&#xff0c;我们将详细介绍如何使用HTML创建基本的用户注册表单。 第一步…...

Python零基础入门教程

Python零基础详细入门教程可以从以下几个方面进行学习和掌握&#xff1a; 一、Python基础认知 1. Python简介 由来与发展&#xff1a;Python是一种广泛使用的高级编程语言&#xff0c;由Guido van Rossum&#xff08;吉多范罗苏姆&#xff09;于1991年首次发布。Python以其简…...

成为git砖家(10): 根据文件内容生成SHA-1

文章目录 1. .git/objects 目录2. git cat-file 命令3. 根据文件内容生成 sha-14. 结语5. References 1. .git/objects 目录 git 是一个根据文件内容进行检索的系统。 当创建 hello.py, 填入 print("hello, world")的内容&#xff0c; 并执行 git add hello.py gi…...

园区导航小程序:一站式解决园区导航问题,释放存储,优化访客体验

随着园区的规模不断扩大&#xff0c;功能区划分日益复杂&#xff0c;导致访客和新员工在没有有效导航的情况下容易迷路。传统APP导航虽能解决部分问题&#xff0c;但其下载安装繁琐、占用手机内存大、且非高频使用导致的闲置&#xff0c;让许多用户望而却步。园区导航小程序的出…...

对于n进制转十进制的解法及代码(干货!)

对于p进制转十进制&#xff0c;我们有&#xff1a;(x)pa[0]*p^0a[1]*p^1a[2]*p^2...a[n]*p^n 举个例子&#xff1a;&#xff08;11001&#xff09;21*10*20*41*81*1625 &#xff08;9FA&#xff09;1610*16^015*16^19*16^22554 据此&#xff0c;我们可以编出c代码来解决问题 …...

当代互联网打工人的生存现状,看完泪流满面!

欢迎私信小编&#xff0c;了解更多产品信息呦~...

花几千上万学习Java,真没必要!(三十八)

测试代码1&#xff1a; package iotest.com; import java.nio.charset.StandardCharsets; import java.io.UnsupportedEncodingException; public class StringByteConversion { public static void main(String[] args) throws UnsupportedEncodingException { // 原始字…...

Zilliz 2025届校园招聘正式启动,寻找向量数据库内核开发工程师

为了解决非结构化数据处理问题&#xff0c;我们构建了向量数据库-Milvus! Milvus 数据库不仅是顶级开源基金会 LF AI&Data 的毕业项目&#xff0c;还曾登上数据库顶会SIGMOD、VLDB&#xff0c;在全球首届向量检索比赛中夺冠。目前&#xff0c;Milvus 项目已获得超过 2.8w s…...

TwinCAT3 新建项目教程

文章目录 打开TwinCAT 新建项目&#xff08;通过TcXaeShell&#xff09; 新建项目&#xff08;通过VS 2019&#xff09;...

大模型算法面试题(十九)

本系列收纳各种大模型面试题及答案。 1、SFT&#xff08;有监督微调&#xff09;、RM&#xff08;奖励模型&#xff09;、PPO&#xff08;强化学习&#xff09;的数据集格式&#xff1f; SFT&#xff08;有监督微调&#xff09;、RM&#xff08;奖励模型&#xff09;、PPO&…...

应用地址信息获取新技巧:Xinstall来助力

在移动互联网时代&#xff0c;应用获取用户地址信息的需求越来越普遍。无论是为了提供个性化服务&#xff0c;还是进行精准营销&#xff0c;地址信息都扮演着至关重要的角色。然而&#xff0c;如何合规、准确地获取这一信息&#xff0c;却是许多开发者面临的挑战。今天&#xf…...

基于算法竞赛的c++编程(28)结构体的进阶应用

结构体的嵌套与复杂数据组织 在C中&#xff0c;结构体可以嵌套使用&#xff0c;形成更复杂的数据结构。例如&#xff0c;可以通过嵌套结构体描述多层级数据关系&#xff1a; struct Address {string city;string street;int zipCode; };struct Employee {string name;int id;…...

日语AI面试高效通关秘籍:专业解读与青柚面试智能助攻

在如今就业市场竞争日益激烈的背景下&#xff0c;越来越多的求职者将目光投向了日本及中日双语岗位。但是&#xff0c;一场日语面试往往让许多人感到步履维艰。你是否也曾因为面试官抛出的“刁钻问题”而心生畏惧&#xff1f;面对生疏的日语交流环境&#xff0c;即便提前恶补了…...

MODBUS TCP转CANopen 技术赋能高效协同作业

在现代工业自动化领域&#xff0c;MODBUS TCP和CANopen两种通讯协议因其稳定性和高效性被广泛应用于各种设备和系统中。而随着科技的不断进步&#xff0c;这两种通讯协议也正在被逐步融合&#xff0c;形成了一种新型的通讯方式——开疆智能MODBUS TCP转CANopen网关KJ-TCPC-CANP…...

Mac软件卸载指南,简单易懂!

刚和Adobe分手&#xff0c;它却总在Library里给你写"回忆录"&#xff1f;卸载的Final Cut Pro像电子幽灵般阴魂不散&#xff1f;总是会有残留文件&#xff0c;别慌&#xff01;这份Mac软件卸载指南&#xff0c;将用最硬核的方式教你"数字分手术"&#xff0…...

Rapidio门铃消息FIFO溢出机制

关于RapidIO门铃消息FIFO的溢出机制及其与中断抖动的关系&#xff0c;以下是深入解析&#xff1a; 门铃FIFO溢出的本质 在RapidIO系统中&#xff0c;门铃消息FIFO是硬件控制器内部的缓冲区&#xff0c;用于临时存储接收到的门铃消息&#xff08;Doorbell Message&#xff09;。…...

零基础在实践中学习网络安全-皮卡丘靶场(第九期-Unsafe Fileupload模块)(yakit方式)

本期内容并不是很难&#xff0c;相信大家会学的很愉快&#xff0c;当然对于有后端基础的朋友来说&#xff0c;本期内容更加容易了解&#xff0c;当然没有基础的也别担心&#xff0c;本期内容会详细解释有关内容 本期用到的软件&#xff1a;yakit&#xff08;因为经过之前好多期…...

USB Over IP专用硬件的5个特点

USB over IP技术通过将USB协议数据封装在标准TCP/IP网络数据包中&#xff0c;从根本上改变了USB连接。这允许客户端通过局域网或广域网远程访问和控制物理连接到服务器的USB设备&#xff08;如专用硬件设备&#xff09;&#xff0c;从而消除了直接物理连接的需要。USB over IP的…...

WEB3全栈开发——面试专业技能点P7前端与链上集成

一、Next.js技术栈 ✅ 概念介绍 Next.js 是一个基于 React 的 服务端渲染&#xff08;SSR&#xff09;与静态网站生成&#xff08;SSG&#xff09; 框架&#xff0c;由 Vercel 开发。它简化了构建生产级 React 应用的过程&#xff0c;并内置了很多特性&#xff1a; ✅ 文件系…...

[特殊字符] 手撸 Redis 互斥锁那些坑

&#x1f4d6; 手撸 Redis 互斥锁那些坑 最近搞业务遇到高并发下同一个 key 的互斥操作&#xff0c;想实现分布式环境下的互斥锁。于是私下顺手手撸了个基于 Redis 的简单互斥锁&#xff0c;也顺便跟 Redisson 的 RLock 机制对比了下&#xff0c;记录一波&#xff0c;别踩我踩过…...

对象回调初步研究

_OBJECT_TYPE结构分析 在介绍什么是对象回调前&#xff0c;首先要熟悉下结构 以我们上篇线程回调介绍过的导出的PsProcessType 结构为例&#xff0c;用_OBJECT_TYPE这个结构来解析它&#xff0c;0x80处就是今天要介绍的回调链表&#xff0c;但是先不着急&#xff0c;先把目光…...