当前位置: 首页 > news >正文

基于BP神经网络对MNIST数据集检测识别(numpy版本)

基于BP神经网络对MNIST数据集检测识别

  • 1.作者介绍
  • 2.BP神经网络介绍
    • 2.1 BP神经网络
  • 3.BP神经网络对MNIST数据集检测实验
    • 3.1 读取数据集
    • 3.2 前向传播
    • 3.3 损失函数
    • 3.4 构建神经网络
    • 3.5 训练
    • 3.6 模型推理
  • 4.完整代码

1.作者介绍

王凯,男,西安工程大学电子信息学院,2022级研究生
研究方向:机器视觉与人工智能
电子邮件:1794240761@qq.com

张思怡,女,西安工程大学电子信息学院,2022级研究生,张宏伟人工智能课题组
研究方向:机器视觉与人工智能
电子邮件:981664791@qq.com

2.BP神经网络介绍

2.1 BP神经网络

搭建一个两层(两个权重矩阵,一个隐藏层)的神经网络,其中输入节点和输出节点的个数是确定的,分别为 784 和 10。而隐藏层节点的个数还未确定,并没有明确要求隐藏层的节点个数,所以在这里取50个。现在神经网络的结构已经确定了,再看一下里面是怎么样的,这里画出了对一个数据的运算过程:
在这里插入图片描述
数学公式:
在这里插入图片描述

3.BP神经网络对MNIST数据集检测实验

3.1 读取数据集

安装numpy :pip install numpy
安装matplotlib pip install matplotlib
mnist是一个包含各种手写数字图片的数据集:其中有60000个训练数据和10000个测试时局,即60000个 train_img 和与之对应的 train_label,10000个 test_img 和 与之对应的test_label。
在这里插入图片描述
其中的 train_img 和 test_img 就是这种图片的形式,train_img 是为了训练神经网络算法的训练数据,test_img 是为了测试神经网络算法的测试数据,每一张图片为2828,将图片转换为2828=784个像素点,每个像素点的值为0到255,像素点值的大小代表灰度,从而构成一个1784的矩阵,作为神经网络的输入,而神经网络的输出形式为110的矩阵,个:eg:[0.01,0.01,0.01,0.04,0.8,0.01,0.1,0.01,0.01,0.01],矩阵里的数字代表神经网络预测值的概率,比如0.8代表第五个数的预测值概率。
其中 train_label 和 test_label 是 对应训练数据和测试数据的标签,可以理解为一个1*10的矩阵,用one-hot-vectors(只有正确解表示为1)表示,one_hot_label为True的情况下,标签作为one-hot数组返回,one-hot数组 例:[0,0,0,0,1,0,0,0,0,0],即矩阵里的数字1代表第五个数为True,也就是这个标签代表数字5。
数据集的读取:
load_mnist(normalize=True, flatten=True, one_hot_label=False):中,
normalize : 是否将图像的像素值正规化为0.0~1.0(将像素值正规化有利于提高精度)。flatten : 是否将图像展开为一维数组。 one_hot_label:是否采用one-hot表示。
在这里插入图片描述
完整代码及数据集下载:https://gitee.com/wang-kai-ya/bp.git

3.2 前向传播

前向传播时,我们可以构造一个函数,输入数据,输出预测。

def predict(self, x):w1, w2 = self.dict['w1'], self.dict['w2']b1, b2 = self.dict['b1'], self.dict['b2']a1 = np.dot(x, w1) + b1z1 = sigmoid(a1)a2 = np.dot(z1, w2) + b2y = softmax(a2)

3.3 损失函数

求出神经网络对一组数据的预测值,是一个1*10的矩阵。
在这里插入图片描述

其中,Yk表示的是第k个节点的预测值,Tk表示标签中第k个节点的one-hot值,举前面的eg:(手写数字5的图片预测值和5的标签)
Yk=[0.01,0.01,0.01,0.04,0.8,0.01,0.1,0.01,0.01,0.01]
Tk=[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]
值得一提的是,在交叉熵误差函数中,Tk的值只有一个1,其余为0,所以对于这个数据的交叉熵误差就为 E = -1(log0.8)。
在这里选用交叉熵误差作为损失函数,代码实现如下:

