pytorch-01
加载mnist数据集
one-hot编码实现
import numpy as np
import torch
x_train = np.load("../dataset/mnist/x_train.npy") # 从网站提前下载数据集,并解压缩
y_train_label = np.load("../dataset/mnist/y_train_label.npy")
x = torch.tensor(y_train_label[:5],dtype=torch.int64) # 获取前5个样本的标签数据
# 定义一个张量输入,因为此时有 5 个数值,且最大值为9,类别数为10
# 所以我们可以得到 y 的输出结果的形状为 shape=(5,10),即5行12列
y = torch.nn.functional.one_hot(x, 10) # 一个参数张量x,10为类别数
print(y)
对于拥有6000个样本的MNIST数据集来说,标签就是一个大小的矩阵张量。
多层感知机模型
#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = torch.nn.Flatten() # 拉平图像矩阵self.linear_relu_stack = torch.nn.Sequential(torch.nn.Linear(28*28,312), # 输入大小为28*28,输出大小为312维的线性变换层torch.nn.ReLU(), # 激活函数层torch.nn.Linear(312, 256),torch.nn.ReLU(),torch.nn.Linear(256, 10) # 最终输出大小为10,对应one-hot标签维度)def forward(self, input): # 构建网络x = self.flatten(input) #拉平矩阵为1维logits = self.linear_relu_stack(x) # 多层感知机return logits
损失函数
优化函数
model = NeuralNetwork()
loss_fu = torch.nn.CrossEntropyLoss() # 交叉熵损失函数,内置了softmax函数,
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5) #设定优化函数loss = loss_fu(pred,label_batch) # 计算损失
完整模型
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #指定GPU编
import torch
import numpy as npbatch_size = 320 #设定每次训练的批次数
epochs = 1024 #设定训练次数#device = "cpu" #Pytorch的特性,需要指定计算的硬件,如果没有GPU的存在,就使用CPU进行计算
device = "cuda" #在这里读者默认使用GPU,如果读者出现运行问题可以将其改成cpu模式#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = torch.nn.Flatten()self.linear_relu_stack = torch.nn.Sequential(torch.nn.Linear(28*28,312),torch.nn.ReLU(),torch.nn.Linear(312, 256),torch.nn.ReLU(),torch.nn.Linear(256, 10))def forward(self, input):x = self.flatten(input)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork()
model = model.to(device) #将计算模型传入GPU硬件等待计算
torch.save(model, './model.pth')
#model = torch.compile(model) #Pytorch2.0的特性,加速计算速度
loss_fu = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5) #设定优化函数#载入数据
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")train_num = len(x_train)//batch_size#开始计算
for epoch in range(20):train_loss = 0for i in range(train_num):start = i * batch_sizeend = (i + 1) * batch_sizetrain_batch = torch.tensor(x_train[start:end]).to(device)label_batch = torch.tensor(y_train_label[start:end]).to(device)pred = model(train_batch)loss = loss_fu(pred,label_batch)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item() # 记录每个批次的损失值# 计算并打印损失值train_loss /= train_numaccuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_sizeprint("epoch:",epoch,"train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))
可视化模型结构和参数
model = NeuralNetwork()
print(model)
是对模型具体使用的函数及其对应的参数进行打印。
格式化显示:
param = list(model.parameters())
k=0
for i in param:l = 1print('该层结构:'+str(list(i.size())))for j in i.size():l*=jprint('该层参数和:'+str(l))k = k+l
print("总参数量:"+str(k))
模型保存
model = NeuralNetwork()
torch.save(model, './model.pth')
netron可视化
安装:pip install netron
运行:命令行输入netron
打开:通过网址http://localhost:8080打开
打开保存的模型文件model.pth:
点击颜色块,可以显示详细信息:
相关文章:
pytorch-01
加载mnist数据集 one-hot编码实现 import numpy as np import torch x_train np.load("../dataset/mnist/x_train.npy") # 从网站提前下载数据集,并解压缩 y_train_label np.load("../dataset/mnist/y_train_label.npy") x torch.tensor(y…...
梦想CAD二次开发
1.mxdraw简介 mxdraw是一个HTML5 Canvas JavaScript框架,它在THREE.js的基础上扩展开发,为用户提供了一套在前端绘图更为方便,快捷,高效率的解决方案,mxdraw的实质为一个前端二维绘图平台。你可以使用mxdraw在画布上绘…...
Eureka的介绍与使用
Eureka 是 Netflix 开源的一款服务注册与发现组件,在微服务架构中扮演着重要的角色。 一、Eureka 的介绍 工作原理 服务注册:各个微服务在启动时,会向 Eureka Server 发送注册请求,将自身的服务名、实例名、IP 地址、端口等信息注…...
ChatGPT之母:AI自动化将取代人类,创意性工作或将消失
目录 01 AI取代创意性工作的担忧 1.1 CTO说了啥 02 AI已开始大范围取代人类 01 AI取代创意性工作的担忧 几天前的采访中,OpenAI的CTO直言,AI可能会扼杀一些本来不应该存在的创意性工作。 近来一篇报道更是印证了这一观点。国外科技媒体的老板Miller用…...
【深度学习驱动流体力学】湍流仿真到深度学习湍流预测
目录 一、湍流项目结构二、三个OpenFOAM湍流算例1. motorBike背景和目的文件结构和关键文件使用和应用湍流仿真深度学习湍流预测深度学习湍流预测的挑战和应用结合湍流仿真与深度学习2. pitzDaily背景和目的文件结构和关键文件使用和应用3. pitzDailyMapped背景和目的文件结构和…...
如何从0构建一款类似pytest的工具
Pytest主要模块 Pytest 是一个强大且灵活的测试框架,它通过一系列步骤来发现和运行测试。其核心工作原理包括以下几个方面:测试发现:Pytest 会遍历指定目录下的所有文件,找到以 test_ 开头或 _test.py 结尾的文件,并且…...
6.27-6.29 旧c语言
#include<stdio.h> struct stu {int num;float score;struct stu *next; }; void main() {struct stu a,b,c,*head;//静态链表a.num 1;a.score 10;b.num 2;b.score 20;c.num 3;c.score 30;head &a;a.next &b;b.next &c;do{printf("%d,%5.1f\n&…...
Unidbg调用-补环境V3-Hook
结合IDA和unidbg,可以在so的执行过程进行Hook,这样可以让我们了解并分析具体的执行步骤。 应用场景:基于unidbg调试执行步骤 或 还原算法(以Hookzz为例)。 1.大姨妈 1.1 0x1DA0 public void hook1() {...
从AICore到TensorCore:华为910B与NVIDIA A100全面分析
华为NPU 910B与NVIDIA GPU A100性能对比,从AICore到TensorCore,展现各自计算核心优势。 AI 2.0浪潮汹涌而来,若仍将其与区块链等量齐观,视作炒作泡沫,则将错失新时代的巨大机遇。现在,就是把握AI时代的关键…...
Edge 浏览器退出后,后台占用问题
Edge 浏览器退出后,后台占用问题 环境 windows 11 Microsoft Edge版本 126.0.2592.68 (正式版本) (64 位)详情 在关闭Edge软件后,查看后台,还占用很多系统资源。实在不明白,关了浏览器还不能全关了,微软也学流氓了。…...
实验八 T_SQL编程
题目 以电子商务系统数据库ecommerce为例 1、在ecommerce数据库,针对会员表member首先创建一个“呼和浩特地区”会员的视图view_hohhot,然后通过该视图查询来自“呼和浩特”地区的会员信息,用批处理命令语句将问题进行分割,并分…...
【爆肝34万字】从零开始学Python第2天: 判断语句【入门到放弃】
目录 前言判断语句True、False简单使用作用 比较运算符引入比较运算符的分类比较运算符的结果示例代码总结 逻辑运算符引入逻辑运算符的简单使用逻辑运算符与比较运算符一起使用特殊情况下的逻辑运算符 if 判断语句引入基本使用案例演示案例补充随堂练习 else 判断子句引入else…...
React 19 新特性集合
前言:https://juejin.cn/post/7337207433868197915 新 React 版本信息 伴随 React v19 Beta 的发布,React v18.3 也一并发布。 React v18.3相比最后一个 React v18 的版本 v18.2 ,v18.3 添加了一些警告提示,便于尽早发现问题&a…...
耐高温水位传感器有哪些
耐高温水位传感器在现代液位检测技术中扮演着重要角色,特别适用于需要高温环境下稳定工作的应用场合。这类传感器的设计和材质选择对其性能和可靠性至关重要。 一种典型的耐高温水位传感器是FS-IR2016D,它采用了PPSU作为主要材质。PPSU具有优良的耐高温…...
Symfony国际化与本地化:打造多语言应用的秘诀
标题:Symfony国际化与本地化:打造多语言应用的秘诀 摘要 Symfony是一个高度灵活的PHP框架,用于创建Web应用程序。它提供了强大的国际化(i18n)和本地化(l10n)功能,允许开发者轻松创…...
ApolloClient GraphQL 与 ReactNative
要在 React Native 应用程序中设置使用 GraphQL 的简单示例,您需要遵循以下步骤: 设置一个 React Native 项目。安装 GraphQL 必要的依赖项。创建一个基本的 GraphQL 服务器(或使用公共 GraphQL 端点)。从 React Native 应用中的…...
【贡献法】2262. 字符串的总引力
本文涉及知识点 贡献法 LeetCode2262. 字符串的总引力 字符串的 引力 定义为:字符串中 不同 字符的数量。 例如,“abbca” 的引力为 3 ,因为其中有 3 个不同字符 ‘a’、‘b’ 和 ‘c’ 。 给你一个字符串 s ,返回 其所有子字符…...
C#基于SkiaSharp实现印章管理(3)
本系列第一篇文章中创建的基本框架限定了印章形状为矩形,但常用的印章有方形、圆形等多种形状,本文调整程序以支持定义并显示矩形、圆角矩形、圆形、椭圆等4种形式的印章背景形状。 定义印章背景形状枚举类型,矩形、圆形、椭圆相关的尺寸…...
如何理解泛型的编译期检查
既然说类型变量会在编译的时候擦除掉,那为什么我们往 ArrayList 创建的对象中添加整数会报错呢?不是说泛型变量String会在编译的时候变为Object类型吗?为什么不能存别的类型呢?既然类型擦除了,如何保证我们只能使用泛型…...
计算机组成原理:海明校验
在上图中,对绿色的7比特数据进行海明校验,需要添加紫色的4比特校验位,总共是蓝色的11比特。紫色的校验位pi分布于蓝色的hi的1, 2, 4, 8, 16, 32, 64位,是2i-1位。绿色的数据位bi分布于剩下的位。 在下图中,b1位于h3&a…...
信息学奥赛初赛天天练-39-CSP-J2021基础题-哈夫曼树、哈夫曼编码、贪心算法、满二叉树、完全二叉树、前中后缀表达式转换
PDF文档公众号回复关键字:20240629 2022 CSP-J 选择题 单项选择题(共15题,每题2分,共计30分:每题有且仅有一个正确选项) 5.对于入栈顺序为a,b,c,d,e的序列,下列( )不合法的出栈序列 A. a,b&a…...
第11章 规划过程组(收集需求)
第11章 规划过程组(一)11.3收集需求,在第三版教材第377~378页; 文字图片音频方式 第一个知识点:主要输出 1、需求跟踪矩阵 内容 业务需要、机会、目的和目标 项目目标 项目范围和 WBS 可…...
探索WebKit的守护神:深入Web安全策略
探索WebKit的守护神:深入Web安全策略 在数字化时代,网络已成为我们生活的一部分,而网页浏览器作为我们探索网络世界的窗口,其安全性至关重要。WebKit作为众多流行浏览器的内核,例如Safari,其安全性策略是保…...
unity ScrollRect裁剪ParticleSystem粒子
搜了下大概有这几种方法 通过模板缓存通过shader裁剪区域:案例一,案例二,案例三,三个案例都是类似的方法,需要在c#传入数据到shader通过插件 某乎上的模板缓存方法link,(没有登录看不到全文&a…...
凤仪亭 | 第7集 | 大丈夫生居天地之间,岂能郁郁久居人下 | 司徒一言,令我拨云见日,茅塞顿开 | 三国演义 | 逐鹿群雄
🙋大家好!我是毛毛张! 🌈个人首页: 神马都会亿点点的毛毛张 📌这篇博客分享的是《三国演义》文学剧本第Ⅰ部分《群雄逐鹿》的第7️⃣集《凤仪亭》的经典语句和文学剧本全集台词 文章目录 1.经典语句2.文学剧本台词 …...
React实战学习(一)_棋盘设计
需求: 左上侧:状态左下侧:棋盘,保证胜利就结束 和 下过来的不能在下右侧:“时光机”,保证可以回顾,索引 语法: 父子之间属性传递(props)子父组件传递(写法上&…...
【LeetCode】每日一题:三数之和
解题思路 最开始是打算沿着二数之和的思路做,即固定了最大的,然后小的开始遍历,因为这种遍历方式只需要遍历一轮就能完成,所以复杂度应该是O(n2),但是最后几个示例还是超时了,可能进…...
逆风而行:提升逆商,让困难成为你前进的动力
一、引言 生活,总是充满了未知与变数。有时,我们会遇到阳光明媚的日子,享受着宁静与和谐;但更多时候,我们却不得不面对那些突如其来的坏事件,如工作的挫折、人际关系的困扰、健康的挑战等。这些事件如同突…...
新能源汽车CAN总线故障定位与干扰排除的几个方法
CAN总线是目前最受欢迎的现场总线之一,在新能源车中有广泛应用。新能源车的CAN总线故障和隐患将影响驾驶体验甚至行车安全,如何进行CAN总线故障定位及干扰排除呢? 目前,国内机动车保有量已经突破三亿大关。由于大量的燃油车带来严峻的环境问题,因此全面禁售燃油车的日程在…...
【涵子来信】——社交宝典:克服你心中的内向,世界总有缺陷
内向,你是内向的吗?想必每个人不同,面对的情形也是不同的。 暑假是一个很好的机会,我是可以去多社交社交。但是,面对着CSDN上这么多技术人er,那么,我的宝典,对于大家,有…...
saharan wordpress/市场营销策划案例经典大全
用户放弃购物车是B2C电商系统平台的噩梦。消费者由于各种原因中途放弃他们的购物车。这些原因包括:过程中断,繁琐的UI和其他简单的改变了主意。我们将看看用户放弃购物车的三大原因以及如何减少购物车放弃。一、缺乏精心设计的UX 不太注意用户在购物过程…...
图标不显示wordpress/什么是搜索引擎营销?
我们都知道,管理信息系统类的项目报表的位置是何等重要,业务运营数据最后给领导的反应就是那么几张综合的业务数据报表,我从事软件开发的这八、九年中,98%的项目都是管理信息系统项目,都时时被报表纠结着,早年用VB开发…...
佛山企业网站/免费b2b网站推广渠道
#include <stdio.h> #include <stdlib.h> #include <malloc.h>int main(int argc,char* argv[]) {int* ptr;ptr (int*)malloc(sizeof(int) * 128);printf("%x \n", ptr);return 0; }...
网站病毒怎么做/产品推广方案
css中的z-index用法详解...
网站设计报价/灵宝seo公司
原标题:CAD实习心得体会cad实习心得体会一、课程实习的目的:把握autocad用于工程制图的基本操作,了解工程图纸绘制的格式和要求,能够用autocad绘制二维的工程图纸和简单的三维图纸。并简单绘制校本部大门十字路口地下通道。二、课程实习的任务…...
郑州网站建设最便宜/拼多多seo怎么优化
Intellil IDEA新建空项目选择第一个Java大家都知道,这里就不介绍了,我们选择最后一个Empty Project 选择在这个窗口打开还是新的窗口打开,这里自己选择就行 这时候又来到了熟悉的界面,可以看到最下面没有Empty Project了&am…...