当前位置: 首页 > news >正文

PyTorch官网demo解读——第一个神经网络(1)

神经网络如此神奇,feel the magic

今天分享一下学习PyTorch官网demo的心得,原来实现一个神经网络可以如此简单/简洁/高效,同时也感慨PyTorch如此强大。

这个demo的目的是训练一个识别手写数字的模型!

先上源码:
from pathlib import Path
import requests   # http请求库
import pickle
import gzipfrom matplotlib import pyplot   # 显示图像库import math
import numpy as np
import torch###########下载训练/验证数据######################################################
# 这里加载的是mnist数据集
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)URL = "https://github.com/pytorch/tutorials/raw/main/_static/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)###########解压并加载训练数据######################################################
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")# 通过pyplot显示数据集中的第一张图片
# 显示过程会中断运行,看到效果之后可以屏蔽掉,让调试更顺畅
#print("x_train[0]: ", x_train[0])
#pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
#pyplot.show()# 将加载的数据转成tensor
x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape   # n是函数,c是列数
print("x_train.shape: ", x_train.shape)
print("y_train.min: {0}, y_train.max: {1}".format(y_train.min(), y_train.max()))# 初始化权重和偏差值,权重是随机出来的784*10的矩阵,偏差初始化为0
weights = torch.randn(784, 10) / math.sqrt(784)
weights.requires_grad_()
bias = torch.zeros(10, requires_grad=True)# 激活函数
def log_softmax(x):return x - x.exp().sum(-1).log().unsqueeze(-1)# 定义模型:y = wx + b
# 实际上就是单层的Linear模型
def model(xb):return log_softmax(xb @ weights + bias)# 丢失函数 loss function
def nll(input, target):return -input[range(target.shape[0]), target].mean()
loss_func = nll# 计算精度函数
def accuracy(out, yb):preds = torch.argmax(out, dim=1)return (preds == yb).float().mean()###########开始训练##################################################################
bs = 64  # 每一批数据的大小
lr = 0.5  # 学习率
epochs = 2  # how many epochs to train forfor epoch in range(epochs):for i in range((n - 1) // bs + 1):start_i = i * bsend_i = start_i + bsxb = x_train[start_i:end_i]yb = y_train[start_i:end_i]pred = model(xb) # 通过模型预测loss = loss_func(pred, yb) # 通过与实际结果比对,计算丢失值loss.backward() # 反向传播with torch.no_grad():weights -= weights.grad * lr  # 调整权重值bias -= bias.grad * lr  # 调整偏差值weights.grad.zero_()bias.grad.zero_()##########对比一下预测结果############################################################
xb = x_train[0:bs]  # 加载一批数据,这里用的是训练的数据,在实际应用中最好使用没训练过的数据来验证
yb = y_train[0:bs]  # 训练数据对应的正确结果
preds = model(xb)  # 使用训练之后的模型进行预测
print("################## after training ###################")
print("accuracy: ", accuracy(preds, yb))   # 打印出训练之后的精度
# print(preds[0])
print("pred value: ", torch.argmax(preds, dim=1))   # 打印预测的数字
print("real value: ", yb)   # 实际正确的数据,可以直观地和上一行打印地数据进行对比
运行结果:

可以看到训练后模型地预测精度达到了0.9531,已经不错了,毕竟只使用了一个单层地Linear模型;从输出地对比数据中可以看出有三个地方预测错了(红框标记地数字)

ok,今天先到这里,下一篇再来解读代码中地细节

附:

PyTorch官方源码:https://github.com/pytorch/tutorials/blob/main/beginner_source/nn_tutorial.py

天地一逆旅,同悲万古愁!

相关文章:

PyTorch官网demo解读——第一个神经网络(1)

神经网络如此神奇,feel the magic 今天分享一下学习PyTorch官网demo的心得,原来实现一个神经网络可以如此简单/简洁/高效,同时也感慨PyTorch如此强大。 这个demo的目的是训练一个识别手写数字的模型! 先上源码: fr…...

升华 RabbitMQ:解锁一致性哈希交换机的奥秘【RabbitMQ 十】

欢迎来到我的博客,代码的世界里,每一行都是一个故事 升华 RabbitMQ:解锁一致性哈希交换机的奥秘【RabbitMQ 十】 前言第一:该插件需求为什么需要一种更智能的消息路由方式?一致性哈希的基本概念: 第二&…...

