当前位置: 首页 > news >正文

mnist数据集

训练模型

import tensorflow as tfimport keras
from keras.models import Sequential
from keras.layers import Dense,Dropout, Flatten,Conv2D, MaxPooling2D
# from keras.optimizers import SGD
from tensorflow.keras.optimizers import Adam,Nadam, SGDfrom PIL import Image
import numpy as np
import matplotlib.pyplot as pltprint('tf',tf.__version__)
print('keras',keras.__version__)# batch大小,每处理128个样本进行一次梯度更新
batch_size = 64
# 训练素材类别数
num_classes = 10
# 迭代次数
epochs = 5f = np.load("mnist.npz")
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
f.close()
print(x_train.shape,"  ",y_train.shape)
print(x_test.shape,"  ",y_test.shape)# im=plt.imshow(x_train[0],cmap="gray")
# plt.show()## 维度合并784 = 28*28
x_train = x_train.reshape(60000, 784).astype('float32')
x_test = x_test.reshape(10000, 784).astype('float32')## 归一化,像素点的值 转成 0-1 之间的数字
x_train /= 255
x_test /= 255# print(x_train[0])# 标签转换为独热码
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
# print(y_train[0]) ## 类似 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]print(x_train.shape,"  ",y_train.shape)
print(x_test.shape,"  ",y_test.shape)# 构建模型
model = Sequential()
model.add(Dense(512, activation='relu',input_shape=(784,)))
model.add(Dense(256, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))
model.summary()# [编译模型] 配置模型,损失函数采用交叉熵,优化采用Adadelta,将识别准确率作为模型评估
model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=tf.keras.optimizers.Adadelta(), metrics=['accuracy'])
#  validation_data为验证集
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_test, y_test))# 开始评估模型效果 # verbose=0为不输出日志信息
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1]) # 准确度model.save('mnist_model_weights.h5') # 保存训练模型

训练结果

Epoch 4/5
938/938 [==============================] - 7s 7ms/step - loss: 1.5926 - accuracy: 0.7292 - val_loss: 1.4802 - val_accuracy: 0.7653
Epoch 5/5
938/938 [==============================] - 6s 6ms/step - loss: 1.4047 - accuracy: 0.7686 - val_loss: 1.2988 - val_accuracy: 0.7918
Test loss: 1.2988097667694092
Test accuracy: 0.7918000221252441Process finished with exit code 0

测试模型


import tensorflow as tffrom PIL import Image
import numpy as np
from keras.models import load_model# 构建模型
model = load_model('mnist_model_weights.h5') # 加载训练模型
# model.summary()def read_image(img_name):im = Image.open(img_name).resize((28,28),Image.ANTIALIAS).convert('L') # 将要识别的图缩放到训练图一样的大小,并且灰度化data = np.array(im)return dataimages=[]
images.append(read_image("test.png"))
# print(images)X = np.array(images)
print(X.shape)
X=X.reshape(1, 784).astype('float32')
print(X.shape)
X /=255
# print(X[0:1])
result=model.predict(X[0:1])[0] # 识别出第一张图的结果,多张图的时候,把后面的[0] 去掉,返回的就是多张图结果
num=0 # 用来分析预测的结果
for i in range(len(result)): # result的长度是10# print(result[i]*255)if result[i]*255>result[num]*255: # 值越大,就越可能是结果num=iprint("预测结果",num)

将数据集转换为图片


