llama3源码解读之推理-infer
文章目录
- 前言
- 一、整体源码解读
- 1、完整main源码
- 2、tokenizer加载
- 3、llama3模型加载
- 4、llama3测试数据文本加载
- 5、llama3模型推理模块
- 1、模型推理模块的数据处理
- 2、模型推理模块的model.generate预测
- 3、模型推理模块的预测结果处理
- 6、多轮对话
- 二、llama3推理数据处理
- 1、完整数据处理源码
- 2、使用prompt方式询问数据加载
- 3、推理处理数据
- 三、llama3推理generate调用_sample方法
- 1、GenerationMode.SAMPLE方法源码
- 2、huggingface的_sample源码
- 3、_sample的初始化与准备
- 4、_sample的初始化(attention / hidden states / scores)
- 5、encoder-decoder模式模型处理
- 6、保持相应变量内容
- 7、进入while循环模块
- 8、结果返回处理
- 四、_sample的while循环模块内容
- 1、模型输入加工
- 1、首次迭代数据加工
- 2、再次迭代数据加工
- 2、模型推理
- 1、首次迭代数据加工
- 2、再次迭代数据加工
- 3、预测结果取值
- 4、预测logits惩罚处理
- 5、预测softmax处理
- 6、预测token选择处理
- 7、停止条件更新处理
- 8、再次循环迭代input_ids与attention_mask处理(生成式预测)
- 8、停止条件再次更新与处理
- 五、模型推理(self)
- 1、模型预测输入参数
- 2、进入包装选择调用方法
- 3、进入forward函数--class LlamaForCausalLM(LlamaPreTrainedModel)
- 4、llama3推理forward的前提说明
- 5、llama3推理forward的设置参数
- 6、llama3推理forward的推理(十分重要)
- 7、llama3推理forward的的lm_head转换
- 8、llama3推理forward的结果包装(CausalLMOutputWithPast)
- 六、llama3模型forward的self.model方法
- 1、模型输入内容
- 1、self.model模型初次输入内容
- 2、self.model模型再次输入内容
- 2、进入包装选择调用方法
- 3、进入forward函数--class LlamaModel(LlamaPreTrainedModel)
- 4、llama3模型forward的源码
- 5、llama3模型forward的输入准备与embedding
- 1、初次输入llama3模型内容
- 2、再此输入llama3模型内容
- 3、说明
- 6、llama3模型forward的cache
- 1、DynamicCache.from_legacy_cache(past_key_values)函数
- 2、cache = cls()类方法
- 该类位置
- 该类方法
- 2、cache.update的更新
- 7、llama3模型decoder_layer方法
- 1、decoder_layer调用源码
- 2、decoder_layer源码方法与调用self.self_attn方法
- 3、self.self_attn源码方法
- 8、hidden_states的LlamaRMSNorm
- 9、next_cache保存与输出结果
- 七、llama3模型atten方法
- 1、模型位置
- 2、llama3推理模型初始化
- 3、llama3推理的forward方法
- 4、llama3模型结构图
前言
本项目是解读开源github的代码,该项目基于Meta最新发布的新一代开源大模型Llama-3开发,是Chinese-LLaMA-Alpaca开源大模型相关系列项目(一期、二期)的第三期。而本项目开源了中文Llama-3基座模型和中文Llama-3-Instruct指令精调大模型。这些模型在原版Llama-3的基础上使用了大规模中文数据进行增量预训练,并且使用精选指令数据进行精调,进一步提升了中文基础语义和指令理解能力,相比二代相关模型获得了显著性能提升。因此,我是基于该项目解读训练与推理相关原理与内容,并以代码形式带领读者一步一步解读,理解其大语言模型运行机理。而该博客首先给出llama3推理源码相关内容解读,我将按照源码流程给出解读。
一、整体源码解读
1、完整main源码
我先给出完整的源码,后面推理使用哪些部分代码,我在深度解读。而一些较为简单内容我不在解读了。
if __name__ == '__main__':load_type = torch.float16# Move the model to the MPS device if availableif torch.backends.mps.is_available():device = torch.device("mps")else:if torch.cuda.is_available():device = torch.device(0)else:device = torch.device('cpu')print(f"Using device: {device}")if args.tokenizer_path is None:args.tokenizer_path = args.base_modeltokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)terminators = [tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids("<|eot_id|>")]if args.use_vllm:model = LLM(model=args.base_model,tokenizer=args.tokenizer_path,tensor_parallel_size=len(args.gpus.split(',')),dtype=load_type)generation_config["stop_token_ids"] = terminatorsgeneration_config["stop"] = ["<|eot_id|>", "<|end_of_text|>"]else:if args.load_in_4bit or args.load_in_8bit:quantization_config = BitsAndBytesConfig(load_in_4bit=args.load_in_4bit,load_in_8bit=args.load_in_8bit,bnb_4bit_compute_dtype=load_type,bnb_4bit_use_double_quant=True,bnb_4bit_quant_type="nf4")model = AutoModelForCausalLM.from_pretrained(args.base_model,torch_dtype=load_type,low_cpu_mem_usage=True,device_map='auto',quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None,attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa")if device==torch.device('cpu'):model.float()model.eval()# test dataif args.data_file is None:examples = sample_dataelse:with open(args.data_file, 'r') as f:examples = [line.strip() for line in f.readlines()]print("first 10 examples:")for example in examples[:10]:print(example)with torch.no_grad():if args.interactive:print("Start inference with instruction mode.")print('='*85)print("+ 该模式下仅支持单轮问答,无多轮对话能力。\n""+ 如要进行多轮对话,请使用llama.cpp")print('-'*85)print("+ This mode only supports single-turn QA.\n""+ If you want to experience multi-turn dialogue, please use llama.cpp")print('='*85)while True:raw_input_text = input("Input:")if len(raw_input_text.strip())==0:breakif args.with_prompt:input_text = generate_prompt(instruction=raw_input_text)else:input_text = raw_input_textif args.use_vllm:output = model.generate([input_text], SamplingParams(**generation_config), use_tqdm=False)response = output[0].outputs[0].textelse:inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ?generation_output = model.generate(input_ids = inputs["input_ids"].to(device),attention_mask = inputs['attention_mask'].to(device),eos_token_id=terminators,pad_token_id=tokenizer.eos_token_id,generation_config = generation_config)s = generation_output[0]output = tokenizer.decode(s, skip_special_tokens=True)if args.with_prompt:response = output.split("assistant\n\n")[-1].strip()else:response = outputprint("Response: ",response)print("\n")else:print("Start inference.")results = []if args.use_vllm:if args.with_prompt is True:inputs = [generate_prompt(example) for example in examples]else:inputs = examplesoutputs = model.generate(inputs, SamplingParams(**generation_config))for index, (example, output) in enumerate(zip(examples, outputs)):response = output.outputs[0].textprint(f"======={index}=======")print(f"Input: {example}\n")print(f"Output: {response}\n")results.append({"Input":example,"Output":response})else:for index, example in enumerate(examples):if args.with_prompt:input_text = generate_prompt(instruction=example)else:input_text = exampleinputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ?generation_output = model.generate(input_ids = inputs["input_ids"].to(device),attention_mask = inputs['attention_mask'].to(device),eos_token_id=terminators,pad_token_id=tokenizer.eos_token_id,generation_config = generation_config)s = generation_output[0]output = tokenizer.decode(s,skip_special_tokens=True)if args.with_prompt:response = output.split("assistant\n\n")[1].strip()else:response = outputprint(f"======={index}=======")print(f"Input: {example}\n")print(f"Output: {response}\n")results.append({"Input":input_text,"Output":response})dirname = os.path.dirname(args.predictions_file)os.makedirs(dirname,exist_ok=True)with open(args.predictions_file,'w') as f:json.dump(results,f,ensure_ascii=False,indent=2)if args.use_vllm:with open(dirname+'/generation_config.json','w') as f:json.dump(generation_config,f,ensure_ascii=False,indent=2)else:generation_config.save_pretrained('./')
2、tokenizer加载
有关tokenzier相关加载可参考博客这里。这里,我直接给出其源码,如下:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
terminators = [tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids("<|eot_id|>")]
tokenizer.eos_token_id=128009,而terminators=[128009,128009]。
3、llama3模型加载
huggingface模型加载可参考博客这里。这里,llama3的模型加载不在介绍,如下源码:
model = AutoModelForCausalLM.from_pretrained(args.base_model, # 权重路径文件夹torch_dtype=load_type,low_cpu_mem_usage=True,device_map='auto',quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None,attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa"
)
if device==torch.device('cpu'):model.float()
model.eval()
注意:model.eval()为固定权重方式,这是pytorch评估类似。
4、llama3测试数据文本加载
# test dataif args.data_file is None:examples = sample_data # ["为什么要减少污染,保护环境?","你有什么建议?"]else:with open(args.data_file, 'r') as f:examples 相关文章:
llama3源码解读之推理-infer
文章目录 前言一、整体源码解读1、完整main源码2、tokenizer加载3、llama3模型加载4、llama3测试数据文本加载5、llama3模型推理模块1、模型推理模块的数据处理2、模型推理模块的model.generate预测3、模型推理模块的预测结果处理6、多轮对话二、llama3推理数据处理1、完整数据…...
【教程】Linux安装Redis步骤记录
下载地址 Index of /releases/ Downloads - Redis 安装redis-7.4.0.tar.gz 1.下载安装包 wget https://download.redis.io/releases/redis-7.4.0.tar.gz 2.解压 tar -zxvf redis-7.4.0.tar.gz 3.进入目录 cd redis-7.4.0/ 4.编译 make 5.安装 make install PREFIX/u…...
全球汽车线控制动系统市场规模预测:未来六年CAGR为17.3%
引言: 随着汽车行业的持续发展和对安全性能需求的增加,汽车线控制动系统作为提升车辆安全性和操控性的关键组件,正逐渐受到市场的广泛关注。本文旨在通过深度分析汽车线控制动系统行业的各个维度,揭示行业发展趋势和潜在机会。 【…...
Ubuntu运行深度学习代码,代码随机epoch中断没有任何报错
深度学习运行代码直接中断 文章目录 深度学习运行代码直接中断问题描述设备信息问题补充解决思路问题发现及正确解决思路新问题出现最终问题:ubuntu系统,4090显卡安装英伟达驱动535.x外的驱动会导致开机无法进入桌面问题记录 问题描述 运行深度学习代码…...
只有4%知道的Linux,看了你也能上手Ubuntu桌面系统,Ubuntu简易设置,源更新,root密码,远程服务...
创作不易 只因热爱!! 热衷分享,一起成长! “你的鼓励就是我努力付出的动力” 最近常提的一句话,那就是“但行好事,莫问前程"! 与辉同行的董工说:守正出奇。坚持分享,坚持付出,坚持奉献,…...
Tomcat部署——个人笔记
Tomcat部署——个人笔记 文章目录 [toc]简介安装配置文件WEB项目的标准结构WEB项目部署IDEA中开发并部署运行WEB项目 本学习笔记参考尚硅谷等教程。 简介 Apache Tomcat 官网 Tomcat是Apache 软件基金会(Apache Software Foundation)的Jakarta 项目中…...
常见且重要的用户体验原则
以下是一些常见且重要的用户体验原则: 1. 以用户为中心 - 深入了解用户的需求、期望、目标和行为习惯。通过用户研究、调查、访谈等方法获取真实的用户反馈,以此来设计产品或服务。 - 例如,在设计一款老年手机时,充分考虑老年…...
web基础及nginx搭建
第四周 上午 静态资源 根据开发者保存在项目资源目录中的路径访问静态资源 html 图片 js css 音乐 视频 f12 ,开发者工具,网络 1 、 web 基本概念 web 服务器( web server ):也称 HTTP 服务器( HTTP …...
C++ 布隆过滤器
1. 布隆过滤器提出 我们在使用新闻客户端看新闻时,它会给我们不停地推荐新的内容,它每次推荐时要去重,去掉 那些已经看过的内容。问题来了,新闻客户端推荐系统如何实现推送去重的? 用服务器记录了用 户看过的所有历史…...
使用HTML创建用户注册表单
在当今数字化时代,网页表单对于收集用户信息和促进网站交互至关重要。无论您设计简单的注册表单还是复杂的调查表,了解HTML的基础知识可以帮助您构建有效的用户界面。在本教程中,我们将详细介绍如何使用HTML创建基本的用户注册表单。 第一步…...
Python零基础入门教程
Python零基础详细入门教程可以从以下几个方面进行学习和掌握: 一、Python基础认知 1. Python简介 由来与发展:Python是一种广泛使用的高级编程语言,由Guido van Rossum(吉多范罗苏姆)于1991年首次发布。Python以其简…...
成为git砖家(10): 根据文件内容生成SHA-1
文章目录 1. .git/objects 目录2. git cat-file 命令3. 根据文件内容生成 sha-14. 结语5. References 1. .git/objects 目录 git 是一个根据文件内容进行检索的系统。 当创建 hello.py, 填入 print("hello, world")的内容, 并执行 git add hello.py gi…...
园区导航小程序:一站式解决园区导航问题,释放存储,优化访客体验
随着园区的规模不断扩大,功能区划分日益复杂,导致访客和新员工在没有有效导航的情况下容易迷路。传统APP导航虽能解决部分问题,但其下载安装繁琐、占用手机内存大、且非高频使用导致的闲置,让许多用户望而却步。园区导航小程序的出…...
对于n进制转十进制的解法及代码(干货!)
对于p进制转十进制,我们有:(x)pa[0]*p^0a[1]*p^1a[2]*p^2...a[n]*p^n 举个例子:(11001)21*10*20*41*81*1625 (9FA)1610*16^015*16^19*16^22554 据此,我们可以编出c代码来解决问题 …...
当代互联网打工人的生存现状,看完泪流满面!
欢迎私信小编,了解更多产品信息呦~...
花几千上万学习Java,真没必要!(三十八)
测试代码1: package iotest.com; import java.nio.charset.StandardCharsets; import java.io.UnsupportedEncodingException; public class StringByteConversion { public static void main(String[] args) throws UnsupportedEncodingException { // 原始字…...
Zilliz 2025届校园招聘正式启动,寻找向量数据库内核开发工程师
为了解决非结构化数据处理问题,我们构建了向量数据库-Milvus! Milvus 数据库不仅是顶级开源基金会 LF AI&Data 的毕业项目,还曾登上数据库顶会SIGMOD、VLDB,在全球首届向量检索比赛中夺冠。目前,Milvus 项目已获得超过 2.8w s…...
TwinCAT3 新建项目教程
文章目录 打开TwinCAT 新建项目(通过TcXaeShell) 新建项目(通过VS 2019)...
大模型算法面试题(十九)
本系列收纳各种大模型面试题及答案。 1、SFT(有监督微调)、RM(奖励模型)、PPO(强化学习)的数据集格式? SFT(有监督微调)、RM(奖励模型)、PPO&…...
应用地址信息获取新技巧:Xinstall来助力
在移动互联网时代,应用获取用户地址信息的需求越来越普遍。无论是为了提供个性化服务,还是进行精准营销,地址信息都扮演着至关重要的角色。然而,如何合规、准确地获取这一信息,却是许多开发者面临的挑战。今天…...
国防科技大学计算机基础课程笔记02信息编码
1.机内码和国标码 国标码就是我们非常熟悉的这个GB2312,但是因为都是16进制,因此这个了16进制的数据既可以翻译成为这个机器码,也可以翻译成为这个国标码,所以这个时候很容易会出现这个歧义的情况; 因此,我们的这个国…...
【杂谈】-递归进化:人工智能的自我改进与监管挑战
递归进化:人工智能的自我改进与监管挑战 文章目录 递归进化:人工智能的自我改进与监管挑战1、自我改进型人工智能的崛起2、人工智能如何挑战人类监管?3、确保人工智能受控的策略4、人类在人工智能发展中的角色5、平衡自主性与控制力6、总结与…...
ETLCloud可能遇到的问题有哪些?常见坑位解析
数据集成平台ETLCloud,主要用于支持数据的抽取(Extract)、转换(Transform)和加载(Load)过程。提供了一个简洁直观的界面,以便用户可以在不同的数据源之间轻松地进行数据迁移和转换。…...
Linux --进程控制
本文从以下五个方面来初步认识进程控制: 目录 进程创建 进程终止 进程等待 进程替换 模拟实现一个微型shell 进程创建 在Linux系统中我们可以在一个进程使用系统调用fork()来创建子进程,创建出来的进程就是子进程,原来的进程为父进程。…...
2025年低延迟业务DDoS防护全攻略:高可用架构与实战方案
一、延迟敏感行业面临的DDoS攻击新挑战 2025年,金融交易、实时竞技游戏、工业物联网等低延迟业务成为DDoS攻击的首要目标。攻击呈现三大特征: AI驱动的自适应攻击:攻击流量模拟真实用户行为,差异率低至0.5%,传统规则引…...
Copilot for Xcode (iOS的 AI辅助编程)
Copilot for Xcode 简介Copilot下载与安装 体验环境要求下载最新的安装包安装登录系统权限设置 AI辅助编程生成注释代码补全简单需求代码生成辅助编程行间代码生成注释联想 代码生成 总结 简介 尝试使用了Copilot,它能根据上下文补全代码,快速生成常用…...
mq安装新版-3.13.7的安装
一、下载包,上传到服务器 https://github.com/rabbitmq/rabbitmq-server/releases/download/v3.13.7/rabbitmq-server-generic-unix-3.13.7.tar.xz 二、 erlang直接安装 rpm -ivh erlang-26.2.4-1.el8.x86_64.rpm不需要配置环境变量,直接就安装了。 erl…...
git引用概念(git reference,git ref)(简化对复杂SHA-1哈希值的管理)(分支引用、标签引用、HEAD引用、远程引用、特殊引用)
文章目录 **引用的本质**1. **引用是文件**2. **引用的简化作用** **引用的类型**1. **分支引用(Branch References)**2. **标签引用(Tag References)**3. **HEAD 引用**4. **远程引用(Remote References)*…...
GitHub 趋势日报 (2025年06月07日)
📊 由 TrendForge 系统生成 | 🌐 https://trendforge.devlive.org/ 🌐 本日报中的项目描述已自动翻译为中文 📈 今日获星趋势图 今日获星趋势图 603 netbird 459 dify 440 cognee 352 omni-tools 337 note-gen 239 ragbits 237 …...
【Zephyr 系列 8】构建完整 BLE 产品架构:状态机 + AT 命令 + 双通道通信实战
🧠关键词:Zephyr、BLE、状态机、双向透传、AT 命令、Buffer、主从共存、系统架构 📌适合人群:希望开发 BLE 产品(模块/标签/终端)具备可控、可测、可维护架构的开发者 🧭 引言:从“点功能”到“系统架构” 前面几篇我们已经逐步构建了 BLE 广播、连接、数据透传系统…...
