当前位置: 首页 > news >正文

深度学习中简易FC和CNN搭建

  • TensorFlow是由谷歌开发的
  • PyTorch是由Facebook人工智能研究院(Facebook AI Research)开发的

Torch和cuda版本的对应,手动安装较好

全连接FC(Batch*Num)

搭建建议网络:

from torch import nnclass Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out  = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x

封装数据

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)

训练模型:

import numpy as npdef fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))

一般在训练模型时加上model.train(),这样会正常使用Batch NormalizationDropout;测试的时候一般选择model.eval(),这样就不会使用Batch NormalizationDropout

批量损失函数

from torch import optim
def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)

优化器
SGD是一种简单且易于实现的优化算法,但在大规模数据集和复杂模型上收敛缓慢。
Adam是一种自适应学习率调整的优化算法,能够更快地收敛,但可能会占用更多的内存。
在实践中,根据具体问题和数据集的特点,选择适合的优化算法可以提高训练效果。

卷积神经网络CNN(Batch * C * H * W)

Channel First
引入py库

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms 
import matplotlib.pyplot as plt
import numpy as np

预处理

# 定义超参数 
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片# 训练集
train_dataset = datasets.MNIST(root='./data',  train=True,   transform=transforms.ToTensor(),  download=True) # 测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

构建CNN

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)nn.Conv2d(in_channels=1,              # 灰度图out_channels=16,            # 要得到几多少个特征图kernel_size=5,              # 卷积核大小stride=1,                   # 步长padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1),                              # 输出的特征图为 (16, 28, 28)nn.ReLU(),                      # relu层nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14))self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),                      # relu层nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),                # 输出 (32, 7, 7))self.conv3 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(32, 64, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),             # 输出 (32, 7, 7))self.out = nn.Linear(64 * 7 * 7, 10)   # 全连接层得到的结果def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7)output = self.out(x)return output

定义准确率

def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels) 

训练网络模型

# 实例化
net = CNN() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法#开始训练循环
for epoch in range(num_epochs):#当前epoch的结果保存下来train_rights = [] for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环net.train()                             output = net(data) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() right = accuracy(output, target) train_rights.append(right) if batch_idx % 100 == 0: net.eval() val_rights = [] for (data, target) in test_loader:output = net(data) right = accuracy(output, target) val_rights.append(right)#准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data, 100. * train_r[0].numpy() / train_r[1], 100. * val_r[0].numpy() / val_r[1]))

相关文章:

深度学习中简易FC和CNN搭建

TensorFlow是由谷歌开发的PyTorch是由Facebook人工智能研究院(Facebook AI Research)开发的 Torch和cuda版本的对应,手动安装较好 全连接FC(Batch*Num) 搭建建议网络: from torch import nnclass Mnist_NN(nn.Module):def __i…...

【多模态】20、OVR-CNN | 使用 caption 来实现开放词汇目标检测

文章目录 一、背景二、方法2.1 学习 视觉-语义 空间2.2 学习开放词汇目标检测 三、效果 论文:Open-Vocabulary Object Detection Using Captions 代码:https://github.com/alirezazareian/ovr-cnn 出处:CVPR2021 Oral 一、背景 目标检测数…...

网络编程 IO多路复用 [select版] (TCP网络聊天室)

//head.h 头文件 //TcpGrpSer.c 服务器端 //TcpGrpUsr.c 客户端 select函数 功能&#xff1a;阻塞函数&#xff0c;让内核去监测集合中的文件描述符是否准备就绪&#xff0c;若准备就绪则解除阻塞。 原型&#xff1a; #include <sys/select.…...

数学建模学习(7):单目标和多目标规划

