当前位置: 首页 > news >正文

[pai-diffusion]pai的easynlp的clip模型训练

EasyNLP带你玩转CLIP图文检索 - 知乎作者:熊兮、章捷、岑鸣、临在导读随着自媒体的不断发展,多种模态数据例如图像、文本、语音、视频等不断增长,创造了互联网上丰富多彩的世界。为了准确建模用户的多模态内容,跨模态检索是跨模态理解的重要任务,…icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/528476134

initialize_easynlp()->train_dataset = CLIPDataset(pretrained_model_name_or_path=get_pretrain_model_path("alibaba-pai/clip_chinese_roberta_base_vit_base"),data_file="MUGE_MR_train_base64_part.tsv",max_seq_length=32,input_schema="text:str:1,image:str:1",first_sequence="text",second_sequence="image",is_training=True)
valid_dataset = CLIPDataset()model = get_application_model(app_name='clip',...)
- easynlp.appzoo.api.ModelMapping->CLIPApp
- easynlp.appzoo.clip.model.py->CLIPApp
- CHINESE_CLIP->
- self.visual = VisualTransformer()
- self.bert = BertModel()trainer = Trainer(model,train_dataset,user_defined_parameters,  evaluator=get_application_evaluator(app_name="clip",valid_dataset=valid_dataset,user_defined_parameters=user_defined_parameters,eval_batch_size=32))trainer.train()
- for _epoch in range(self._first_epoch,int(args.epoch_num)):for _step,batch in enumerate(self._train_loader):    label_ids = batch.pop()forward_outputs = self._model(batch)loss_dict = self.model_module.compute_loss(forward_outputs,label_ids)_loss = loss_dict('loss')_loss.backward()model = get_application_model_evaluation()
evaluator = get_application_evaluator()
evaluator.evaluate(model)

数据处理:

import os
import base64
import multiprocessing
from tqdm import tqdmdef process_image(image_path):# 从图片路径中提取中文描述image_name = os.path.basename(image_path)description = os.path.splitext(image_name)[0]# 将图片转换为 Base64 编码with open(image_path, 'rb') as f:image_data = f.read()base64_data = base64.b64encode(image_data).decode('utf-8')return description, base64_datadef generate_tsv(directory):image_paths = [os.path.join(directory, filename) for filename in os.listdir(directory) iffilename.endswith(('.jpg', '.png'))]with multiprocessing.Pool() as pool, tqdm(total=len(image_paths), desc='Processing Images') as pbar:results = []for result in pool.imap_unordered(process_image, image_paths):results.append(result)pbar.update(1)with open('/home/image_team/image_team_docker_home/lgd/e_commerce_sd/data/vcg_furnitures_text_image/vcg_furnitures_train.tsv','w', encoding='utf-8') as f:for description, base64_data in results:line = f"{description}\t{base64_data}\n"f.write(line)if __name__ == '__main__':target_directory = "/home/image_team/image_team_docker_home/lgd/e_commerce_sd/data/vcg_furnitures_text_image/vcg_furnitures_train/img_download/"# import pdb;pdb.set_trace()generate_tsv(target_directory)

训练代码:

import torch.cuda
from easynlp.appzoo import CLIPDataset
from easynlp.appzoo import get_application_predictor, get_application_model, get_application_evaluator, \get_application_model_for_evaluation
from easynlp.core import Trainer, PredictorManager
from easynlp.utils import initialize_easynlp, get_args, get_pretrain_model_path
from easynlp.utils.global_vars import parse_user_defined_parametersdef main():# /root/.easynlp/modelzoo中train_dataset = CLIPDataset(pretrained_model_name_or_path=get_pretrain_model_path(args.pretrained_model_name_or_path),data_file=args.tables.split(",")[0],max_seq_length=args.sequence_length,input_schema=args.input_schema,first_sequence=args.first_sequence,second_sequence=args.second_sequence,is_training=True)valid_dataset = CLIPDataset(# 预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称"alibaba-pai/clip_chinese_roberta_base_vit_base"以得到其路径,并自动下载模型pretrained_model_name_or_path=get_pretrain_model_path(args.pretrained_model_name_or_path),data_file=args.tables.split(",")[-1],# "data/pai/MUGE_MR_valid_base64_part.tsv"max_seq_length=args.sequence_length,  # 文本最大长度,超过将截断,不足将paddinginput_schema=args.input_schema,  # 输入tsv数据的格式,逗号分隔的每一项对应数据文件中每行以\t分隔的一项,每项开头为其字段标识,如label、sent1等first_sequence=args.first_sequence,  # 用于说明input_schema中哪些字段作为第一/第二列输入数据second_sequence=args.second_sequence,is_training=False)  # 是否为训练过程,train_dataset为True,valid_dataset为Falsemodel = get_application_model(app_name=args.app_name,  # 任务名称,这里选择文本分类"clip"pretrained_model_name_or_path=get_pretrain_model_path(args.pretrained_model_name_or_path),user_defined_parameters=user_defined_parameters# user_defined_parameters:用户自定义参数,直接填入刚刚处理好的自定义参数user_defined_parameters)trainer = Trainer(model=model,train_dataset=train_dataset,user_defined_parameters=user_defined_parameters,evaluator=get_application_evaluator(app_name=args.app_name,valid_dataset=valid_dataset,user_defined_parameters=user_defined_parameters,eval_batch_size=32))trainer.train()# 模型评估model = get_application_model_for_evaluation(app_name=args.app_name,pretrained_model_name_or_path=args.checkpoint_dir,user_defined_parameters=user_defined_parameters)evaluator = get_application_evaluator(app_name=args.app_name,valid_dataset=valid_dataset,user_defined_parameters=user_defined_parameters,eval_batch_size=32)model.to(torch.cuda.current_device())evaluator.evaluate(model=model)# 模型预测if test:predictor = get_application_predictor(app_name="clip",model_dir="./outputs/clip_model/",first_sequence="text",second_sequence="image",sequence_length=32,user_defined_parameters=user_defined_parameters)predictor_manager = PredictorManager(predictor=predictor,input_file="data/vcg_furnitures_text_image/vcg_furnitures_test.tsv",input_schema="text:str:1",output_file="text_feat.tsv",output_schema="text_feat",append_cols="text",batch_size=2)predictor_manager.run()if __name__ == "__main__":initialize_easynlp()args = get_args()user_defined_parameters = parse_user_defined_parameters('pretrain_model_name_or_path=alibaba-pai/clip_chinese_roberta_base_vit_base')args.checkpoint_dir = "./outputs/clip_model/"args.pretrained_model_name_or_path = "alibaba-pai/clip_chinese_roberta_base_vit_base"# args.n_gpu = 3# args.worker_gpu = "1,2,3"args.app_name = "clip"args.tables = "data/pai/MUGE_MR_train_base64_part.tsv,data/pai/MUGE_MR_valid_base64_part.tsv"# "data/vcg_furnitures_text_image/vcg_furnitures_train.tsv," \#               "data/vcg_furnitures_text_image/vcg_furnitures_test.tsv"# "data/pai/MUGE_MR_train_base64_part.tsv,data/pai/MUGE_MR_valid_base64_part.tsv"args.input_schema = "text:str:1,image:str:1"args.first_sequence = "text"args.second_sequence = "image"args.learning_rate = 1e-4args.epoch_num = 1000args.random_seed = 42args.save_checkpoint_steps = 200args.sequence_length = 32# args.train_batch_size = 2args.micro_batch_size = 32test = Falsemain()# python -m torch.distributed.launch --nproc_per_node 4 tools/train_pai_chinese_clip.py

说一点自己的想法,在我自己工作之初,我很喜欢去拆解一些框架,例如openmm系列,但其实大部分在训练过程上都是相似的,大可不必,在改动上,也没有必要对其进行流程上的大改动,兼具百家之长,了解整体pipeline,更加专注在pipeline实现和效果导向型的结果提交更加有效。

相关文章:

[pai-diffusion]pai的easynlp的clip模型训练

EasyNLP带你玩转CLIP图文检索 - 知乎作者:熊兮、章捷、岑鸣、临在导读随着自媒体的不断发展,多种模态数据例如图像、文本、语音、视频等不断增长,创造了互联网上丰富多彩的世界。为了准确建模用户的多模态内容,跨模态检索是跨模态…...

期权如何交易?期权如何做模拟交易?

买卖期权的第一步就是要有期权账户,国内的期权品种有商品期权和ETF期权以及股指期权,每种的开户方式和要求都不同,下文为大家介绍期权如何交易?期权如何做模拟交易? 一、期权交易需要开立一个期权账户,可以…...

【新书推荐】大模型赛道如何实现华丽的弯道超车 —— 《分布式统一大数据虚拟文件系统 Alluxio原理、技术与实践》