#coding: utf-8
import os
import tensorflow as tf
import input_data
from PIL import Image'''
函数功能:按照bmp格式提取mnist数据集中的图片
参数介绍:mnist_dir   mnist数据集存储的路径save_dir    提取结果存储的目录
'''mint=tf.keras.datasets.mnistdef extract_mnist(mnist_dir, save_dir):rows = 28cols = 28# 加载mnist数据集# one_hot = True为默认打开"独热编码"mnist = input_data.read_data_sets(mnist_dir, one_hot=False)# 获取训练图片数量shape = mnist.train.images.shapeimages_train_count = shape[0]pixels_count_per_image = shape[1]# 获取训练标签数量=训练图片数量# 关闭"独热编码"后,labels的类型为[7 3 4 ... 5 6 8]labels = mnist.train.labelsprint(labels)exit(0)labels_train_count = labels.shape[0]if (images_train_count == labels_train_count):print("训练集共包含%d张图片,%d个标签" % (images_train_count, labels_train_count))print("每张图片包含%d个像素" % (pixels_count_per_image))print("数据类型为", mnist.train.images.dtype)# mnist图像数值的范围为[0,1], 需将其转换为[0,255]for current_image_id in range(images_train_count):for i in range(pixels_count_per_image):if mnist.train.images[current_image_id][i] != 0:mnist.train.images[current_image_id][i] = 255if ((current_image_id + 1) % 50) == 0:print("已转换%d张,共需转换%d张" %(current_image_id + 1, images_train_count))# 创建train images的保存目录, 按标签保存for i in range(10):dir = "%s/%s" % (save_dir, i)print(dir)if not os.path.exists(dir):os.mkdir(dir)# indices = [0, 0, 0, ..., 0]用来记录每个标签对应的图片数量indices = [0 for x in range(0, 10)]for i in range(images_train_count):new_image = Image.new("L", (cols, rows))# 遍历new_image 进行赋值for r in range(rows):for c in range(cols):new_image.putpixel((r, c), int(mnist.train.images[i][c + r * cols]))# 获取第i张训练图片对应的标签label = labels[i]image_save_path = "%s/%s/%s.bmp" % (save_dir, label,indices[label])indices[label] += 1new_image.save(image_save_path)# 打印保存进度if ((i + 1) % 50) == 0:print("图片保存进度: 已保存%d张,共需保存%d张" % (i + 1, images_train_count))else:print("图片数量与标签数量不一致!")if __name__ == '__main__':mnist_dir = "Mnist_Data"save_dir = "Mnist_Data_TrainImages"extract_mnist(mnist_dir, save_dir)

利用图片制作mnist格式数据集

import os
from PIL import Image
from array import *
from random import shuffle# # 文件组织架构:
# ├──training-images
# │   └──0(类别为0的图像)
# │   ├──1(类别为1的图像)
# │   ├──2(类别为2的图像)
# │   ├──3(类别为3的图像)
# │   └──4(类别为4的图像)
# ├──test-images
# │   └──0(类别为0的图像)
# │   ├──1(类别为1的图像)
# │   ├──2(类别为2的图像)
# │   ├──3(类别为3的图像)
# │   └──4(类别为4的图像)
# └── mnist数据集制作.py(本脚本)# Load from and save to
Names = [['./training-images', 'train'], ['./test-images', 'test']]for name in Names:data_image = array('B')data_label = array('B')print(os.listdir(name[0]))FileList = []for dirname in os.listdir(name[0])[0:]:  # [1:] Excludes .DS_Store from Mac OS# print(dirname)path = os.path.join(name[0], dirname)# print(path)for filename in os.listdir(path):# print(filename)if filename.endswith(".png"):FileList.append(os.path.join(name[0] + '/', dirname + '/', filename))print(FileList)shuffle(FileList)  # Usefull for further segmenting the validation setfor filename in FileList:label = int(filename.split('/')[2])print(filename)Im = Image.open(filename)# print(Im)pixel = Im.load()width, height = Im.sizefor x in range(0, width):for y in range(0, height):data_image.append(pixel[y, x])data_label.append(label)  # labels start (one unsigned byte each)hexval = "{0:#0{1}x}".format(len(FileList), 6)  # number of files in HEX# header for label arrayheader = array('B')header.extend([0, 0, 8, 1, 0, 0])header.append(int('0x' + hexval[2:][:2], 16))header.append(int('0x' + hexval[2:][2:], 16))data_label = header + data_label# additional header for images arrayif max([width, height]) <= 256:header.extend([0, 0, 0, width, 0, 0, 0, height])else:raise ValueError('Image exceeds maximum size: 256x256 pixels');header[3] = 3  # Changing MSB for image data (0x00000803)data_image = header + data_imageoutput_file = open(name[1] + '-images-idx3-ubyte', 'wb')data_image.tofile(output_file)output_file.close()output_file = open(name[1] + '-labels-idx1-ubyte', 'wb')data_label.tofile(output_file)output_file.close()# 运行脚本得到四个文件test-images-idx3-ubyte、test-labels-idx1-ubyte、train-images-idx3-ubyte、train-labels-idx1-ubyte
# 在cmd中利用gzip -c train-labels-idx1-ubyte > train-labels-idx1-ubyte.gz命令对上述四个文件压缩得到最终的mnist格式数据集