vue3 element-plus 日期选择器 el-date-picker 汉化

vue3 项目中,element-plus 的日期选择器 el-date-picker 默认是英文版的,如下: 页面引入: //引入汉化语言包 import locale from "element-plus/lib/locale/lang/zh-cn" import { ElDatePicker, ElButton, ElConfigP…...

剑指 Offer(第2版)面试题 35:复杂链表的复制

剑指 Offer(第2版)面试题 35:复杂链表的复制 剑指 Offer(第2版)面试题 35:复杂链表的复制解法1:模拟 剑指 Offer(第2版)面试题 35:复杂链表的复制 题目来源&…...

自定义指令Custom Directives

<script setup langts> import { ref } from "vue"const state ref(false)/*** Implement the custom directive* Make sure the input element focuses/blurs when the state is toggled* */ // 以v开头的驼峰式命名的变量都可以作为一个自定义指令 const VF…...

预测性维护对制造企业设备管理的作用

制造企业设备管理和维护对于生产效率和成本控制至关重要。然而&#xff0c;传统的维护方法往往无法准确预测设备故障&#xff0c;导致生产中断和高额维修费用。为了应对这一挑战&#xff0c;越来越多的制造企业开始采用预测性维护技术。 预测性维护是通过传感器数据、机器学习和…...

华为、新华三、锐捷常用命令总结

华为、新华三、锐捷常用命令总结 一、华为交换机基础配置命令二、H3C交换机的基本配置三、锐捷交换机基础命令配置 一、华为交换机基础配置命令 1、创建vlan&#xff1a; <Quidway> //用户视图&#xff0c;也就是在Quidway模式下运行命令。 <Quidway>system-view…...

链路追踪详解(四):分布式链路追踪的事实标准 OpenTelemetry 概述

目录 OpenTelemetry 是什么&#xff1f; OpenTelemetry 的起源和目标 OpenTelemetry 主要特点和功能 OpenTelemetry 的核心组件 OpenTelemetry 的工作原理 OpenTelemetry 的特点 OpenTelemetry 的应用场景 小结 OpenTelemetry 是什么&#xff1f; OpenTelemetry 是一个…...

Node.js 工作线程与子进程:应该使用哪一个

Node.js 工作线程与子进程&#xff1a;应该使用哪一个 并行处理在计算密集型应用程序中起着至关重要的作用。例如&#xff0c;考虑一个确定给定数字是否为素数的应用程序。如果我们熟悉素数&#xff0c;我们就会知道必须从 1 遍历到该数的平方根才能确定它是否是素数&#xff…...

python matplotlib 三维图形添加文字且不随图形变动而变动

要在三维图形中添加文字并使其不随图形变动而变动&#xff0c;可以使用 annotate() 方法。这个方法可以在三维图形中添加文字&#xff0c;并且可以指定文字的位置、对齐方式和字体大小等属性。 下面是一个示例代码&#xff0c;演示如何在三维图形中添加文字&#xff1a; impo…...

Ubuntu设置kubelet启动脚本关闭swap分区

查看swap分区 swapon -s打开swap分区 swapon -a查看/etc/fstab下所有固化的swap分区&#xff0c;注释 vi /etc/fstab修改kubelet.conf文件 vi /etc/systemd/system/kubelet.service.d/10-kubeadm.conf添加 ExecStartPre/sbin/swapoff -a生效 systemctl daemon-reload sys…...

MySQL数据库存储

MySQL数据库存储 MySQL数据库简介MySQL开发环境MySQL安装图形化界面工具Navicat使用 表的操作表的概念3.2 创建表3.3 修改表 数据的操作-增删改查4.1 增加数据4.2 删除数据4.3 修改数据4.4 查询数据4.4.1 基础查询4.4.2 分组查询和聚合函数4.4.4 having语句4.4.5 排序4.5 多表联…...

verilog语法进阶,时钟原语

概述&#xff1a; 内容 1. 时钟缓冲 2. 输入时钟缓冲 3. ODDR2作为输出时钟缓冲 1. 输入时钟缓冲 BUFGP verilog c代码&#xff0c;clk作为触发器的边沿触发&#xff0c;会自动将clk综合成时钟信号。 module primitive1(input clk,input a,output reg y); always (posed…...

