当前位置: 首页 > news >正文

pytorch-多分类实战之手写数字识别

目录

  • 1. 网络设计
  • 2. 代码实现
    • 2.1 网络代码
    • 2.2 train
  • 3. 完整代码

1. 网络设计

输入是手写数字图片28x28,输出是10个分类0~9,有两个隐藏层,如下图所示:
在这里插入图片描述

2. 代码实现

2.1 网络代码

第一层将784降维到200,第二次使用200不降维,输出层200降维到10,每一层之后加一个激活函数relu,每一层都需要梯度信息所以requires_grad=True;
forward函数最后不要加softmax,因为后面CrossEntropyLoss中包含了softmax操作。
在这里插入图片描述

2.2 train

优化目标是w1、b1、w2、b2、w3、b3,使用SGD优化器,使用CrossEntropyLoss计算loss
在这里插入图片描述

3. 完整代码

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transformsbatch_size=200
learning_rate=0.01
epochs=10train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)w1, b1 = torch.randn(200, 784, requires_grad=True),\torch.zeros(200, requires_grad=True)
w2, b2 = torch.randn(200, 200, requires_grad=True),\torch.zeros(200, requires_grad=True)
w3, b3 = torch.randn(10, 200, requires_grad=True),\torch.zeros(10, requires_grad=True)# torch.nn.init.kaiming_normal_(w1)
# torch.nn.init.kaiming_normal_(w2)
# torch.nn.init.kaiming_normal_(w3)def forward(x):x = x@w1.t() + b1x = F.relu(x)x = x@w2.t() + b2x = F.relu(x)x = x@w3.t() + b3x = F.relu(x)return xoptimizer = optim.SGD([w1, b1, w2, b2, w3, b3], lr=learning_rate)
criteon = nn.CrossEntropyLoss()for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)logits = forward(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)logits = forward(data)test_loss += criteon(logits, target).item()pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

如下图:
未使用torch.nn.init.kaiming_normal_(w1)初始化参数的情况,可以看出Loss在2.302585后就不下降了。
在这里插入图片描述
如下图:使用了torch.nn.init.kaiming_normal_(w1)初始化参数的情况下,Loss下降还是比较快的。
在这里插入图片描述
因此使用好的初始化参数对网络的训练起到至关重要的作用

相关文章:

pytorch-多分类实战之手写数字识别

目录 1. 网络设计2. 代码实现2.1 网络代码2.2 train 3. 完整代码 1. 网络设计 输入是手写数字图片28x28,输出是10个分类0~9,有两个隐藏层,如下图所示: 2. 代码实现 2.1 网络代码 第一层将784降维到200,第二次使用…...

httpsok-快速申请谷歌SSL免费证书

🔥httpsok-快速申请谷歌SSL免费证书 使用场景: 部署CDN证书、OSS云存储证书证书类型: 单域名 多域名 通配符域名 混合域名证书厂商: ZeroSSL Lets Encrypt Google证书加密类型: ECC、 RSA 一、证书管理 进入 证书管…...

LiveGBS流媒体平台GB/T28181功能-国标级联中如何自定义通道国标编号编辑通道编号保持唯一性

LiveGBS国标级联中如何自定义通道国标编号编辑通道编号保持唯一性 1、国标级联选择通道修改2、通道编辑修改3、分屏展示设备树修改3.1、编辑名称中修改 4、分屏展示分组修改4.1、编辑名称中修改4.2、选择通道中修改 5、搭建GB28181视频直播平台 1、国标级联选择通道修改 国标级…...

mysql 大表凌晨定时删除数据

有几张表数据量非常大,一次维护量有点大(一个月有500多万条数据,并且还在往上涨), 于是想了个定时删除数据,每天凌晨执行,这样每天删除数据量就小, 循环删除,每次删除5…...

ArcGIS和ArcGIS Pro快速加载ArcGIS历史影像World Imagery Wayback