相关文章:

mnist数据集

训练模型 import tensorflow as tfimport keras from keras.models import Sequential from keras.layers import Dense,Dropout, Flatten,Conv2D, MaxPooling2D # from keras.optimizers import SGD from tensorflow.keras.optimizers import Adam,Nadam, SGDfrom PIL import…...

Java之IO概述以及

1.1 什么是IO 生活中&#xff0c;你肯定经历过这样的场景。当你编辑一个文本文件&#xff0c;忘记了ctrls &#xff0c;可能文件就白白编辑了。当你电脑上插入一个U盘&#xff0c;可以把一个视频&#xff0c;拷贝到你的电脑硬盘里。那么数据都是在哪些设备上的呢&#xff1f;键…...

Spring WebFlux—Reactive 核心

一、概述 spring-web 模块包含以下对响应式Web应用的基础支持&#xff1a; 对于服务器请求处理&#xff0c;有两个级别的支持。 HttpHandler: 用于HTTP请求处理的基本约定&#xff0c;具有非阻塞I/O和Reactive Streams背压&#xff0c;以及Reactor Netty、Undertow、Tomcat、…...

由于找不到d3dx9_43.dll,无法继续执行代码要怎么解决

D3DX9_43.dll是一个动态链接库文件&#xff0c;它是DirectX的一个组件&#xff0c;主要用于支持一些旧版本的游戏和软件。当电脑缺少这个文件时&#xff0c;可能会导致这些游戏和软件无法正常运行。例如&#xff0c;一些老游戏可能需要D3DX9_43.dll来支持图形渲染等功能。此外&…...

git安装配置教程

目录 git安装配置1. 安装git2. git 配置3.生成ssh key:4. 获取生产的密钥3. gitee或者github添加ssh-key4.git使用5. git 使用-本地仓库与远程仓库建立连接第一步&#xff1a;进入项目文件夹&#xff0c;初始化本地仓库第二步&#xff1a;建立远程仓库。 建立远程连接的小技巧 …...

要如何选择报修工单管理系统?需要注意哪些核心功能?

现如今&#xff0c;越来越多的企业已经离不开报修工单管理系统&#xff0c;但市面上的产品繁多&#xff0c;很难寻找到一款特别符合企业需求的系统。企业采购报修工单管理系统的主要目的在于利用其核心功能&#xff0c;如工单流转等&#xff0c;来解决工作事件的流程问题&#…...

面对大数据量渲染,前端工程师如何保证页面流畅性?

一、问题背景 在web前端开发中,需要渲染大量数据是很常见的需求。拿一般的业务系统来说,一个模块中往往需要显示成百上千条记录,这已经属于比较大的数据量。而一些大型系统,如数据分析平台、监控系统等,需要同时渲染的 数据量可能达到几十万甚至上百万。 面对大数据量渲染的需…...

2023年浙工商MBA新生奖学金名单公布,如何看待?