案例069:基于微信小程序的计算机实验室排课与查询系统

文末获取源码 开发语言&#xff1a;Java 框架&#xff1a;SSM JDK版本&#xff1a;JDK1.8 数据库&#xff1a;mysql 5.7 开发软件&#xff1a;eclipse/myeclipse/idea Maven包&#xff1a;Maven3.5.4 小程序框架&#xff1a;uniapp 小程序开发软件&#xff1a;HBuilder X 小程序…...

C语言:将三个数从大到小输出

#include<stdio.h> int main() {int a 0;int b 0;int c 0;printf("请输入abc的值&#xff1a;");scanf_s("%d%d%d", &a, &b, &c);if (b > a){int tmp a;a b;b tmp;}if (c > a){int tmp a;a c;c tmp;}if (b < c){int t…...

基于Hadoop的铁路货运大数据平台设计与应用

完整下载&#xff1a;基于Hadoop的铁路货运大数据平台设计与应用 基于Hadoop的铁路货运大数据平台设计与应用 Design and Application of Railway Freight Big Data Platform based on Hadoop 目录 目录 2 摘要 3 关键词 4 第一章 绪论 4 1.1 研究背景 4 1.2 研究目的与意义 5 …...

Java基础题2:类和对象

1.下面代码的运行结果是&#xff08;&#xff09; public static void main(String[] args){String s;System.out.println("s"s);}A.代码编程成功&#xff0c;并输出”s” B.代码编译成功&#xff0c;并输出”snull” C.由于String s没有初始化&#xff0c;代码不能…...

冒泡排序学习

冒泡排序&#xff08;Bubble Sort&#xff09;是一种简单的排序算法&#xff0c;它通过重复地交换相邻的元素来排序。具体实现如下&#xff1a; 1. 从待排序的数组中的第一个元素开始&#xff0c;依次比较相邻的两个元素。 2. 如果前一个元素大于后一个元素&#xff0c;则交换…...

LeetCode(65)LRU 缓存【链表】【中等】

目录 1.题目2.答案3.提交结果截图 链接&#xff1a; LRU 缓存 1.题目 请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 实现 LRUCache 类&#xff1a; LRUCache(int capacity) 以 正整数 作为容量 capacity 初始化 LRU 缓存int get(int key) 如果关键字 k…...

网站提示“不安全”

当你在浏览网站时&#xff0c;有时可能会遇到浏览器提示网站不安全的情况。这通常是由于网站缺乏SSL证书所致。那么&#xff0c;从SSL证书的角度出发&#xff0c;我们应该如何解决这个问题呢&#xff1f; 首先&#xff0c;让我们简单了解一下SSL证书。SSL证书是一种用于保护网站…...

【Linux】驱动

驱动 驱动程序过程 系统调用 用户空间 内核空间 添加驱动和调用驱动 驱动程序是如何调用设备硬件 驱动 在计算机领域&#xff0c;驱动&#xff08;Driver&#xff09;是一种软件&#xff0c;它充当硬件设备与操作系统之间的桥梁&#xff0c;允许它们进行通信和协同工作。驱动程…...

Java研学-HTML

HTML 1 介绍 HTML(Hypertext Markup Language) 超文本标记语言。静态网页&#xff0c;用于在浏览器上显示数据 超文本: 指页面内可以包含图片、链接&#xff0c;甚至音乐、程序等非文字元素。 标记语言: 使用 < > 括起来的语言 超文本标记语言的结构, 包括“头”部分&am…...

SpringBoot之响应的详细解析

2. 响应 前面我们学习过HTTL协议的交互方式&#xff1a;请求响应模式&#xff08;有请求就有响应&#xff09; 那么Controller程序呢&#xff0c;除了接收请求外&#xff0c;还可以进行响应。 2.1 ResponseBody 在我们前面所编写的controller方法中&#xff0c;都已经设置了…...

redis:四、双写一致性的原理和解决方案(延时双删、分布式锁、异步通知MQ/canal)、面试回答模板

双写一致性 场景导入 如果现在有个数据要更新&#xff0c;是先删除缓存&#xff0c;还是先操作数据库呢&#xff1f;当多个线程同时进行访问数据的操作&#xff0c;又是什么情况呢&#xff1f; 以先删除缓存&#xff0c;再操作数据库为例 多个线程运行的正常的流程应该如下…...