ArcGIS在线历史影像网站 World Imagery Wayback(网址:https://livingatlas.arcgis.com/wayback/)提供了数期历史影像在线浏览服务,之前不少自媒体作者在文中宣称其能代表Google Earth历史影像。 1、一点对比 (1)同一级别下的版本覆盖面 以下述区域为例,自2014年2月20…...

数据仓库的—数据仓库的体系架构

数据仓库通常采用分层的体系架构设计,作为支撑企业决策分析需求的数据基础设施。典型的数据仓库体系架构由以下三个核心层次组成: 源数据层(Source Layer) 这是数据仓库的数据来源,包括组织内部的各种运营系统,如ERP、CRM、SCM等,以及外部数据源如互联网、社交媒体等。这些系…...

【C/C++基础知识】const 关键字

文章目录 Q&A and 前言const 修饰基本变量初始化const 对象仅在文件内有效 const 的引用应用 指针与 constconst 修饰类成员函数参考写在最后 Q&A and 前言 Q:简要说一说 C 中的 const 关键字,含义以及常见的使用位置 A:const 是 C…...

Docker之数据卷和Dockerfile

Docker之数据卷与Dockerfile的详细使用介绍如下: 一、Docker数据卷 数据卷(volume)是Docker中的一个重要概念,它允许你在容器和宿主机或容器之间共享文件系统。数据卷提供了持久性存储,即使在容器被删除后&#xff0…...

pull拉取最新代码

工作区、暂存区、版本库 工作区:就是你在电脑里能看到的目录。 暂存区:英文叫 stage 或 index。一般存放在 .git 目录下的 index 文件(.git/index)中,所以我们把暂存区有时也叫作索引(index)。 …...

工控 modbusTCP 报文

Tx 发送报文:00 C9 00 00 00 06 01 03 00 00 00 02 Rx 接收报文:00 C9 00 00 00 07 01 03 04 01 4D 00 01 Tx 发送报文:00 C9 00 00 00 06 01 03 00 00 00 02 00 C9 事务处理标识符 2字节 00 00 协议标识符 2字节 固定 00 00 00 06 长度 2字节 表示之后的字节总数 (…...

在Ubuntu服务器上快速安装一个redis并提供远程服务

一、快速安装一个Redis 第一步:更新apt源 sudo apt update第二步:下载Redis sudo apt install redis第三步:查看Redis是否已自启动 systemctl status redis二、配置Redis提供远程服务 第一步:先确保6379端口正常开放 如果是…...

玩机进阶教程------手机定制机 定制系统 解除系统安装软件限制的一些步骤解析

定制机 在于各工作室与商家合作定制rom中有一些定制机。限制用户私自安装第三方软件。或者限制解锁 。无法如正常机登陆账号等等。定制机一般用于固定行业或者一些部门。专机专用。例如很多巴枪扫描机型等等。或者一些小牌机型。对于没有官方包的机型首先要导出各个分区来制作…...

Bilstm双向长短期神经网络多输入单输出回归分析

目录 背影 摘要 LSTM的基本定义 LSTM实现的步骤 BILSTM神经网络 Bilstm双向长短期神经网络多输入单输出回归分析 完整代码: Bilstm双向长短期神经网络多输入单输出回归分析.zip资源-CSDN文库 https://download.csdn.net/download/abc991835105/89087121 效果图 结果分析 展望 …...

ELK+Filebeat日志分析系统

一、ELK基本介绍: 1.ELK 简介: ELK平台是一套完整的日志集中处理解决方案(日志系统)。 将 ElasticSearch、Logstash 和 Kiabana 三个开源工具配合使用, 完成更强大的用户对日志的查询、排序、统计需求。 ELK --> ELFK --> ELFKMQ2.ELK组件介绍…...

flex吃干抹净

Flex 布局是什么? Flex 是 Flexible Box 的缩写,意为"弹性布局",用来为盒状模型提供最大的灵活性。 .box{display: flex;//行内元素也可以使用flex布局//display: inline-flex; }display: flex; 使元素呈现为块级元素,…...

【单片机毕业设计8-基于stm32c8t6的RFID校园门禁系统】

【单片机毕业设计8-基于stm32c8t6的RFID校园门禁系统】 前言一、功能介绍二、硬件部分三、软件部分总结 前言 🔥这里是小殷学长,单片机毕业设计篇8基于stm32的RFID校园门禁系统 🧿创作不易,拒绝白嫖可私 一、功能介绍 -----------…...

uni-app web端使用getUserMedia,摄像头拍照

<template><view><video id"video"></video></view> </template> 摄像头显示在video标签上 var opts {audio: false,video: true }navigator.mediaDevices.getUserMedia(opts).then((stream)> {video document.querySelec…...

2024-简单点-观察者模式

先看代码&#xff1a; # 导入未来模块以支持类型注解 from __future__ import annotations# 导入抽象基类模块和随机数生成器 from abc import ABC, abstractmethod from random import randrange# 导入列表类型注解 from typing import List# 定义观察者模式中的主体接口&…...

STM32—DMA直接存储器访问详解

DMA——直接存储器访问 DMA&#xff1a;Data Memory Access, 直接存储器访问。 DMA和我们之前学过的串口、GPIO都是类似的&#xff0c;都是STM32中的一个外设。串口是用来发送通信数据的&#xff0c;而DMA则是用来把数据从一个地方搬到另一个地方&#xff0c;而且不占用CPU。…...

【JavaEE初阶系列】——网络编程 TCP客户端/服务器 程序实现

目录 &#x1f6a9;TCP流套接字编程 &#x1f36d;ServerSocket API &#x1f36d;Socket API &#x1f36d;TCP服务器 &#x1f36d;TCP客户端 &#x1f6a9;TCP流套接字编程 俩个关键的类 ServerSocket (给服务器使用的类&#xff0c;使用这个类来绑定端口号&#xff0…...

CMake构建OpenCv并导入QT项目过程中出现的问题汇总

前言 再此之前请确保你的环境变量是否配置&#xff0c;这是总共需要配置的环境变量 E:\cmake\bin E:\OpenCv\opencv\build\x64\vc15\bin F:\Qt\Tools\mingw730_64\bin F:\Qt\5.12.4\mingw73_64\bin 问题一&#xff1a; CMake Error: CMake was unable to find a build program…...

AcWing 796. 子矩阵的和——算法基础课题解

AcWing 796. 子矩阵的和 题目描述 输入一个 n 行 m 列的整数矩阵&#xff0c;再输入 q 个询问&#xff0c;每个询问包含四个整数 x1,y1,x2,y2&#xff0c;表示一个子矩阵的左上角坐标和右下角坐标。 对于每个询问输出子矩阵中所有数的和。 输入格式 第一行包含三个整数 n&…...

macos 查看 远程服务器是否开放某个端口

想要使用mac查看远程服务器某个端口是否开发&#xff0c;可通过 nc 命令&#xff0c;如下&#xff1a; nc -zv <服务器IP> <端口号>如果该端口开发&#xff0c;结果为&#xff1a;succeeded! Connection to <服务器IP> port <端口号> [类型] succeed…...

GraphQL注入

GraphQL概述 GraphQL是一种查询语言&#xff0c;用于API设计和数据交互&#xff0c;不仅仅用于查询数据库。GraphQL 允许客户端在一个请求中明确地指定需要的数据&#xff0c;并返回预期的结果&#xff1b;并且将数据查询和数据修改分离开&#xff0c;大大增加灵活性。GraphQL…...

以太坊源码阅读01

正所谓区块链&#xff0c;怎能不熟悉区块的数据结构呢&#xff1f;区块的结构体被保存在core/types/block.go文件中&#xff0c;下面是我截取出来的&#xff1a; type Block struct {header *Headeruncles []*Headertransactions Transactionswithdrawals Withdr…...

Spark-Scala语言实战(15)

在之前的文章中&#xff0c;我们学习了如何在spark中使用键值对中的学习键值对方法中的lookup&#xff0c;cogroup两种方法。想了解的朋友可以查看这篇文章。同时&#xff0c;希望我的文章能帮助到你&#xff0c;如果觉得我的文章写的不错&#xff0c;请留下你宝贵的点赞&#…...

【SpringBoot XSS存储漏洞 拦截器】Java纯后端对于前台输入值的拦截校验实现 一个类加一个注解结束

先看效果&#xff1a; 1.js注入拦截&#xff1a; 2.sql注入拦截 生效只需要两步&#xff1a; 1.创建Filter类&#xff0c;粘贴如下代码&#xff1a; package cn.你的包命.filter; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.IO…...

【微信小程序】canvas开发笔记

【微信小程序】canvasToTempFilePath:fail fail canvas is empty 看说明书 最好是先看一下官方文档点此前往 如果是canvas 2d 写canvas: this.canvas,&#xff0c;如果是旧版写canvasId: ***, 解决问题 修改对应的代码&#xff0c;如下所示&#xff0c;然后再试试运行&#x…...

TripoSR: Fast 3D Object Reconstruction from a Single Image 论文阅读

1 Abstract TripoSR的核心是一个基于变换器的架构&#xff0c;专为单图像3D重建设计。它接受单张RGB图像作为输入&#xff0c;并输出图像中物体的3D表示。TripoSR的核心包括&#xff1a;图像编码器、图像到三平面解码器和基于三平面的神经辐射场&#xff08;NeRF&#xff09;。…...

u盘为什么一插上电脑就蓝屏,u盘一插电脑就蓝屏

u盘之前还好好的&#xff0c;可以传输文件&#xff0c;使用正常&#xff0c;但是最近使用时却出现问题了。只要将u盘一插入电脑&#xff0c;电脑就显示蓝屏。u盘为什么一插上电脑就蓝屏呢?一般&#xff0c;导致的原因有以下几种。一&#xff0c;主板的SATA或IDE控制器驱动损坏…...

巴彦淖尔市 网站建设/百度宁波运营中心

2017跟着小虎玩着去软考--项目管理师系列 趣味好玩解析2015年下半年信息系统项目管理师上午试题21-25题【小虎梦想】小虎有一个梦想&#xff0c;希望像专业的计算机考试&#xff0c;如全国计算机软件技术与软件专业资格&#xff08;水平&#xff09;考试&#xff08;简称软考&a…...

服装网站的建设与管理/seo网络优化招聘

小编又来了&#xff01;&#xff01;今天给大家带来的是蘑菇街广告投放系统的建设概要。相信大部分需要做流量召回和广告投放的公司都会关注这部分系统的建设。这里面也是不断在效果和成本上进行平衡&#xff0c;这次邀请了蘑菇街广告投放技术负责人腾哲给大家分享他的一些经验…...

建站网站官方/网站建设规划书

frp中文文档&#xff1a;https://github.com/fatedier/frp/blob/v0.14.0/README_zh.md frp配置文件下载&#xff1a;https://github.com/fatedier/frp/releases/tag/v0.21.0 配置文件下载说明&#xff1a; linux服务器&#xff1a;frp_0.21.0_linux_amd64.tar.gz 树莓派&#x…...

武汉网站推广¥做下拉去118cr/宣传软文是什么意思

“我可能干了个假的数据分析师&#xff01;”经常有同学发出这种感慨&#xff0c;然后到处发《数据分析师是干什么的》《数据分析师、数据工程师、数据运营、数据挖掘工程师、商业数据分析师、我随便写个什么分析师之间到底有什么区别》一类的帖子。之所以会这样&#xff0c;是…...

seo 排名/seo观察网

当当当~时光荏苒&#xff0c;岁月如梭&#xff0c;年终将至&#xff0c;有些财年节点设置在自然年底的公司&#xff0c;已经开始一年一度的升职加薪年终总结了。项目经理每天在客户现场背锅填坑沟通协调记录、整理和分析已经不易&#xff0c;再提取亮点和价值可能更无从下手。因…...

大连网站制作报价/seo全网营销的方式

实现网站的深度和运动效果有很多种方式&#xff0c;例如有的网站使用视差滚动&#xff08;Parallax Scrolling&#xff09;&#xff0c;有的是用Flash动画。不管采用什么技术&#xff0c;伪深度&#xff08;或者运动&#xff09;效果能够让网站更具互动性&#xff0c;更有趣。今…...