当前位置: 首页 > news >正文

基于chatGPT设计卷积神经网络

1. 简介

本文主要介绍基于chatGPT,设计一个针对骁龙855芯片设计的友好型神经网络。

提问->跑通总共花了5min左右,最终得到的网络在Cifar100数据集上与ResNet18的精度对比如下。

模型flopsparamstrain acc1/5test acc1/5
ResNet18(timm)1.8211.18~98%/100%63.11%/86.11%
CNN-GPT1.516.66~98%/100%68.44%/89.14%

此外,GPT生成的模型训练速度更快,ResNet18才训练170个epoch,CNN-GPT已经完整地训练完了,具体的模型结构在文末,欢迎尝试使用~

个人认为比较有意思的一点是,GPT生成的模型并没有在stem block中就进行降采样,而是卷了一会之后才开始,这个与之前在文本任务中的经验结果类似,特征图越大,会避免信息损失。
基于chatGPT生成模型结构的方式或许在某些情况下可以解决我们的问题,但是基础的神经网络知识还是需要的,否则它给出的答案有些小bug,我们也不知道怎么修改。
模型训练结果与超参设置强相关,在这里仅设置默认的训练策略作为对比,根据训练结果看出,两个模型均过拟合严重。而且,不管是ResNet18或者CNN-GPT,均与Cifar100数据集的SOTA精度相差很远,此处需要注意。

2. 具体步骤

Step1 骁龙855芯片上可以正常运行的算子

基于上述问题,chatGPT给出的回答如下所示。

以下是在骁龙855芯片上比较友好的神经网络算子:1. Conv2d
2. Depthwise Conv2d
3. Maxpool2d
4. Avgpool2d
5. BatchNorm2d
6. ReLU
7. Flatten
8. Linear

整体看起来还算比较靠谱。

Step2 构建一个模型大小在10M左右的模型

针对上面的问题,chatGPT首先给出了思路。

如果想设计一个在骁龙855芯片上运行速度很快,模型大小在10M以内的模型,可以考虑使用以下方法来进行优化:1. 使用深度可分离卷积代替标准卷积,以获得更快的速度和更小的模型尺寸。
2. 使用深度方向的滤波器来替换全连接层,以减少参数数量并减少内存使用。
3. 对于每个卷积层,使用Batch Norm层来加速收敛。
4. 使用ReLU代替传统的Sigmoid或Tanh激活函数,因为ReLU的计算效率更高。

然后给出了一个模型示例。

以下是一个可以作为参考的模型:Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
BatchNorm2d(32)
ReLU()DepthwiseConv2d(32, kernel_size=3, stride=1, padding=1)
BatchNorm2d(32)
ReLU()Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
BatchNorm2d(64)
ReLU()DepthwiseConv2d(64, kernel_size=3, stride=1, padding=1)
BatchNorm2d(64)
ReLU()Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
BatchNorm2d(128)
ReLU()DepthwiseConv2d(128, kernel_size=3, stride=1, padding=1)
BatchNorm2d(128)
ReLU()Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
BatchNorm2d(256)
ReLU()DepthwiseConv2d(256, kernel_size=3, stride=1, padding=1)
BatchNorm2d(256)
ReLU()Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
BatchNorm2d(512)
ReLU()DepthwiseConv2d(512, kernel_size=3, stride=1, padding=1)
BatchNorm2d(512)
ReLU()Conv2d(512, 1024, kernel_size=3, stride=2, padding=1)
BatchNorm2d(1024)
ReLU()Avgpool2d(kernel_size=4)
Flatten()
Linear(1024, 10)

上面生成的模型有几个小问题。

  1. DepthwiseConv2d没有被定义
  2. DepthwiseConv2d的参数列表中没有out channel的参数,需要手动补充。

Step3 生成一个DepthwiseConv2d结构

chatGPT回答如下

以下是一个非常简单的torch的DepthwiseConv2d实现:

import torch.nn as nnclass DepthwiseConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=None):super(DepthwiseConv2d, self).__init__()if groups is None:groups = in_channelsself.depthwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups)self.pointwise_conv = nn.Conv2d(out_channels, out_channels, 1, 1, 0, 1, groups=1)def forward(self, x):x = self.depthwise_conv(x)x = self.pointwise_conv(x)return x

加入到模型结构定义文件中,同时修改类别数,就可以进行模型训练了。

设置训练超参数

为保证对比的公平性,下面是训练过程中的一些超参数,

	parser.add_argument("--epoch", type=int, default=240)parser.add_argument("--batch_size", type=int, default=64)parser.add_argument("--lr", type=float, default=0.05)parser.add_argument("--lr_decay_rate", type=float, default=0.1)parser.add_argument("--lr_decay_stages", type=str, default="150,180,210")parser.add_argument("--wd", type=float, default=0.0005)

3. 完整的模型结构

完整的模型结构如下,也欢迎大家在自己的数据集上尝试。

import torch
import torch.nn as nnclass DepthwiseConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=None):super(DepthwiseConv2d, self).__init__()if groups is None:groups = in_channelsself.depthwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups)self.pointwise_conv = nn.Conv2d(out_channels, out_channels, 1, 1, 0, 1, groups=1)def forward(self, x):x = self.depthwise_conv(x)x = self.pointwise_conv(x)return xclass CNNGPT(nn.Module):def __init__(self) -> None:super().__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(32),nn.ReLU(),DepthwiseConv2d(32, 32, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(32),nn.ReLU(),nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(64),nn.ReLU(),DepthwiseConv2d(64, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(128),nn.ReLU(),DepthwiseConv2d(128, 128, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(256),nn.ReLU(),DepthwiseConv2d(256, 256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(512),nn.ReLU(),DepthwiseConv2d(512, 512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(1024),nn.AdaptiveAvgPool2d(1),nn.Flatten(),nn.Linear(1024, 100),)def forward(self, x):y = self.model(x)return ydef get_flops_params(model):from thop import profilemodel.eval()flops, params = profile(model,inputs=[torch.randn([1, 3, 224, 224]),],)print(f"flops: {flops/1000**3} G, params: {params/1000**2} M")return flops, paramsif __name__ == "__main__":model = CNNGPT()get_flops_params(model)

相关文章:

基于chatGPT设计卷积神经网络

1. 简介 本文主要介绍基于chatGPT,设计一个针对骁龙855芯片设计的友好型神经网络。 提问->跑通总共花了5min左右,最终得到的网络在Cifar100数据集上与ResNet18的精度对比如下。 模型flopsparamstrain acc1/5test acc1/5ResNet18(timm)1.8211.18~98…...

java.sql.Date和java.util.Date的区别

参考答案 java.sql.Date 是 java.util.Date 的子类java.util.Date 是 JDK 中的日期类,精确到时、分、秒、毫秒java.sql.Date 与数据库 Date 相对应的一个类型,只有日期部分,时分秒都会设置为 0,如:2019-10-23 00:00:0…...

动态规划---线性dp和区间dp

动态规划(三) 目录动态规划(三)一:线性DP1.数字三角形1.1数字三角形题目1.2代码思路1.3代码实现(正序and倒序)2.最长上升子序列2.1最长上升子序列题目2.2代码思路2.3代码实现3.最长公共子序列3.1最长公共子序列题目3.2代码思路3.3代码实现4.石子合并4.1题目如下4.2代…...

常见的2D与3D碰撞检测算法

分离轴分离轴定理(Separating Axis Theorem)是用于解决2D或3D物体碰撞检测问题的一种方法。其基本思想是,如果两个物体未发生碰撞,那么可以找到一条分离轴(即一条直线或平面),两个物体在该轴上的…...

STM32 10个工程篇:1.IAP远程升级(二)

一直提醒自己要更新CSDN博客,但是确实这段时间到了一个项目的关键节点,杂七杂八的事情突然就一涌而至。STM32、FPGA下位机代码和对应Labview的IAP升级助手、波形设置助手上位机代码笔者已经调试通过,因为不想去水博客、凑数量,复制…...

Unity+ChatGpt的联动 AICommand

果然爱是会消失的,对吗 chatGpt没出现之前起码还看人家的文章,现在都是随便你。 本着师夷长技以制夷的思路,既然打不过,那么我就加入 github地址:https://github.com/keijiro/AICommand 文档用chatGpt翻译如下&#…...

STM-32:按键控制LED灯 程序详解

目录一、基本原理二、接线图三、程序思路3.1库函数3.2程序代码注:一、基本原理 左边是STM322里电路每一个端口均可以配置的电路部分,右边部分是外接设备 电路图。 配置为 上拉输入模式的意思就是,VDD开关闭合,VSS开关断开。 浮空…...

北邮22信通:(8)实验1 题目五:大整数加减法(搬运官方代码)

北邮22信通一枚~ 跟随课程进度每周更新数据结构与算法的代码和文章 持续关注作者 解锁更多邮苑信通专属代码~ 上一篇文章: 北邮22信通:(7)实验1 题目四:一元多项式(节省内存版)_青山如…...

Fiddler抓取https史上最强教程

有任何疑问建议观看下面视频 2023最新Fiddler抓包工具实战,2小时精通十年技术!!!对于想抓取HTTPS的测试初学者来说,常用的工具就是fiddler。 但是初学时,大家对于fiddler如何抓取HTTPS难免走歪路&#xff…...

STM32开发基础知识入门

C语言基础 位操作 对基本类型变量可以在位级别进行操作。 1) 不改变其他位的值的状况下,对某几个位进行设值。 先对需要设置的位用&操作符进行清零操作,然后用|操作符设值。 2) 移位操作提高代码的可读性。 3) ~取反操作使用技巧 可用于对某…...

学习操作系统的必备教科书《操作系统:原理与实现》| 文末赠书4本

使用了6年的实时操作系统,是时候梳理一下它的知识点了 摘要: 本文简单介绍了博主学习操作系统的心路历程,同时还给大家总结了一下当下流行的几种实时操作系统,以及在工程中OSAL应该如何设计。希望对大家有所启发和帮助。 文章目录…...

大数据的常用算法(分类、回归分析、聚类、关联规则、神经网络方法、web数据挖掘)

在大数据时代,数据挖掘是最关键的工作。大数据的挖掘是从海量、不完全的、有噪声的、模糊的、随机的大型数据库中发现隐含在其中有价值的、潜在有用的信息和知识的过程,也是一种决策支持过程。其主要基于人工智能,机器学习,模式学…...

【数据结构】详解二叉树与堆与堆排序的关系

🌇个人主页:平凡的小苏 📚学习格言:别人可以拷贝我的模式,但不能拷贝我不断往前的激情 🛸C语言专栏:https://blog.csdn.net/vhhhbb/category_12174730.html 🚀数据结构专栏&#xff…...

【Pandas】数据分析入门

文章目录前言一、Pandas简介1.1 什么是Pandas1.2 Pandas应用二、Series结构2.1 Series简介2.2 基本使用三、DataFrame结构3.1 DataFrame简介3.2 基本使用四、Pandas-CSV4.1 CSV简介4.2 读取CSV文件4.3 数据处理五、数据清洗5.1 数据清洗的方法5.2 清洗案例总结前言 大家好&…...

【c++】:list模拟实现“任意位置插入删除我最强ƪ(˘⌣˘)ʃ“

文章目录 前言一.list的基本功能的使用二.list的模拟实现总结前言 1. list是可以在常数范围内在任意位置进行插入和删除的序列式容器,并且该容器可以前后双向迭代。2. list的底层是双向链表结构,双向链表中每个元素存储在互不相关的独立节点中&#xff0…...

QT表格控件实例(Table Widget 、Table View)

欢迎小伙伴的点评✨✨,相互学习🚀🚀🚀 博主🧑🧑 本着开源的精神交流Qt开发的经验、将持续更新续章,为社区贡献博主自身的开源精神👩‍🚀 文章目录前言一、图示实例二、列…...

第二章Vue组件化编程

文章目录模块与组件、模块化与组件化模块组件模块化组件化Vue中的组件含义非单文件组件基本使用组件注意事项使用 kebab-case使用 PascalCase组件的嵌套模板templateVueComponent一个重要的内置功能单文件组件Vue脚手架使用Vue CLI脚手架先配置环境初始化脚手架分析脚手架结构实…...

面试官:vue2和vue3的区别有哪些

目录 多根节点,fragment(碎片) Composition API reactive 函数是用来创建响应式对象 Ref toRef toRefs 去除了管道 v-model的prop 和 event 默认名称会更改 vue2写法 Vue 3写法 vue3组件需要使用v-model时的写法 其他语法 1. 创…...

【TopK问题】——用堆实现

文章目录一、TopK问题是什么二、解决方法三、时间复杂度一、TopK问题是什么 TopK问题就是从1000个数中找出前K个最大的数或者最小的数这样的类似问题。 不过并不要求这k个数字必须是有序的,如果题目有要求,则进行堆排序即可。 还有比如求出全国玩韩信…...

【Spring从成神到升仙系列 四】从源码分析 Spring 事务的来龙去脉

👏作者简介:大家好,我是爱敲代码的小黄,独角兽企业的Java开发工程师,CSDN博客专家,阿里云专家博主📕系列专栏:Java设计模式、数据结构和算法、Kafka从入门到成神、Kafka从成神到升仙…...

使用Nginx反向代理OpenAI API

由于OpenAI的API在国内无法访问,所以可以通过海外服务器利用Nginx实现反向代理。 安装Nginx 这一步就不赘述了,不同的Linux系统安装方式略有不同,根据自己的服务器的系统自行百度即可。 OpenSSL创建证书 因为OpenAI的接口是https协议的&a…...

USB键盘实现——字符串描述符(四)

字符串描述符 字符串描述符内容解析和 HID鼠标 一致。 获取字符串描述符请求 标准设备请求 typedef struct __attribute__ ((packed)){union {struct __attribute__ ((packed)) {uint8_t recipient : 5; ///< Recipient type usb_request_recipient_t.uint8_t type …...

STM32的中断

目录 一、STM32中断概述 二、外部中断控制器EXTI 三、按键中断 四、串口中断 一、STM32中断概述 处理器中的中断在处理器中&#xff0c;中断是一个过程&#xff0c;即CPU在正常执行程序的过程中&#xff0c;遇到外部/内部的紧急事件需要处理&#xff0c;暂时中止当前程序的…...

Flink进阶篇-CDC 原理、实践和优化采集到Doris中

简介 基于doris官方用doris构建实时仓库的思路&#xff0c;从flinkcdc到doris实时数仓的实践。 原文 Apache Flink X Apache Doris 构建极速易用的实时数仓架构 (qq.com) 前提-Flink CDC 原理、实践和优化 CDC 是什么 CDC 是变更数据捕获&#xff08;Change Data Captur…...

看完这篇 教你玩转渗透测试靶机vulnhub——My File Server: 1

Vulnhub靶机My File Server: 1渗透测试详解Vulnhub靶机介绍&#xff1a;Vulnhub靶机下载&#xff1a;Vulnhub靶机安装&#xff1a;Vulnhub靶机漏洞详解&#xff1a;①&#xff1a;信息收集&#xff1a;②&#xff1a;FTP匿名登入&#xff1a;③&#xff1a;SMB共享服务&#xf…...

OpenHarmony实战STM32MP157开发板 “控制” Hi3861开发板 -- 中篇

一、前言 我们在 OpenHarmony实战STM32MP157开发板 “控制” Hi3861开发板 – 上篇 中介绍到了,App面板的开发,以及JS API接口的开发和调用。 那么本篇文章,会详解:BearPi-HM Nano开发板,如何实现数据上报和指令接收响应的。 看到这里,可能有同学可能已经知道思路了,因…...

【数据结构初阶】单链表

目录一、思路>>>>>>>>>>>>过程<<<<<<<<<<<<<<<1.打印2.尾插3.尾删4.头插5.头删6.查找7.指定位置后插入8.指定位置后删除9.链表的销毁二、整个程序1.SLTlist.c2.SLTlist.c一、思路 #define …...

多线程代码案例-阻塞队列

hi,大家好,今天为大家带来多线程案例--阻塞队列 这块知识点也很重要,要好好掌握呀~~~ &#x1f338;&#x1f338;&#x1f338;&#x1f338;&#x1f338;&#x1f338;&#x1f338;&#x1f338;&#x1f338;&#x1f338;&#x1f338;&#x1f338;&#x1f338;&#x…...

mysql的limit查询竟然有坑?

背景 最近项目联调的时候发现了分页查询的一个bug&#xff0c;分页查询总有数据查不出来或者重复查出。 数据库一共14条记录。 如果按照一页10条。那么第一页和第二页的查询SQL和和结果如下。 .png) 那么问题来了&#xff0c;查询第一页和第二页的时候都出现了11,12,13的记录…...

【Docker】MAC电脑下的Docker操作

文章目录安装Docker部署mysql 一主一从登录ChatGPT搞方案本地创建一个文件夹编辑docker-compose.yml文件启动检查并编排容器验证基于command的my.cnf配置的加载主数据库建一个用户给子数据库用于主从复制启动主从同步安装Docker 官网地址 https://www.docker.com/ 下载安装 验…...

小说网站seo排名怎么做/篮网目前排名

打开这两个地址&#xff0c;就能看到最新的安装包了。。。。 http://www.jianshu.com/p/c67c14b3110c https://static.xamarin.com/installer_assets/v3/Mac/Universal/InstallationManifest.xml...

wordpress自定义登录/网站维护一般都是维护什么

安徽新华学院《局域网组建、管理与维护》的评分标准相关搜索: 安徽, 局域网, 新华, 评分, 学院安徽新华学院2009-2010学年度第一学期《局域网组建、管理与维护》期末考试参考答案及评分标准&#xff08;A卷,考核&#xff09;命题教师 应作斌 审核人&#xff1a; 适用年级…...

音乐网站设计外国/seo营销怎么做

Go Web编程--SecureCookie实现客户端Session管理在Web应用开发中Session是在用户和服务器之间进行交换的非持久化交互信息。当用户登录时&#xff0c;可以在用户和服务器之间生成Session&#xff0c;然后来回交换数据&#xff0c;并在用户登出时销毁Session。gorilla/sessions软…...

wordpress 上传类/信息流推广方式

中新网1月24日电 据外媒报道&#xff0c;为了防止出现“无协议脱欧”&#xff0c;当地时间23日&#xff0c;英国工党一名发言人表示&#xff0c;工党或将在议会中支持推迟脱欧的提案。资料图&#xff1a;当地时间1月15日晚&#xff0c;英国议会下院以432票对202票&#xff0c;投…...

thinkphp网站源码下载/百度搜图

2023年&#xff0c;低代码仍然是热点关键词&#xff0c;毫无疑问&#xff0c;我们“低代码”仍会是2023年软件行业的热门关键字。 今天探讨下传统开发跟低代码开发产品有什么区别&#xff0c;产品交付方面怎么样&#xff1f; 传统开发&#xff1a; 正常一个软件产品的开发周期都…...

视频分销网站建设/网站建设步骤流程详细介绍

一.问题11.需求&#xff1a;最近在用Vue写后台&#xff0c;其中有个需求被广泛使用&#xff1a;点击签收&#xff0c;立即变成 节点办理|办结2.解决办法&#xff1a;思路;在 节点办理|办结 两个button按钮上绑定v-if的三目运算&#xff0c;当点击签收&#xff0c;签收状…...