浙工商MBA项目官方最新公布了2023年的非全日制新生奖学金名单&#xff0c;按照政策约定&#xff0c;共分为特等奖学金1名&#xff0c;一等奖学金10名&#xff0c;二等奖学金15名&#xff0c;三等奖学金30名&#xff0c;额度对应3万、1万、0.8万、0.5万不等&#xff0c;主要名单…...

关于时空数据的培训 GAN:实用指南(第 02/3 部分)

一、说明 在本系列关于训练 GAN 实用指南的第 1 部分中&#xff0c;我们讨论了 a&#xff09; 鉴别器 &#xff08;D&#xff09; 和生成器 &#xff08;G&#xff09; 训练之间的不平衡如何导致模式崩溃和由于梯度消失而导致静音学习&#xff0c;以及 b&#xff09; GAN 对超参…...

UNIAPP利用canvas绘制图片和文字,并跟随鼠标移动

最近有个项目&#xff0c;要触摸组件&#xff0c;产生一条图片跟随移动&#xff0c;并显示相应的文字&#xff0c;在网上找了一些资料&#xff0c;终于完成构想&#xff0c;废话少说&#xff0c;直接上代码&#xff08;测试通过&#xff09; <template> <view>…...

【智能电表数据接入物联网平台实践】

智能电表数据接入物联网平台实践 设备接线准备设备调试代码实现Modbus TCP Client 读取电表数据读取寄存器数据转成32bit Float格式然后使用modbusTCP Client 读取数据 使用mqtt协议接入物联网平台最终代码实现 设备接线准备 设备调试 代码实现 Modbus TCP Client 读取电表数…...

Docker--network命令的用法

原文网址&#xff1a;Docker--network命令的用法_IT利刃出鞘的博客-CSDN博客 简介 说明 本文介绍Docker的network网络命令的用法。 官网网址 docker network | Docker Documentation 命令概述 所有命令 命令名称 说明 docker network connect 将容器连接到网络 dock…...

优维低代码实践:图片和搜索

优维低代码技术专栏&#xff0c;是一个全新的、技术为主的专栏&#xff0c;由优维技术委员会成员执笔&#xff0c;基于优维7年低代码技术研发及运维成果&#xff0c;主要介绍低代码相关的技术原理及架构逻辑&#xff0c;目的是给广大运维人提供一个技术交流与学习的平台。 优维…...

[Qt]控件

文章摘于 爱编程的大丙 文章目录 1. 按钮类型控件1.1 按钮基类 QAbstractButton1.1.1 标题和图标1.1.2 按钮的 Check 属性1.1.3 信号1.1.4 槽函数 1.2 QPushButton1.2.1 常用API1.2.2 按钮的使用 1.3 QToolButton1.3.1 常用API1.3.2 按钮的使用 1.4 QRadioButton1.4.1 常用API…...

GEE:快速实现时间序列线性趋势和变化敏感性计算(斜率、截距)以NDVI时间序列为例

作者:CSDN @ _养乐多_ 本博客将向您介绍如何使用Google Earth Engine(GEE)平台来处理Landsat 5、7和8的卫星图像数据,构建时间序列,以NDVI为例,计算NDVI时间序列的斜率和截距,以及如何导出这些结果供进一步分析使用。 文章目录 一、代码详解1.1 核心代码详解1.2 核心代…...

LC1713. 得到子序列的最少操作次数(java - 动态规划)

LC1713. 得到子序列的最少操作次数 题目描述LIS 动态规划 二分法代码演示 题目描述 难度 - 困难 LC1713.得到子序列的最少操作次数 给你一个数组 target &#xff0c;包含若干 互不相同 的整数&#xff0c;以及另一个整数数组 arr &#xff0c;arr 可能 包含重复元素。 每一次…...

vr飞机驾驶舱模拟流程3D仿真演示加大航飞安全法码

众所周知&#xff0c;航空航天飞行是一项耗资大、变量参数很多、非常复杂的系统工程&#xff0c;因此可利用虚拟仿真技术经济、安全及可重复性等特点&#xff0c;进行飞行任务或操作的模拟&#xff0c;以代替某些费时、费力、费钱的真实试验或者真实试验无法开展的场合&#xf…...