文章目录 大模型赛道如何实现华丽的弯道超车 —— AI/ML训练赋能解决方案01 具备对海量小文件的频繁数据访问的 I/O 效率02 提高 GPU 利用率,降低成本并提高投资回报率03 支持各种存储系统的原生接口04 支持单云、混合云和多云部署01 通过数据抽象化统一数据孤岛02 …...

Calendar对象获取当前周的bug

项目场景: 双周项目管理,需要获取当前周为一年之中的第几周,原先的代码是用Calendar对象,先用setTime()把当前时间传入,再用get(3)获取一年中的第几周 问题描述 实际发…...

嵌入式环境buildroot的espeak配置与编译

1、在buildroot目录下输入make menuconfig 2、选择Target packages 3、选择Audio and video applications 4、选择espeak、选择alsa via portaudio (新版嵌入式linux一般都是用alsa音频驱动) 5、配置portaudio 选择Library 6、选择Audio/Sound 7、选择…...

物理机环境搭建-linux部署nginx

1、安装nginx部署所需依赖 yum install -y gcc-c pcre pcre-devel zlib zlib-devel openssl openssl-devel2、安装nginx包 wget http://nginx.org/download/nginx-1.8.0.tar.gz 如果没有wget可以安装一下 yum install -y wget下载完成后可以在/usr/local/下放置tar包&#xf…...

删除安装Google Chrome浏览器时捆绑安装的Google 文档、表格、幻灯片、Gmail、Google 云端硬盘、YouTube网址链接(Mac)

删除安装Google Chrome浏览器时捆绑安装的Google 文档、表格、幻灯片、Gmail、Google 云端硬盘、YouTube网址链接(Mac) Mac mini操作系统,安装完 Google Chrome 浏览器以后,单击 启动台 桌面左下角的“显示应用程序”,我们发现捆绑安装了 Goo…...

硬件故障诊断:快速定位问题

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…...

IP代理与加速器:理解它们的区别与共同点

在网络使用过程中,我们经常会遇到需要提高访问速度或保护隐私的需求。IP代理和加速器都是常见的应对方案,但它们在工作原理和应用场景上存在一些区别。本文将为您深入探讨IP代理和加速器的异同,帮助您更好地理解它们的作用和适用情况&#xf…...

Java中List转字符串的方法

一、使用String.join方法 在Java 8之后&#xff0c;String类增加了一个静态方法join()&#xff0c;可以方便地将列表中的元素连接成字符串。 // 创建List List<String> list Arrays.asList("Google", "Baidu", "Taobao"); // 以逗号分隔…...

PyTorch实战:实现MNIST手写数字识别

前言 PyTorch可以说是三大主流框架中最适合初学者学习的了&#xff0c;相较于其他主流框架&#xff0c;PyTorch的简单易用性使其成为初学者们的首选。这样我想要强调的一点是&#xff0c;框架可以类比为编程语言&#xff0c;仅为我们实现项目效果的工具&#xff0c;也就是我们…...

【计算机网络】深入理解TCP协议二(连接管理机制、WAIT_TIME、滑动窗口、流量控制、拥塞控制)

TCP协议 1.连接管理机制2.再谈WAIT_TIME状态2.1理解WAIT_TIME状态2.2解决TIME_WAIT状态引起的bind失败的方法2.3监听套接字listen第二个参数介绍 3.滑动窗口3.1介绍3.2丢包情况分析 4.流量控制5.拥塞控制5.1介绍5.2慢启动 6.捎带应答、延时应答 1.连接管理机制 正常情况下&…...

springboot整合sentinel完成限流

1、直入正题&#xff0c;下载sentinel的jar包 1.1 直接到Sentinel官网里的releases下即可下载最新版本&#xff0c;Sentinel官方下载地址&#xff0c;直接下载jar包即可。不过慢&#xff0c;可能下载不下来 1.2 可以去gitee去下载jar包 1.3 下载完成后&#xff0c;进行打包…...

signal(SIGPIPE, SIG_IGN)

linux查看signal常见信号。 [rootplatform:]# kill -l1) HUP2) INT3) QUIT4) ILL5) TRAP6) ABRT7) BUS8) FPE9) KILL 10) USR1 11) SEGV 12) USR2 13) PIPE 14) ALRM 15) TERM 16) STKFLT 17) CHLD 18) CONT 19) STOP 20) TSTP 21) TTIN 22) TTOU 23) URG 24) XCPU 25) XFSZ 2…...