def loss(self, y, t):t = t.argmax(axis=1)num = y.shape[0]s = y[np.arange(num), t]return -np.sum(np.log(s)) / num

3.4 构建神经网络

前面我们定义了预测值predict, 损失函数loss, 识别精度accuracy, 梯度grad,下面构建一个神经网络的类,把这些方法添加到神经网络的类中:

for i in range(epoch):batch_mask = np.random.choice(train_size, batch_size)  # 从0到60000 随机选100个数x_batch = x_train[batch_mask]y_batch = net.predict(x_batch)t_batch = t_train[batch_mask]grad = net.gradient(x_batch, t_batch)for key in ('w1', 'b1', 'w2', 'b2'):net.dict[key] -= lr * grad[key]loss = net.loss(y_batch, t_batch)train_loss_list.append(loss)# 每批数据记录一次精度和当前的损失值if i % iter_per_epoch == 0:train_acc = net.accuracy(x_train, t_train)test_acc = net.accuracy(x_test, t_test)train_acc_list.append(train_acc)test_acc_list.append(test_acc)print('第' + str(i/600) + '次迭代''train_acc, test_acc, loss :| ' + str(train_acc) + ", " + str(test_acc) + ',' + str(loss))

3.5 训练

import numpy as np
import matplotlib.pyplot as plt
from TwoLayerNet import TwoLayerNet
from mnist import load_mnist(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
net = TwoLayerNet(input_size=784, hidden_size=50, output_size=10, weight_init_std=0.01)epoch = 20400
batch_size = 100
lr = 0.1train_size = x_train.shape[0]  # 60000
iter_per_epoch = max(train_size / batch_size, 1)  # 600train_loss_list = []
train_acc_list = []
test_acc_list = []

保存权重:

np.save('w1.npy', net.dict['w1'])
np.save('b1.npy', net.dict['b1'])
np.save('w2.npy', net.dict['w2'])
np.save('b2.npy', net.dict['b2'])

结果可视化:
在这里插入图片描述
在这里插入图片描述

3.6 模型推理

import numpy as np
from mnist import load_mnist
from functions import sigmoid, softmax
import cv2
######################################数据的预处理
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
batch_mask = np.random.choice(100,1)  # 从0到60000 随机选100个数
#print(batch_mask)
x_batch = x_train[batch_mask]#####################################转成图片
arr = x_batch.reshape(28,28)
cv2.imshow('wk',arr)
key = cv2.waitKey(10000)
#np.savetxt('batch_mask.txt',arr)
#print(x_batch)
#train_size = x_batch.shape[0]
#print(train_size)
########################################进入模型预测
w1 = np.load('w1.npy')
b1 = np.load('b1.npy')
w2 = np.load('w2.npy')
b2 = np.load('b2.npy')a1 = np.dot(x_batch,w1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1,w2) + b2
y = softmax(a2)
p = np.argmax(y, axis=1)print(p)

运行python reasoning.py
可以看到模型拥有较高的准确率。

4.完整代码

训练

import numpy as np
import matplotlib.pyplot as plt
from TwoLayerNet import TwoLayerNet
from mnist import load_mnist(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
net = TwoLayerNet(input_size=784, hidden_size=50, output_size=10, weight_init_std=0.01)epoch = 20400
batch_size = 100
lr = 0.1train_size = x_train.shape[0]  # 60000
iter_per_epoch = max(train_size / batch_size, 1)  # 600train_loss_list = []
train_acc_list = []
test_acc_list = []for i in range(epoch):batch_mask = np.random.choice(train_size, batch_size)  # 从0到60000 随机选100个数x_batch = x_train[batch_mask]y_batch = net.predict(x_batch)t_batch = t_train[batch_mask]grad = net.gradient(x_batch, t_batch)for key in ('w1', 'b1', 'w2', 'b2'):net.dict[key] -= lr * grad[key]loss = net.loss(y_batch, t_batch)train_loss_list.append(loss)# 每批数据记录一次精度和当前的损失值if i % iter_per_epoch == 0:train_acc = net.accuracy(x_train, t_train)test_acc = net.accuracy(x_test, t_test)train_acc_list.append(train_acc)test_acc_list.append(test_acc)print('第' + str(i/600) + '次迭代''train_acc, test_acc, loss :| ' + str(train_acc) + ", " + str(test_acc) + ',' + str(loss))np.save('w1.npy', net.dict['w1'])
np.save('b1.npy', net.dict['b1'])
np.save('w2.npy', net.dict['w2'])
np.save('b2.npy', net.dict['b2'])markers = {'train': 'o', 'test': 's'}
x = np.arange(len(train_acc_list))
plt.plot(x, train_acc_list, label='train acc')
plt.plot(x, test_acc_list, label='test acc', linestyle='--')
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()

测试

import numpy as np
from mnist import load_mnist
from functions import sigmoid, softmax
import cv2
######################################数据的预处理
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
batch_mask = np.random.choice(100,1)  # 从0到60000 随机选100个数
#print(batch_mask)
x_batch = x_train[batch_mask]#####################################转成图片
arr = x_batch.reshape(28,28)
cv2.imshow('wk',arr)
key = cv2.waitKey(10000)
#np.savetxt('batch_mask.txt',arr)
#print(x_batch)
#train_size = x_batch.shape[0]
#print(train_size)
########################################进入模型预测
w1 = np.load('w1.npy')
b1 = np.load('b1.npy')
w2 = np.load('w2.npy')
b2 = np.load('b2.npy')a1 = np.dot(x_batch,w1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1,w2) + b2
y = softmax(a2)
p = np.argmax(y, axis=1)print(p)

相关文章:

基于BP神经网络对MNIST数据集检测识别(numpy版本)

基于BP神经网络对MNIST数据集检测识别 1.作者介绍2.BP神经网络介绍2.1 BP神经网络 3.BP神经网络对MNIST数据集检测实验3.1 读取数据集3.2 前向传播3.3 损失函数3.4 构建神经网络3.5 训练3.6 模型推理 4.完整代码 1.作者…...

HTML5-创建HTML文档

HTML5中的一个主要变化是:将元素的语义与元素对其内容呈现结果的影响分开。从原理上讲这合乎情理。HTML元素负责文档内容的结构和含义,内容的呈现则由应用于元素上的CSS样式控制。下面介绍最基础的HTML元素:文档元素和元数据元素。 一、构建…...

Vue中Axios的封装和API接口的管理

一、axios的封装 在vue项目中,和后台交互获取数据这块,我们通常使用的是axios库,它是基于promise的http库,可运行在浏览器端和node.js中。他有很多优秀的特性,例如拦截请求和响应、取消请求、转换json、客户端防御XSR…...

MLIR面试题

1、请简要解释MLIR的概念和用途,并说明MLIR在编译器领域中的重要性。 MLIR(Multi-Level Intermediate Representation)是一种多级中间表示语言,提供灵活、可扩展和可优化的编译器基础设施。MLIR的主要目标是为不同的编程语言、领域专用语言(DSL)和编译器…...

***杨辉三角_yyds_LeetCode_python***

1.题目描述: 给定一个非负整数 numRows,生成「杨辉三角」的前 numRows 行。 在「杨辉三角」中,每个数是它左上方和右上方的数的和。 示例 1: 输入: numRows 5 输出: [[1],[1,1],[1,2,1],[1,3,3,1],[1,4,6,4,1]] 示例 2: 输入: numRows …...

Mac使用DBeaver连接达梦数据库

Mac使用DBeaver连接达梦数据库 下载达梦驱动包 达梦数据库 在下载页面随便选择一个系统并下载下来。 下载下来的是zip的压缩包解压出来就是一个ISO文件,然后我们打开ISO文件进入目录:/dameng/source/drivers/jdbc 进入目录后找到这几个驱动包&#x…...

spring.expression 随笔0 概述

0. 我只是个普通码农,不值得挽留 Spring SpEL表达式的使用 常见的应用场景:分布式锁的切面借助SpEL来构建key 比较另类的的应用场景:动态校验 个人感觉可以用作控制程序的走向,除此之外,spring的一些模块的自动配置类,也会在Cond…...

从Cookie到Session: Servlet API中的会话管理详解

文章目录 一. Cookie与Session1. Cookie与Session2. Servlet会话管理操作 二. 登录逻辑的实现 一. Cookie与Session 1. Cookie与Session 首先, 在学习过 HTTP 协议的基础上, 我们需要知道 Cookie 是 HTTP 请求报头中的一个关键字段, 本质上是浏览器在本地存储数据的一种机制,…...

docker数据管理与网络通信

一、管理docker容器中数据 管理Docker 容器中数据主要有两种方式:数据卷(Data Volumes)和数据卷容器( DataVolumes Containers) 。 1、 数据卷 数据卷是一个供容器使用的特殊目录,位于容器中。可将宿主机的目录挂载到数据卷上,对数据卷的修改操作立刻…...

怎么查询电脑的登录记录及密码更改情况?

源头是办公室公用的电脑莫名其妙打不开了,问别人也都不知道密码是多少 因为本来就没设密码啊!(躺倒) 甚至已经想好了如果是50万想攻破电脑,被po抓住要怎么花这笔钱了 是我想太多 当然最后也没解决,莫名…...

《三》TypeScript 中函数的类型

TypeScript 允许指定函数的参数和返回值的类型。 函数声明的类型定义:function 函数名(形参: 形参类型, 形参: 形参类型, ...): 返回值类型 {} function sum(x: number, y: number): number {return x y } sum(1, 2) // 正确 sum(1, 2, 3) // 错误。输入多余的或者…...

深入学习 Mysql 引擎 InnoDB、MyISAM

tip:作为程序员一定学习编程之道,一定要对代码的编写有追求,不能实现就完事了。我们应该让自己写的代码更加优雅,即使这会费时费力。 💕💕 推荐:体系化学习Java(Java面试专题&#…...

【华为OD统一考试B卷 | 100分】阿里巴巴找黄金宝箱(V)(C++ Java JavaScript Python)

题目描述 一贫如洗的樵夫阿里巴巴在去砍柴的路上,无意中发现了强盗集团的藏宝地,藏宝地有编号从0~N的箱子,每个箱子上面贴有一个数字。 阿里巴巴念出一个咒语数字k(k<N),找出连续k个宝箱数字和的最大值,并输出该最大值。 输入描述 第一行输入一个数字字串,数字之间…...

六步快速搭建个人网站

目录 第一步、选择搭建平台WordPress 第二步、选域名 1&#xff09;域名在哪买&#xff1f; 2&#xff09;域名怎么选&#xff1f; 3&#xff09;以阿里云为例&#xff0c;讲解怎么买域名 第三步、选择服务器 第四步、申请主机、安装WordPress 第五步、选择WordPress模…...

TypeScript 中的 type 关键字有什么用?

创建类型别名 在 TypeScript 中&#xff0c;type 关键字用于创建类型别名&#xff08;Type Alias&#xff09;。类型别名可以给一个类型起一个新的名字&#xff0c;使代码更具可读性和可维护性。 类型别名可以用于定义各种类型&#xff0c;包括基本类型、复合类型和自定义类型…...

27 getcwd 的调试

前言 同样是一个 很常用的 glibc 库函数 不管是 用户业务代码 还是 很多类库的代码, 基本上都会用到 获取当前路径 不过 我们这里是从 具体的实现 来看一下 测试用例 就是简单的使用了一下 getcwd rootubuntu:~/Desktop/linux/HelloWorld# cat Test04Getcwd.c #inc…...

使用IDEA使用Git:Git使用指北——实际操作篇

Git使用指北——实际操作 &#x1f916;:使用IDEA Git插件实际工作流程 &#x1f4a1; 本文从实际使用的角度出发&#xff0c;以IDEA Git插件为基座讲述了如果使用IDEA的Git插件来解决实际开发中的协作开发问题。本文从 远程仓库中拉取项目&#xff0c;在本地分支进行开发&…...

java boot将一组yml配置信息装配在一个对象中

其实将一组yml数据封进一个对象中才是以后的主流开发方式 我们创建一个springboot项目 找到项目中的启动类所在目录 在同目录下创建一个类 名字你们可以随便取 我这里直接叫 dataManager 然后 在yml中定义这样一组数据信息 然后 我们在类中定义三个和这个配置信息相同的字段…...

【裸机开发】链接脚本(.lds文件)的基本语法

目录 一、什么是链接脚本&#xff1f; 二、链接脚本的基本语法格式 1、常用命令 2、内置变量 三、链接脚本的简单案例 一、什么是链接脚本&#xff1f; 一段程序的编译需要经历四个阶段&#xff08;预处理—编译—汇编—链接&#xff09;&#xff0c;而链接脚本管理的就是…...

Java 进阶 -- 集合(三)

4、实现 实现是用于存储集合的数据对象&#xff0c;它实现了接口部分中描述的接口。本课描述了以下类型的实现: 通用实现是最常用的实现&#xff0c;是为日常使用而设计的。它们在标题为“通用实现”的表格中进行了总结。特殊目的实现是为在特殊情况下使用而设计的&#xff0…...

【华为OD机试真题 C语言】5、TLV解析 | 机试真题+思路参考+代码解析

文章目录 一、题目&#x1f383;题目描述&#x1f383;输入输出&#x1f383;样例1 二、思路参考三、代码参考&#x1f3c6;C语言 作者&#xff1a;KJ.JK &#x1f342;个人博客首页&#xff1a; KJ.JK &#x1f342;专栏介绍&#xff1a; 华为OD机试真题汇总&#xff0c;定期…...

(七)CSharp-刘铁锰版-事件

一、初步了解事件 定义&#xff1a;单词 Event &#xff0c;译为“事件” 《牛津词典》中的解释是“a thing that happens,especially something important”通顺的解释就是“能够发生的什么事情” 角色&#xff1a; 使对象或类具备通知能力的成员 &#xff08;中译&#x…...

【ROS】郭老二博文之:ROS目录

1、ROS2 【ROS】Ubuntu22.04安装ROS2&#xff08;Humble Hawksbill&#xff09; 【ROS】ROS2命令行工具详解 【ROS】ROS2中的概念和名词解释 【ROS】ROS2编程示例&#xff1a;话题订阅-发布-C版 【ROS】ROS2编程示例&#xff1a;服务和客户端-C版 【ROS】ROS2编程示例&#xf…...

Android应用程序进程的启动过程

Android应用程序进程的启动过程 导语 到这篇文章为止&#xff0c;我们已经简要地了解过了Android系统的启动流程了&#xff0c;其中比较重要的内容有Zygote进程的启动和SystemService以及Launcher的启动&#xff0c;接下来我们将要学习的是Android应用程序的启动过程&#xff…...

【2】Midjourney注册

随着AI技术的问世&#xff0c;2023年可以说是AI爆炸性成长的一年&#xff0c;近期最广为人知的AI服务除了chatgpt外&#xff0c;就是从去年五月就已经问世的AI绘画工具mid journey了。 ▲几个AI工具也代表了人工智能的热门阶段 只要输入一段文字&#xff0c;AI就会根据语意计算…...

第六十八天学习记录:高等数学:导数(宋浩板书)

导数是微积分中的一个概念&#xff0c;描述了函数在某一个点上的变化率。具体地说&#xff0c;函数 f ( x ) f(x) f(x)在 x a xa xa处的导数为 f ′ ( a ) f(a) f′(a)&#xff0c;表示当 x x x在 a a a处发生微小的变化 Δ x \Delta x Δx时&#xff0c; f ( x ) f(x) f(x)对…...

unreal 5 实现角色拾取功能

要实现角色拾取功能&#xff0c;我们需要实现蓝图接口功能&#xff0c;蓝图接口主要提供的是蓝图和蓝图之间可以通信&#xff0c;接下来&#xff0c;跟着教程&#xff0c;实现一下角色的拾取功能。 首先&#xff0c;我们要实现一个就是可视区的物品在朝向它的时候&#xff0c;会…...

chatgpt赋能python:如何使用Python升序排列一个列表?

如何使用Python升序排列一个列表&#xff1f; 在Python编程中&#xff0c;我们经常需要对列表进行排序。列表排序是一种常见的操作&#xff0c;可以帮助我们对数据进行分析和管理。在这篇文章中&#xff0c;我们将学习如何使用Python对一个列表进行升序排列。 什么是升序排列…...

Lecture 20 Topic Modelling

目录 Topic ModellingA Brief History of Topic ModelsLDAEvaluationConclusion Topic Modelling makeingsense of text English Wikipedia: 6M articlesTwitter: 500M tweets per dayNew York Times: 15M articlesarXiv: 1M articlesWhat can we do if we want to learn somet…...

ThreadPoolExecutor线程池

文章目录 一、ThreadPool线程池状态二、ThreadPoolExecutor构造方法三、Executors3.1 固定大小线程池3.2 带缓冲线程池3.3 单线程线程池 四、ThreadPoolExecutor4.1 execute(Runnable task)方法使用4.2 submit()方法4.3 invokeAll()4.4 invokeAny()4.5 shutdown()4.6 shutdownN…...

小程序推广网站/百度网络推广怎么收费

中石油校内的比赛&#xff0c;只能后续补题了&#xff0c;题目来说还是比较的不错 A 数方格&#xff08;思维&#xff09; 规定了都是正方形&#xff0c;那么枚举就可以一行的按照规律也好枚举列数也好枚举 代码 #include <bits/stdc.h> using namespace std; int main()…...

上海做网站的费用/站长之家是什么

C是一种编程语言&#xff0c;但又不是一种单一的编程语言&#xff0c;它可以包含以下四种子语言&#xff0c;也即C的四个组成部分&#xff1a; 1、C部分。C语言的基本语法&#xff0c;内置类型、预处理、数组、指针等。 2、面向对象部分。类&#xff0c;封装、继承、多态、虚…...

德阳网站制作/seo建设招商

圈套模式之投资 在触发器篇&#xff0c;我们讨论了内部触发器的重要性。通过外部触发器产品设计者可以引诱用户做出下一步行动。在行动篇&#xff0c;我们知道用户的一小步行动都是想立即实现潜在的期望。在奖励篇&#xff0c;我们知道不确定的奖励会使用户不断来使用产品。 投…...

寻找做项目的网站/网站建设找哪家公司好

目前越来越多的浏览器兼容CSS3标准了&#xff0c;就连IE浏览器老大哥也开始向CSS3低头&#xff0c;微软宣布IE9浏览器支持更多的CSS3属性&#xff0c;IE9更注重HTML5标准。不过CSS3里有一个使对象旋转的属性transform rotate&#xff0c;号称兼容CSS3的浏览器对它的支持也不算好…...

网站首页滚动图怎么做/抖音seo排名优化公司

第十三周 所花时间&#xff08;包括上课&#xff09; 9小时&#xff08;上课7小时&#xff09; 代码量&#xff08;行&#xff09; 360 博客量&#xff08;篇&#xff09; 3 了解到的知识点 数据库连接的复习&#xff0c;以及对增删改查的复习 转载于:https://www.cnbl…...

wordpress统计插件WP/seo排名查询软件

来源&#xff5c;新熵编辑&#xff5c;于松叶盲盒市场的不确定性正在加大。泡泡玛特的雷款滞销、利用福袋去库存等问题只是头部品牌透支用户信任、损失用户好感度的行为缩影&#xff0c;进入更大的视野&#xff0c;会发现整个盲盒市场已经开始趋于冷静和理智。盲盒圈就像一个围…...