一、八大排序(sort)

文章目录 一、时间复杂度&#xff08;一&#xff09;定义&#xff1a;常数操作 二、空间复杂度&#xff08;一&#xff09;定义&#xff1a; 三、排序&#xff08;一&#xff09;选择排序1.定义2.代码3.特性 &#xff08;二&#xff09;冒泡排序1.定义2.代码3.特性 &#xff08…...

【AWS】AI 代码生成器—Amazon CodeWhisperer初体验 | 开启开挂编程之旅

使用 AI 编码配套应用程序更快、更安全地构建应用程序 文章目录 1.1 Amazon CodeWhisperper简介1.2 Amazon CodeWhisperer 定价2.1 打开VS Code2.2 安装AWS ToolKit插件 一、前言 1.1 Amazon CodeWhisperper简介 1️⃣更快地完成更多工作 CodeWhisperer 经过数十亿行代码的训…...

【Mysql主从配置方法---单主从】

Mysql主从 主服务器 创建用户 create user “for_rep”“从服务器IP地址” IDENTIFIED by “123456” 授权 grant replication slave on . to “for_rep”“从服务器IP地址” IDENTIFIED by “123456” 查看用户权限 SHOW GRANTS FOR “for_rep”“从服务器IP地址”; 修改M…...

⼀⽂读懂加密资产交易赛道的新锐⼒量Bitdu

交易所&#xff0c;仍然是加密资产赛道的皇冠级赛道。围绕这个领域展开的商业竞争&#xff0c;最能引起⼴⼤⽤⼾的关注。 经历了数轮资产价格涨跌的⽜熊之后&#xff0c;⼀批批创业者也在不断地思考这⼀议题 — 如何在去中⼼化的世界中&#xff0c;最⾼效率地集结流量、资本和…...

万里牛与金蝶云星空对接集成查询调拨单连通调拨单新增(万里牛调拨单-金蝶【直接调拨单】)

万里牛与金蝶云星空对接集成查询调拨单连通调拨单新增(万里牛调拨单-金蝶【直接调拨单】) 源系统:万里牛 万里牛是杭州湖畔网络技术有限公司旗下SaaS软件品牌&#xff0c;主要针对电商、外贸、实体门店等业务群体&#xff0c;帮助企业快速布局新零售&#xff0c;提升订单处理效…...

虚拟DOM与diff算法

虚拟DOM与diff算法 snabbdom虚拟DOMdiff算法 snabbdom 是什么&#xff1a;snabbdom是著名的虚拟DOM库&#xff0c;是diff算法的鼻祖&#xff0c;Vue源码借鉴了snabbdom 虚拟DOM 是什么&#xff1a;本质上是存在内存里的 JavaScript 对象 作用&#xff1a;用来描述真实DOM的层…...

K8S:pod资源限制及探针

文章目录 一.pod资源限制1.pod资源限制方式2.pod资源限制指定时指定的参数&#xff08;1&#xff09;request 资源&#xff08;2&#xff09; limit 资源&#xff08;3&#xff09;两种资源匹配方式 3.资源限制的示例&#xff08;1&#xff09;官网示例&#xff08;2&#xff0…...

CSS中的定位

position 的属性与含义 CSS 中的 position 属性用于控制元素在页面中的定位方式&#xff0c;有四个主要的取值&#xff0c;每个取值都会影响元素的布局方式&#xff0c;它们是&#xff1a; static&#xff08;默认值&#xff09;&#xff1a; 这是所有元素的初始定位方式。在静…...

二、链表(linked-list)

文章目录 一、定义二、经典例题&#xff08;一&#xff09;[21.合并两个有序链表](https://leetcode.cn/problems/merge-two-sorted-lists/description/)1.思路2.复杂度分析3.注意4.代码 &#xff08;二&#xff09;[86.分割链表](https://leetcode.cn/problems/partition-list…...