智能优化算法应用:基于动物迁徙算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于动物迁徙算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于动物迁徙算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.动物迁徙算法4.实验参数设定5.算法结果6.…...

illuminate/database 使用 五

之前文章&#xff1a; illuminate/database 使用 一-CSDN博客 illuminate/database 使用 二-CSDN博客 illuminate/database 使用 三-CSDN博客 illuminate/database 使用 四-CSDN博客 一、原生查询 1.1 原理 根据之前内容调用执行的静态类为Illuminate\Database\Capsule\M…...

武汉灰京文化:益智游戏的教育与娱乐完美结合

随着游戏技术的不断发展&#xff0c;益智类游戏正经历着一场革命性的变革&#xff0c;逐渐融合了教育和娱乐的元素。创新的设计和互动方式使得许多益智游戏成为了知识传递和技能训练的有效工具&#xff0c;同时保持了娱乐体验的趣味性。这种教育与娱乐的完美结合不仅使益智游戏…...

arcgis api for js 中的query实现数据查询

相当于服务地址中的query查询 获取图层范围内的数据4.24 import Query from arcgis/core/rest/support/Query; import * as QueryTask from "arcgis/core/rest/query";//获取图层范围内的数据4.24 _returnFeatureFromWhere(url, where, geo) {const self thisretu…...

AcWing 1250. 格子游戏(并查集)

题目链接 活动 - AcWing本课程系统讲解常用算法与数据结构的应用方式与技巧。https://www.acwing.com/problem/content/1252/ 题解 当两个点已经是在同一个连通块中&#xff0c;再连一条边&#xff0c;就围成一个封闭的圈。一般用x * n y的形式将&#xff08;x, y&#xff0…...

CSS对文本的简单修饰

CSS格式&#xff1a; 格式一&#xff1a;在head中的style标签范围内。 < style> 在style内的只写名字不写 &#xff1a; < > 选择器 { 属性的名称 &#xff1a; 样式&#xff1b; 属性的名称&#xff1a;样式&#xff1b; } < style> style中的注释用/* *…...

wordpress 网店 主题/打开百度网页版

打比方一个类里边有多个内部类&#xff0c;怎样获取该类里边指定的某一个内部类public class FactoryTest {Testpublic void test2(){FactoryTest factoryTest new FactoryTest();Class extends FactoryTest> clazz factoryTest.getClass();Class>[] classes clazz.ge…...

wordpress EDD Alipay/关键词seo排名公司

在亚马逊云科技&#xff0c;有着这么一群人&#xff0c;他们经常被认为只会写代码&#xff0c;而不善言辞。但这只是大家对他们的误解。他们的工作不仅需要懂开发、善沟通&#xff0c;还需要能够dive deep用户的需求。他们就是亚马逊云科技的 Software Dev Engineer&#xff01…...

php做动态网站/seo搜索引擎优化技术

SQL Server 数据库启动过程&#xff0c;以及启动不起来的各种问题的分析及解决技巧参考文章&#xff1a; &#xff08;1&#xff09;SQL Server 数据库启动过程&#xff0c;以及启动不起来的各种问题的分析及解决技巧 &#xff08;2&#xff09;https://www.cnblogs.com/VicL…...

本地wordpress 外网访问不了/百度收录网址

前言从智能单品到全屋智能&#xff0c;随着消费者对生活品质追求的提升&#xff0c;智能化产品逐渐走入大众家庭&#xff0c;从而推动智能家居市场蓬勃发展。从 2017 年开始&#xff0c;智能家居设备已经应用于日常生活各项任务。2017 年其市场规模约为 4.3 亿美元。据 IDC 预测…...

如何在宝塔中安装wordpress/北京网站建设公司哪家好

Runnable的缺陷 不能返回一个返回值也不能抛出checked Exception&#xff0c;只能trycatch Callable接口 类似于 Runnable&#xff0c;被其他线程执行的任务实现call方法有返回值 Future类 在并发编程中&#xff0c;我们经常用到非阻塞的模型&#xff0c;在之前的多线程的三…...

东莞建设网站推广公司地址/各种推广平台

http://blog.csdn.net/jay14/article/details/54074553 TLB原理 转载于:https://www.cnblogs.com/wangdgy/p/8465574.html...