优化问题描述 优化 优化算法是指在满足一定条件下,在众多方案中或者参数中最优方案,或者参数值,以使得某个或者多个功能指标达到最优,或使得系统的某些性能指标达到最大值或者最小值 线性规划 线性规划是指目标函数和约束都是线性的情况 [x,fval]linprog(f,A,b,Aeq,Beq,LB,U…...

Element UI如何自定义样式

简介 Element UI是一套非常完善的前端组件库&#xff0c;但是如何个性化定制其中的组件样式呢&#xff1f;今天我们就来聊一聊这个 举例 就拿最常见的按钮el-button来举例&#xff0c;一般来说默认是蓝底白字。效果图如下 可是我们想个性化定制&#xff0c;让他成为粉底红字应…...

protobuf入门实践2

如何在proto中定义一个rpc服务? syntax "proto3"; //声明protobuf的版本package fixbug; //声明了代码所在的包 &#xff08;对于C来说就是namespace)//下面的选项&#xff0c;表示生成service服务类和rpc方法描述&#xff0c; 默认是不生成的 option cc_generi…...

adb shell使用总结

文章目录 日志记录系统概览adb 使用方式 adb命令日志过滤按照告警等级进行过滤按照tag进行过滤根据告警等级和tag进行联合过滤屏蔽系统和其他App干扰&#xff0c;仅仅关注App自身日志 查看“当前页面”Activity文件传输截屏和录屏安装、卸载App启动activity其他 日志记录系统概…...

UG NX二次开发(C++)-Tag的含义、Tag类型与其他的转换

文章目录 1、前言2、Tag号的含义3、tag_t转换为int3、TaggedObject与Tag转换3.1 TaggedObject定义3.2 TaggedObject获取Tag3.3 根据Tag获取TaggedObject4.Tag与double类型的转换1、前言 在UG NX中,每个对象对应一个tag号,C++中,其类型是tag_t,一般是5位或者6位的int数字,…...

Informer 论文学习笔记

论文&#xff1a;《Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting》 代码&#xff1a;https://github.com/zhouhaoyi/Informer2020 地址&#xff1a;https://arxiv.org/abs/2012.07436v3 特点&#xff1a; 实现时间与空间复杂度为 O ( …...

c语言位段知识详解

本篇文章带来位段相关知识详细讲解&#xff01; 如果您觉得文章不错&#xff0c;期待你的一键三连哦&#xff0c;你的鼓励是我创作的动力之源&#xff0c;让我们一起加油&#xff0c;一起奔跑&#xff0c;让我们顶峰相见&#xff01;&#xff01;&#xff01; 目录 一.什么是…...

FFmpeg aresample_swr_opts的解析

ffmpeg option的解析 aresample_swr_opts是AVFilterGraph中的option。 static const AVOption filtergraph_options[] {{ "thread_type", "Allowed thread types", OFFSET(thread_type), AV_OPT_TYPE_FLAGS,{ .i64 AVFILTER_THREAD_SLICE }, 0, INT_MA…...

CAN学习笔记3:STM32 CAN控制器介绍

STM32 CAN控制器 1 概述 STM32 CAN控制器&#xff08;bxCAN&#xff09;&#xff0c;支持CAN 2.0A 和 CAN 2.0B Active版本协议。CAN 2.0A 只能处理标准数据帧且扩展帧的内容会识别错误&#xff0c;而CAN 2.0B Active 可以处理标准数据帧和扩展数据帧。 2 bxCAN 特性 波特率…...

软工导论知识框架(二)结构化的需求分析

本章节涉及很多重要图表的制作&#xff0c;如ER图、数据流图、状态转换图、数据字典的书写等&#xff0c;对初学者来说比较生僻&#xff0c;本贴只介绍基础的轮廓&#xff0c;后面会有单独的帖子详解各图表如何绘制。 一.结构化的软件开发方法&#xff1a;结构化的分析、设计、…...

[SQL挖掘机] - 算术函数 - abs

介绍: 当谈到 SQL 中的 abs 函数时&#xff0c;它是一个用于计算数值的绝对值的函数。“abs” 代表 “absolute”&#xff08;绝对&#xff09;&#xff0c;因此 abs 函数的作用是返回一个给定数值的非负值&#xff08;即该数值的绝对值&#xff09;。 abs 函数接受一个参数&a…...

vue拼接html点击事件不生效

vue使用ts&#xff0c;拼接html&#xff0c;点击事件不生效或者报 is not defined 点击事件要用onclick 不是click let data{name:测,id:123} let conHtml <div> "名称&#xff1a;" data.name "<br>" <p class"cursor blue&quo…...

【Spring】Spring之依赖注入源码解析

1 Spring注入方式 1.1 手动注入 xml中定义Bean&#xff0c;程序员手动给某个属性赋值。 set方式注入 <bean name"userService" class"com.firechou.service.UserService"><property name"orderService" ref"orderService"…...

【微软知识】微软相关技术知识分享

微软技术领域 一、微软操作系统&#xff1a; 微软的操作系统主要是 Windows 系列&#xff0c;包括 Windows 10、Windows Server 等。了解 Windows 操作系统的基本使用、配置和故障排除是非常重要的。微软操作系统&#xff08;Microsoft System&#xff09;是美国微软开发的Wi…...

12.python设计模式【观察者模式】

内容&#xff1a;定义对象间的一种一对多的依赖关系&#xff0c;当一个对象的状态发生改变的时候&#xff0c;所有依赖于它的对象得到通知并被自动更新。观者者模式又称为“发布-订阅”模式。比如天气预报&#xff0c;气象局分发气象数据。 角色&#xff1a; 抽象主题&#xf…...

重生之我要学C++第五天

这篇文章主要内容是构造函数的初始化列表以及运算符重载在顺序表中的简单应用&#xff0c;运算符重载实现自定义类型的流插入流提取。希望对大家有所帮助&#xff0c;点赞收藏评论&#xff0c;支持一下吧&#xff01; 目录 构造函数进阶理解 1.内置类型成员在参数列表中的定义 …...

复习之linux高级存储管理

一、lvm----逻辑卷管理 1.lvm定义 LVM是 Logical Volume Manager&#xff08;逻辑卷管理&#xff09;的简写&#xff0c;它是Linux环境下对磁盘分区进行管理的一种机制。 逻辑卷管理器(LogicalVolumeManager)本质上是一个虚拟设备驱动&#xff0c;是在内核中块设备和物理设备…...

HuggingGPT Solving AI Tasks with ChatGPT and its Friends in Hugging Face

总述 HuggingGPT 让LLM发挥向路由器一样的作用&#xff0c;让LLM来选择调用那个专业的模型来执行任务。HuggingGPT搭建LLM和专业AI模型的桥梁。Language is a generic interface for LLMs to connect AI models 四个阶段 Task Planning&#xff1a; 将复杂的任务分解。但是这里…...

java工程重写jar包中class类覆盖问题

结论&#xff1a;直接在程序中复写jar中的类即可 原因&#xff1a;一般我java工程是运行在tomcat容器中&#xff0c;tomcat容易在加载我们工程类和jar包是的优先级为&#xff1a; 我们工程的class 先于 我们工程lib下的jar 重复的类只加载一次&#xff0c;加载我们复写后的类后…...

Mybatis基于注解与XML开发

文章目录 1 关于SpringBoot2 关于MyBatis2.1 MyBatis概述2.2 MyBatis核心思想2.3 MyBatis使用流程3 MyBatis配置SQL方式3.1 基于注解方式3.1.1 说明3.1.2 使用流程3.1.3 常用注解 3.2 基于XML方式3.2.1 相比注解优势3.2.2 使用流程3.2.3 常用标签 1 关于SpringBoot SpringBoot…...

数字化转型导师坚鹏:数字化时代扩大内需的8大具体建议

在日新月异的数字化时代、复杂多变的国际化环境下&#xff0c;扩大内需成为推动经济发展的国家战略&#xff0c;如何真正地扩大内需&#xff1f;结合本人15年的管理咨询经验及目前实际情况的深入研究&#xff0c;提出以下8大具体建议&#xff1a; 1、制定国民收入倍增计划。结…...

M1/M2 通过VM Fusion安装Win11 ARM,解决联网和文件传输

前言 最近新入了Macmini M2&#xff0c;但是以前的老电脑的虚拟机运行不起来了。&#x1f605;&#xff0c;实际上用过K8S的时候&#xff0c;会发现部分镜像也跑不起来&#xff0c;X86的架构和ARM实际上还是有很多隐形兼容问题。所以只能重新安装ARM Win11&#xff0c;幸好微软…...

Linux中显示系统正在运行的进程的命令

2023年7月29日&#xff0c;周六上午 在Linux中&#xff0c;ps命令用于显示当前系统中正在运行的进程&#xff0c; ps应该是processes snapshot&#xff08;进程快照&#xff09;的缩写。 以下是ps命令的常见用法和示例&#xff1a; 显示当前用户的所有进程&#xff1a;ps 显示…...

vite中安装less

使用vite创建的项目&#xff0c;默认是没有安装less的 如果直接在style中书写less 会报下图错误&#xff1a; 解决方案&#xff1a; npm install --save less 在package.json中查看是否安装成功 安装完成刷新页面&#xff0c;问题解决...

Aduino中eps环境搭建

这里只记录Arduino2.0以后版本&#xff1a;如果有外网环境&#xff0c;那么可以轻松搜到ESP32开发板环境并安装&#xff0c;如果没有&#xff0c;那就见下面操作&#xff1a; 进入首选项&#xff0c;将esp8266的国内镜像地址填入&#xff0c;然后保存&#xff0c;在开发板中查…...

python——案例二 求两个数的和

#案例二 求两个数的和 num1input(请输入第一个数字&#xff1a;) num2input(请输入第二个数字&#xff1a;) sumfloat(num1)float(num2) #计算公式 print(sum) #显示结果 输入num11、num22得到结果sum3...

一文了解 Android 车机如何处理中控的旋钮输入?

前言 上篇文章《从实体按键看 Android 车载的自定义事件机制》带大家了解了 Android 车机支持自定义输入的机制 CustomInputService。事实上&#xff0c;除了支持自定义事件&#xff0c;对于中控上常见的音量控制、焦点控制的旋钮事件&#xff0c;Android 车机也是支持的。 那…...

郑州网站建设方案报价/网上做广告推广

在phpmyadmin中显示上传文件最大问8M&#xff0c;可我的单个sql文件超过50M&#xff0c;尝试上传&#xff0c;上传依然可以进行&#xff0c;但上传到8M即操作终止&#xff0c;phpmyadmin提示上传失败。该问题实际上是php的配置问题&#xff0c;在php.ini中有两个参数需要关注&a…...

国内网站服务器/站内免费推广有哪些

开篇介绍 个人背景&#xff1a; 不说太多废话&#xff0c;但起码要让你先对我有一个基本的了解。本人毕业于浙江某二本院校&#xff0c;算是科班出身&#xff0c;毕业后就进了一家外包公司做开发&#xff0c;当然不是阿里的外包&#xff0c;具体什么公司就不透露了&#xff0…...

wordpress区分移动站/企业网站注册

项目场景&#xff1a; element UI 利用指令loading制作局部loading。 问题描述&#xff1a; 需要制作局部loading的页面很长&#xff0c;loading 时可以向下滚动&#xff0c;导致遮罩层遮挡不了页面的下半部分。 解决方案&#xff1a; 在loading时利用overflow: hidden;禁止…...

进一步加强网站内容建设/网站建设优化推广

IDC评述网&#xff08;idcps.com&#xff09;03月26日报道&#xff1a;据国外域名统计机构DailyChanges公布的最新实时数据显示&#xff0c;3月24日&#xff0c;在全球增长最快的十家域名解析服务商中&#xff0c;中国仅占据三个席位。上榜的中国域名解析商分别是&#xff1a;5…...

wordpress设置背景/合肥做网站哪家好

忙&#xff0c;还有两周多就要上线了&#xff0c;可是项目也就开始两周多。我现在估计已经快疯了。忙的时候&#xff0c;觉得自己实在是有心无力。如果说blog是一块镜子&#xff0c;那么我希望逃避。夜深人静的时候&#xff0c;我都不逃避&#xff0c;但是我已经累的倒下了&…...

哪些网站可以医生做兼职/站长工具 seo查询

算法的定义 算法&#xff0c;是为了解决某类问题而规定的一个有限长度的操作序列&#xff0c;是指解题方案的准确而完整的描述&#xff0c;是一系列解决问题的清晰指令&#xff0c;算法代表着用系统的方法描述解决问题的策略机制。   也就是说&#xff0c;能够对一定规范的输…...