Android EditText筛选+选择功能开发

在日常开发中经常会遇到这种需求&#xff0c;EditText既需要可以筛选&#xff0c;又可以点击选择。这里筛选功能用的是AutoCompleteTextView&#xff0c;选择功能使用的是第三方库https://github.com/kongzue/DialogX。 Android AutoCompleteTextView(自动完成文本框)的基本使用…...

Linux 信号 alarm函数 setitimer函数

/*#include <unistd.h>unsigned int alarm(unsigned int seconds);功能&#xff1a;设置定时器。函数调用&#xff0c;开始倒计时&#xff0c;0的时候给当前的进程发送SIGALARM信号参数&#xff1a;倒计时的时长。。单位&#xff1a;秒 如果参数为0&#xff0c;无效返回…...

自主设计,模拟实现 RabbitMQ - 实现发送方消息确认机制

目录 一、实现发送方消息确认 1.1、需求分析 什么是发送方的消息确认? 如何实现?...

【数据结构】二叉树的·深度优先遍历(前中后序遍历)and·广度优先(层序遍历)

&#x1f490; &#x1f338; &#x1f337; &#x1f340; &#x1f339; &#x1f33b; &#x1f33a; &#x1f341; &#x1f343; &#x1f342; &#x1f33f; &#x1f344;&#x1f35d; &#x1f35b; &#x1f364; &#x1f4c3;个人主页 &#xff1a;阿然成长日记 …...

广州网页设计工资/seo手机排名软件

百分百题库提供安全员考试试题、建筑安全员考试预测题、建筑安全员ABC考试真题、安全员证考试题库等&#xff0c;提供在线做题刷题&#xff0c;在线模拟考试&#xff0c;助你考试轻松过关。 11.&#xff08;单选题&#xff09;物料提升机附墙架设置要符合设计要求&#xff0c;但…...

网站备案真实性核验委托书/百度网首页官网登录

各位客官&#xff0c;实在不好意思&#xff0c;好久没有和大家唠嗑了&#xff0c;最近由于一直忙于虚拟化的各项认证考试工作&#xff0c;于是荒废了客栈的生意&#xff0c;耽误了各位客官的休息与饮食&#xff0c;深表歉意。今天终于闲下来了&#xff0c;刚刚突然想起了制作U盘…...

wordpress skype插件/前端seo优化

1 /*2 实现同接口下不同类的对象的转移 3 定义类的接口4 定义多个继承该接口的类5 定义管理类&#xff0c;把接口当作类型&#xff0c;6 传入该接口下各种类的对象&#xff0c;进行操作 7 */8 #include<iostream>9 #include<…...

手机免费平面设计软件/南昌seo网站排名

由于信道管理器在客户端和服务端所起的不同作用&#xff0c;分为信道监听器和信道工厂。和服务端的信道监听其相比&#xff0c;处于客户端的信道工厂显得简单。从名称就可以看得出来&#xff0c;信道工厂的作用就是单纯的创建用于消息发送的信道。我们先来看看与信道工厂相关的…...

深圳龙华网站建设/友情链接怎么设置

先来一个简单的实现方法&#xff1a;复制代码代码如下:Document.mask-wrapper {position: relative;overflow: hidden;}.mask-inner {position: absolute;left: 0;top: 100%;width: 100%;height: 100%;-moz-transition: top ease 200ms;-o-transition: top ease 200ms;-webkit-…...

给手机开发网站/免费推广产品平台有哪些

今天继续接着CSS盒子模型补充&#xff0c;CSS基础学习十三&#xff1a;盒子模型和CSS基础学习十四&#xff1a;盒子模型补充之display属 性设置都是介绍了盒子模型中的内容概括。开始今天的主题&#xff1a;外边距合并。 外边距合并指的是&#xff0c;当两个垂直外边距相遇时&…...