GAN学习笔记

1.原始的GAN 1.1原始的损失函数 1.1.1写法1参考1&#xff0c;参考2 1.1.2 写法2 where, G Generator D Discriminator Pdata(x) distribution of real data P(z) distribution of generator x sample from Pdata(x) z sample from P(z) D(x) Discriminator network G…...

layui框架学习(45: 工具集模块)

layui的工具集模块util支持固定条、倒计时等组件&#xff0c;同时提供辅助函数处理时间数据、字符转义、批量事件处理等操作。   util模块中的fixbar函数支持设置固定条&#xff08;2.7版本的帮助文档中叫固定块&#xff09;&#xff0c;是指固定在页面一侧的工具条元素&…...

车道检测:Decoupling the Curve Modeling and Pavement Regression for Lane Detection

论文作者&#xff1a;Wencheng Han,Jianbing Shen 作者单位&#xff1a;University of Macau 论文链接&#xff1a;http://arxiv.org/abs/2309.10533v1 内容简介&#xff1a; 1&#xff09;方向&#xff1a;车道检测 2&#xff09;应用&#xff1a;车道检测 3&#xff09…...

【扩散生成模型】Diffusion Generative Models

提出扩散模型思想的论文&#xff1a; 《Deep Unsupervised Learning using Nonequilibrium Thermodynamics》理解 扩散模型综述&#xff1a; “扩散模型”首篇综述论文分类汇总&#xff0c;谷歌&北大最新研究 理论推导、代码实现&#xff1a; What are Diffusion Models?…...

美联储加息步伐“暂停”!BTC凌晨力守27000美元!

美东时间9月20日下午&#xff0c;美联储宣布放缓加息步伐&#xff0c;将联邦基金利率目标维持在5.25%至5.50%的区间不变&#xff0c;保持在22年来的最高点&#xff0c;符合市场预期。 在最新的FOMC声明中&#xff0c;美联储表示最近的指标表明&#xff0c;经济活动一直在稳步扩…...

微信小程序与idea后端如何进行数据交互

交互使用的其实就是调用的req.get(url)方法 进行路径访问&#xff0c;你要先保证自己的springboot项目已经成功运行了&#xff1a; 如下&#xff1a; 如何交互的&#xff1f; 微信小程序&#xff1a;如下为index.js页面 在onLoad()事件中调用方法Project.findAllCities() 要…...

Java 学习路线分享 maven 是什么?

Maven 是一款基于 Java 平台的项目管理和整合工具&#xff0c;它将项目的开发和管理过程抽象成一个项目对象模型&#xff08;POM&#xff09;。开发人员只需要做一些简单的配置&#xff0c;Maven 就可以自动完成项目的编译、测试、打包、发布以及部署等工作。 Maven 是使用 Ja…...

实战演练 | Navicat 常用功能之转储与运行 SQL 文件

数据库管理工作中&#xff0c;"转储 SQL 文件"和"运行 SQL 文件"是两个极为常见操作。一般来说&#xff0c;用户使用数据库管理工具或命令行工具来完成。Navicat 管理开发工具中的“转储 SQL 文件”和“运行 SQL 文件”功能具有直观易用的界面、多种文件格…...

MySQL的备份与恢复

备份与恢复 一、备份1.1 数据备份的必要性1.2 数据备份分类1.2.1 物理备份1.2.2 逻辑备份 1.3 数据库备份策略1.4 常用的备份方法和工具1.5 数据库上云迁移 二、MySQL完全备份2.1 简介2.2 物理冷备份与恢复2.2.1 物理冷备份2.2.2 解压恢复 2.3 mysqldump备份与恢复1&#xff09…...

Python中的函数未定义的错误

前言&#xff1a; 嗨喽~大家好呀&#xff0c;这里是魔王呐 ❤ ~! python更多源码/资料/解答/教程等 点击此处跳转文末名片免费获取 通过这个解释&#xff0c;我们将了解当Python程序显示类似NameError: name ‘’ is not defined的错误时&#xff0c;即使该函数存在于脚本中&…...

AG35学习笔记(二):安装编译SDK、CMakeLists编译app、Scons编译server

目录 一、概述二、安装SDK2.1 网盘SDK - 权限不够2.2 bj41 - 需要交叉source2.3 mullen - relocate_sdk.py路径有误 三、编译SDK3.1 /bin/sh: 1: gcc: not found3.2 curses.h: No such file or directory 四、CMakeLists - 编译app4.1 cmake - 项目构建4.2 make - 项目编译4.3 …...

多台服务器sessionId共享

目录 多台服务器sessionId共享解决方案&#xff1a;ASP.NET Core 参考代码(NET 7):登录处理登录&#xff08;请求&#xff09;过滤器过滤器使用BaseController 多台服务器sessionId共享 session id是服务器首次与浏览器创建连接时&#xff0c;生成的id值&#xff0c;存入浏览器…...

如何在Gazebo中实现多机器人编队仿真

文章目录 前言一、仿真前的配置二、实现步骤1.检查PC和台式机是否通讯成功2.编队中对单个机器人进行独立的控制3、对机器人进行编队控制 前言 实现在gazebo仿真环境中添加多个机器人后&#xff0c;接下来进行编队控制&#xff0c;对具体的实现过程进行记录。 一、仿真前的配置…...

迅为iTOP-iMX6QPLUS-Android6.0下uboot添加网卡驱动

本文档介绍在 iTOP-iMX6Q 和 iTOP-iMX6Q-PLUS 安卓 6.0 的 uboot 上添加网卡驱 动&#xff0c;添加完网卡驱动以后&#xff0c;uboot 就可以正常使用网络了。 1 具体步骤 1.1 修改 mx6sabre_common.h 文件 在 iTOP-iMX6_android6.0.1 源码目录下输入以下命令&#xff0c;打…...

sql server 触发器的使用

看数据库下的所有触发器及状态 SELECT a.name 数据表名 , sysobjects.name AS 触发器名 , sysobjects.crdate AS 创建时间 , sysobjects.info , sysobjects.status FROM sysobjects LEFT JOIN ( SELECT * FROM sysobjects WHERE xtype U ) AS a ON sysobjects.parent_obj a.…...

使用亚马逊云服务器在 G4 实例上运行 Android 应用程序

随着 Android 应用程序和游戏变得越来越丰富&#xff0c;其中有些甚至比 PC 上的软件更易于使用和娱乐&#xff0c;因此许多人希望能够在云上运行 Android 游戏或应用程序&#xff0c;而在 EC2 实例上运行 Android 的解决方案可以让开发人员更轻松地测试和运行 Android 应用程序…...

聊城做网站推广哪家好/广告优化师培训

1、针对布局加载Xml文件的优化&#xff0c;我们使用了异步Inflate的方式&#xff0c;即AsyncLayoutInflater。它的核心原理是在子线程中对我们的Layout进行加载&#xff0c;而加载完成之后会将View通过Handler发送到主线程来使用。所以不会阻塞我们的主线程&#xff0c;加载的时…...

网络营销平台的类型/软媒win7优化大师

一分钟速览新闻点&#xff01; 快手调整员工福利&#xff1a;取消免费三餐&#xff0c;增加生育津贴阿里计划出售微博全部股份&#xff0c;彻底退出投资微信新规来了&#xff01;小程序调用个人敏感信息将需授权阿里CEO张勇辞任滴滴董事&#xff01;滴滴启动香港上市准备新浪升…...

北京万户网络技术有限公司/windows优化

简介 如何让代码执行得更快,如何充分发挥多核CPU的性能,是程序员需要思考的问题. 本文通过简单易懂的实例,让大家快速了解C#多线程的基本方法. 参考文档:http://www.cnblogs.com/yunfeifei/p/3993401.html实例 using System; using System.Diagnostics; using System.Threading…...

wordpress 归档函数/西安网络推广营销公司

MYSQL 内部模块 [Connection Pool] (授权、线程复用、连接限制、内存检测等) >[SQL Interface] (DML、DDL、Views等) [Parser] (Query Translation、Object privilege) [Optimizer] (Access Paths、 统计分析) [Caches & Buffers] >[Pluggable Storage Engines] 复…...

深圳自适应网站建设报价/百度小说app

正则表达式是一个特殊的字符序列&#xff0c;它能帮助你方便的检查一个字符串是否与某种模式匹配。Python 自1.5版本起增加了re 模块&#xff0c;它提供 Perl 风格的正则表达式模式。re 模块使 Python 语言拥有全部的正则表达式功能。compile 函数根据一个模式字符串和可选的标…...

网站建设排名/附近学电脑培训班

人机交互部分的设计 控制驱动和数据管理部分的设计...