当前位置: 首页 > news >正文

【TensorFlow2 之015】 在 TF 2.0 中实现 AlexNet

一、说明

       在这篇文章中,我们将展示如何在 TensorFlow 2.0 中实现基本的卷积神经网络 \(AlexNet\)。AlexNet 架构由 Alex Krizhevsky 设计,并与 Ilya Sutskever 和 Geoffrey Hinton 一起发布。并获得Image Net2012竞赛中冠军。

教程概述:

  1. 理论回顾
  2. 在 TensorFlow 2.0 中的实现

二 理论回顾

        现实生活中的计算机视觉问题需要大量高质量数据进行训练。过去,人们使用 CIFAR 和 NORB 数据集作为计算机视觉问题的基准数据集。然而,ImageNet竞赛改变了这一点。该数据集需要比以前更复杂的网络才能获得良好的结果。

        AlexNet 是 2012 年取得最佳结果的一种网络架构。它的 Top-5 错误率为 15.3%。第二好的成绩远远落后(26.2%)。

        该架构有大约 6000 万个参数,由以下层组成。

图层类型特征图尺寸内核大小跨步激活
图像1227×227
卷积9655×5511×114ReLU
最大池化9627×273×32
卷积25627×275×51ReLU
最大池化25613×133×32
卷积第384章13×133×31ReLU
卷积第384章13×133×31ReLU
卷积25613×133×31ReLU
最大池化2566×63×32
完全连接4096ReLU
完全连接4096ReLU
完全连接1000软最大

        在我们的例子中,我们将仅在 ImageNet 数据集中的两个类上训练模型,因此我们的最后一个全连接层将只有两个具有 Softmax 激活函数的神经元。

        有一些变化使得 AlexNet 与当时的其他网络不同。让我们看看是什么改变了历史!

2.1  重叠的池化层

        标准池化层汇总同一内核图中相邻神经元组的输出。传统上,相邻池单元总结的邻域不重叠。重叠池化层与标准池化层类似,只是计算 Max 的相邻窗口彼此重叠。

重叠池化与非重叠池化

2.2 ReLU 非线性

        评估神经元输出的传统方法是使用 sigmoid 或 tanh 激活函数。这两个函数固定在最小值和最大值之间,因此它们是饱和非线性的。然而,在 AlexNet 中,使用了修正线性单位函数,或者简称为 \(ReLU\)。该函数的阈值为\(0\)。这是一个非饱和激活函数。

        \(ReLU\) 函数需要更少的计算并允许更快的学习,这对在大型数据集上训练的大型模型的性能有很大影响。

2.3  局部响应标准化

        局部响应归一化 (LRN) 首次在 AlexNet 架构中引入,其中选择的激活函数是 \(ReLU\)。使用 LRN 的原因是为了鼓励 侧向抑制。 这是指神经元减少其邻居活动的能力。当我们使用 ReLU 激活函数处理神经元时,这非常有用。具有 \(ReLU\) 激活函数的神经元具有无界激活,我们需要 LRN 对其进行标准化。

三. TensorFlow 2.0中的实现

        交互式 Colab 笔记本可在以下链接找到

        让我们从导入所有必需的库开始

# Load the TensorBoard notebook extension
%load_ext tensorboard
import datetime
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as pltfrom tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout

        导入后,我们需要准备数据。在这里,我们将仅使用 ImageNet 数据集的一小部分。使用以下代码,您可以下载所有图像并将它们存储在文件夹中。

import cv2
import urllib
import requests
import PIL.Image
import numpy as np
from bs4 import BeautifulSoup#ship synset
page = requests.get("http://www.image-net.org/api/text/imagenet.synset.geturls?wnid=n04194289")
soup = BeautifulSoup(page.content, 'html.parser')
#bicycle synset
bikes_page = requests.get("http://www.image-net.org/api/text/imagenet.synset.geturls?wnid=n02834778")
bikes_soup = BeautifulSoup(bikes_page.content, 'html.parser')str_soup=str(soup)
split_urls=str_soup.split('\r\n')bikes_str_soup=str(bikes_soup)
bikes_split_urls=bikes_str_soup.split('\r\n')!mkdir /content/train
!mkdir /content/train/ships
!mkdir /content/train/bikes
!mkdir /content/validation
!mkdir /content/validation/ships
!mkdir /content/validation/bikesimg_rows, img_cols = 32, 32
input_shape = (img_rows, img_cols, 3)def url_to_image(url):resp = urllib.request.urlopen(url)image = np.asarray(bytearray(resp.read()), dtype="uint8")image = cv2.imdecode(image, cv2.IMREAD_COLOR)return imagen_of_training_images=100
for progress in range(n_of_training_images):if not split_urls[progress] == None:try:I = url_to_image(split_urls[progress])if (len(I.shape))==3:save_path = '/content/train/ships/img'+str(progress)+'.jpg'cv2.imwrite(save_path,I)except:Nonefor progress in range(n_of_training_images):if not bikes_split_urls[progress] == None:try:I = url_to_image(bikes_split_urls[progress])if (len(I.shape))==3:save_path = '/content/train/bikes/img'+str(progress)+'.jpg'cv2.imwrite(save_path,I)except:Nonefor progress in range(50):if not split_urls[progress] == None:try:I = url_to_image(split_urls[n_of_training_images+progress])if (len(I.shape))==3:save_path = '/content/validation/ships/img'+str(progress)+'.jpg'cv2.imwrite(save_path,I)except:Nonefor progress in range(50):if not bikes_split_urls[progress] == None:try:I = url_to_image(bikes_split_urls[n_of_training_images+progress])if (len(I.shape))==3:save_path = '/content/validation/bikes/img'+str(progress)+'.jpg'cv2.imwrite(save_path,I)except:None

        现在我们可以创建一个网络。原始 AlexNet 的最后一层有 1000 个神经元,但这里我们只使用一个。这是因为我们只将图像用于两个类。为了构建我们的卷积神经网络,我们将使用 Sequential API。

num_classes = 2# AlexNet model
class AlexNet(Sequential):def __init__(self, input_shape, num_classes):super().__init__()self.add(Conv2D(96, kernel_size=(11,11), strides= 4,padding= 'valid', activation= 'relu',input_shape= input_shape,kernel_initializer= 'he_normal'))self.add(MaxPooling2D(pool_size=(3,3), strides= (2,2),padding= 'valid', data_format= None))self.add(Conv2D(256, kernel_size=(5,5), strides= 1,padding= 'same', activation= 'relu',kernel_initializer= 'he_normal'))self.add(MaxPooling2D(pool_size=(3,3), strides= (2,2),padding= 'valid', data_format= None)) self.add(Conv2D(384, kernel_size=(3,3), strides= 1,padding= 'same', activation= 'relu',kernel_initializer= 'he_normal'))self.add(Conv2D(384, kernel_size=(3,3), strides= 1,padding= 'same', activation= 'relu',kernel_initializer= 'he_normal'))self.add(Conv2D(256, kernel_size=(3,3), strides= 1,padding= 'same', activation= 'relu',kernel_initializer= 'he_normal'))self.add(MaxPooling2D(pool_size=(3,3), strides= (2,2),padding= 'valid', data_format= None))self.add(Flatten())self.add(Dense(4096, activation= 'relu'))self.add(Dense(4096, activation= 'relu'))self.add(Dense(1000, activation= 'relu'))self.add(Dense(num_classes, activation= 'softmax'))self.compile(optimizer= tf.keras.optimizers.Adam(0.001),loss='categorical_crossentropy',metrics=['accuracy'])model = AlexNet((227, 227, 3), num_classes)

        创建模型后,我们定义一些重要的参数以供以后使用。此外,让我们创建图像数据生成器。\(AlexNet\)的参数非常多,有6000万个,这是一个巨大的数字。如果没有足够的数据,这将很可能导致过度拟合。因此,在这里,我们将利用数据增强技术,您可以在此处找到更多相关信息。

        出于同样的原因,AlexNet 中使用了 dropout 层。该技术包括以预定概率“关闭”神经元。这迫使每个神经元具有更强大的特征,可以与其他神经元一起使用。我们不会在这里使用 dropout 层,因为我们不会使用整个数据集。

# some training parameters
EPOCHS = 100
BATCH_SIZE = 32
image_height = 227
image_width = 227
train_dir = "train"
valid_dir = "validation"
model_dir = "my_model.h5"

train_datagen = ImageDataGenerator(rescale=1./255,rotation_range=10,width_shift_range=0.1,height_shift_range=0.1,shear_range=0.1,zoom_range=0.1)train_generator = train_datagen.flow_from_directory(train_dir,target_size=(image_height, image_width),color_mode="rgb",batch_size=BATCH_SIZE,seed=1,shuffle=True,class_mode="categorical")valid_datagen = ImageDataGenerator(rescale=1.0/255.0)
valid_generator = valid_datagen.flow_from_directory(valid_dir,target_size=(image_height, image_width),color_mode="rgb",batch_size=BATCH_SIZE,seed=7,shuffle=True,class_mode="categorical")
train_num = train_generator.samples
valid_num = valid_generator.samples

        现在我们可以设置TensorBoard并开始训练我们的模型。这样我们就可以实时跟踪模型性能。

log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
callback_list = [tensorboard_callback]# start training
model.fit(train_generator,epochs=EPOCHS,steps_per_epoch=train_num // BATCH_SIZE,validation_data=valid_generator,validation_steps=valid_num // BATCH_SIZE,callbacks=callback_list,verbose=0)# save the whole model
model.save(model_dir)%tensorboard --logdir logs/fit

        让我们使用我们的模型进行一些预测并将其可视化。

class_names = ['bike', 'ship']x_valid, label_batch  = next(iter(valid_generator))prediction_values = model.predict_classes(x_valid)# set up the figure
fig = plt.figure(figsize=(10, 6))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)# plot the images: each image is 227x227 pixels
for i in range(8):ax = fig.add_subplot(2, 4, i + 1, xticks=[], yticks=[])ax.imshow(x_valid[i,:],cmap=plt.cm.gray_r, interpolation='nearest')if prediction_values[i] == np.argmax(label_batch[i]):# label the image with the blue textax.text(3, 17, class_names[prediction_values[i]], color='blue', fontsize=14)else:# label the image with the red textax.text(3, 17, class_names[prediction_values[i]], color='red', fontsize=14)

 

四、概括

        在这篇文章中,我们展示了如何在 TensorFlow 2.0 中实现 \(AlexNet\)。我们只使用了 ImageNet 数据集的一部分,这就是为什么我们没有得到最好的结果。为了获得更高的准确性,需要更多的数据和更长的训练时间。

参考资料:

 数据黑客变种rs    深度学习 机器学习 TensorFlow    2020 年 2 月 29 日  |  0

相关文章:

【TensorFlow2 之015】 在 TF 2.0 中实现 AlexNet

一、说明 在这篇文章中,我们将展示如何在 TensorFlow 2.0 中实现基本的卷积神经网络 \(AlexNet\)。AlexNet 架构由 Alex Krizhevsky 设计,并与 Ilya Sutskever 和 Geoffrey Hinton 一起发布。并获得Image Net2012竞赛中冠军。 教程概述: 理论…...

Python进阶之迭代器

文章目录 前言一、迭代器介绍及作用1.可迭代对象2. 迭代器 二、常用函数和迭代器1.常用函数2.迭代器 三、总结结束语 💂 个人主页:风间琉璃🤟 版权: 本文由【风间琉璃】原创、在CSDN首发、需要转载请联系博主💬 如果文章对你有帮助、欢迎关注…...

Vue鼠标右键画矩形和Ctrl按键多选组件

效果图 说明 下面会贴出组件代码以及一个Demo&#xff0c;上面的效果图即为Demo的效果&#xff0c;建议直接将两份代码拷贝到自己的开发环境直接运行调试。 组件代码 <template><!-- 鼠标画矩形选择对象 --><div class"objects" ref"objectsR…...

【MySQL JDBC】使用Java连接MySQL数据库

一、什么是JDBC&#xff1f; 理解API的概念 API&#xff1a;Application Programing Interface -- 应用程序编程接口写好一个程序&#xff0c;这个程序需要给别人提供哪些功能&#xff1f;这些功能就是通过一些 函数/类 这样的方式来提供的。例如 Random、Scanner、ArrayList..…...

字节码学习之常见java语句的底层原理

文章目录 前言1. if语句字节码的解析 2. for循环字节码的解析 3. while循环4. switch语句5. try-catch语句6. i 和i的字节码7. try-catch-finally8. 参考文档 前言 上一章我们聊了《JVM字节码指令详解》 。本章我们学以致用&#xff0c;聊一下我们常见的一些java语句的特性底层…...

Godot C#连接信号不能像GDScirpt一样自动添加代码

前言 我网上找了好久&#xff0c;发现Godot 对于C# 的支持还有待增强 使用c#脚本有办法像gds那样连接节点自带信号时自动生成信号吗&#xff1f; 百度贴吧 Godot C# How To, Episode 9. Signals With Parameters | Godot Mono 解决方案 把信号拉长&#xff0c;看他的属性 修…...

快速自动化处理JavaScript渲染页面

在进行网络数据抓取时&#xff0c;许多网站使用了JavaScript来动态加载内容&#xff0c;这给传统的网络爬虫带来了一定的挑战。本文将介绍如何使用Selenium和ChromeDriver来实现自动化处理JavaScript渲染页面&#xff0c;并实现有效的数据抓取。 1、Selenium和ChromeDriver简介…...

通过API接口进行商品价格监控,可以按照以下步骤进行操作

要实现通过API接口进行商品价格监控&#xff0c;可以按照以下步骤进行操作&#xff1a; 申请平台账号并选择API接口&#xff1a;根据需要的功能&#xff0c;选择相应的API接口&#xff0c;例如商品API接口、店铺API接口、订单API接口等&#xff0c;这一步骤通常需要我们在相应…...

(vue3)大事记管理系统 文章管理页

[element-plus进阶] 文章列表渲染&#xff08;带搜索&到分页&#xff09; 表单架设&#xff1a;当前el-form标签配置一个inline属性&#xff0c;里面的元素就会在一行显示了 中英国际化处理&#xff1a;App.vue中el-config-provider标签包裹组件&#xff0c;意味着整个组…...

springboot 使用RocketMQ客户端生产消费消息DEMO

创建springboot项目省略 项目依赖 注意&#xff1a;当前客户端版本是 5.1.3 &#xff0c;安装的rocketmq服务的版本要与其对应 <properties><java.version>11</java.version><rocketmq-client-java-version>5.1.3</rocketmq-client-java-version&…...

第三章 内存管理 四、连续分配管理方式

目录 一、内存空间的分配与回收 1、连续分配管理方式 &#xff08;1&#xff09;、单一连续分配 优点&#xff1a; 缺点&#xff1a; &#xff08;2&#xff09;、固定分区分配 分区大小相等&#xff1a; 分区大小不等&#xff1a; &#xff08;3&#xff09;、动态分区…...

npm install报--4048错误和ERR_SOCKET_TIMEOUT问题解决方法之一

一、问题描述 学习vue数字大屏加载动漫效果时&#xff0c;在项目终端页面输入全局下载指令 npm install -g json-server 问题1、报--4048错误 会报如下错误 operation not permitted......errno: -4048code:EPERMsyscall: mkdir......The operation was reiected by your op…...

合并两个有序数组

给你两个按 非递减顺序 排列的整数数组 nums1 和 nums2&#xff0c;另有两个整数 m 和 n &#xff0c;分别表示 nums1 和 nums2 中的元素数目。 请你 合并 nums2 到 nums1 中&#xff0c;使合并后的数组同样按 非递减顺序 排列。 注意&#xff1a;最终&#xff0c;合并后数组…...

自动泊车系统设计学习笔记

1 概述 1.1 自动泊车系统研究现状 目前对于自动泊车系统的研究方法通常有两种实现方式&#xff1a; 整个泊车操作可以分为四个阶段&#xff1a;第一阶段车辆向前行驶进行车位识别&#xff0c;第二阶段车辆行驶到准备泊车时的待泊车区域&#xff0c;第三阶段车辆按照规划好的…...

基于Java的家电销售网站管理系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言具体实现截图论文参考论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09;有保障的售后福利 代码参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域…...

设计模式~备忘录模式(memento)-22

目录  (1)优点&#xff1a; (2)缺点&#xff1a; (3)使用场景&#xff1a; (4)注意事项&#xff1a; (5)应用实例&#xff1a; 代码 备忘录模式(memento) 备忘录模式&#xff08;Memento Pattern&#xff09;保存一个对象的某个状态&#xff0c;以便在适当的时候恢复对…...

【Agora UID 踩坑记录 Java 数据类型】

目录 负数二进制表示Java中32位无符号数的取法项目踩坑记录Java 0xffffffff隐式类型转换的坑 负数二进制表示 由于计算机中数据都以二进制表示&#xff0c;而负数的二级制是根据正数二进制取补码&#xff08;补码就是先取反码&#xff0c;然后加1&#xff09;得到&#xff0c;…...

ESP8285 RTOS SDK OTA

一、官方资源说明 官方指南&#xff1a;空中升级 (OTA) - ESP32 - — ESP-IDF 编程指南 v4.3.6 文档&#xff0c;虽然是正对ESP32的&#xff0c;但是原理是一样的。 官方参考例程&#xff1a;esp-idf\ESP8266_RTOS_SDK\examples\system\ota\&#xff0c;其中包含两个例程&…...

Hadoop3教程(四):HDFS的读写流程及节点距离计算

文章目录 &#xff08;55&#xff09;HDFS 写数据流程&#xff08;56&#xff09; 节点距离计算&#xff08;57&#xff09;机架感知&#xff08;副本存储节点选择&#xff09;&#xff08;58&#xff09;HDFS 读数据流程参考文献 &#xff08;55&#xff09;HDFS 写数据流程 …...

[0xGameCTF 2023] web题解

文章目录 [Week 1]signinbaby_phphello_httprepo_leakping [Week 2]ez_sqli方法一&#xff08;十六进制绕过&#xff09;方法二&#xff08;字符串拼接&#xff09; ez_upload [Week 1] signin 打开题目&#xff0c;查看下js代码 在main.js里找到flag baby_php <?php /…...

Qt之submodule编译

工作中会遇到这样一种情况&#xff1a;qt应用程序在运行时提示找不到某个qt的动态库。我遇到的是缺少libQt5Websocket.so&#xff0c;因为应用程序是在x86平台银河麒麟v10上开发&#xff0c;能够正常编译运行&#xff0c;然后移植到rk3588&#xff08;aarch64架构&#xff09;上…...

Python实现带图形界面的计算器

Python实现带图形界面的计算器 在本文中&#xff0c;我们将使用Python编写一个带有图形用户界面的计算器程序。这个程序将允许用户通过点击按钮或键盘输入数字和操作符&#xff0c;并在显示屏上显示计算结果。 开发环境准备 要运行这个计算器程序&#xff0c;您需要安装Pyth…...

$ vue -Vbash: vue: command not found

$ vue -V bash: vue: command not found报这个错&#xff0c;我们需要找到vue安装路径&#xff0c;添加在环境变量的用户变量中&#xff1a; 1、vue安装路径 2、编辑环境变量 然后重新打开命令框&#xff0c;就可以了...

专业音视频领域中,Pro AV的崛起之路

编者按&#xff1a;在技术进步的加持下&#xff0c;AV行业发展得如何了&#xff1f;本文采访了两位深耕于广播电视行业的技术人&#xff0c;为我们介绍了专业音视频的进展&#xff1a;一位冉冉升起的新星&#xff1a;Pro AV以及FPGA在其中发挥的作用。 美国&#xff0c;拉斯维加…...

vscode 右侧滚动条标记不提示,问题解决纪录

问题描述 用vscode看代码时&#xff0c;我希望在右侧提示一个变量在文件下都在那里使用&#xff0c;在那里赋值&#xff0c;之前该功能是存在的&#xff0c;当我打开一个新的文件夹时这个功能消失了。 解决办法 在setting.json文件下输入 "C_Cpp.intelliSenseEngine&…...

【Java 进阶篇】JavaScript特殊语法详解

JavaScript是一门非常灵活的编程语言&#xff0c;允许开发人员使用多种不同的语法和技巧来解决各种问题。本篇博客将深入探讨JavaScript中的一些特殊语法&#xff0c;这些语法可能不是常规的JavaScript编程知识&#xff0c;但它们对于理解语言的强大之处以及在某些情况下解决问…...

PCL点云处理之配准中的匹配对连线可视化显示 Correspondences(二百一十九)

PCL点云处理之配准中的匹配对连线可视化显示 Correspondences(二百一十九) 一、算法介绍二、算法实现1.可视化代码2.完整代码(特征匹配+可视化)最终效果一、算法介绍 关于点云配准中的匹配对,如果能够可视化将极大提高实验的准确性,还好PCL提供了这样的可视化工具,做法…...

Vue el-table全表搜索,模糊匹配-前端静态查询

后端返回的数据是全部的数据&#xff0c;没有分页&#xff0c;前端需要做的是分页全表模糊查询 代码&#xff1a; //根据关键字对表全局搜索 globalSearch() {//为了拿到对象的列名let filterList Object.keys(this.tableData[0]);if (this.searchWord) {this.tableFilterDat…...

基于html5开发的Win12网页版,抢先体验

据 MSPoweruser 报道&#xff0c;Windows 11虽然刚刚开始步入正轨&#xff0c;但最新爆料称微软已经在开启下一个计划&#xff0c;Windows 12 的开发将在 去年3 月份开始。德国科技网站 Deskmodder.de 称&#xff0c;根据内部消息&#xff0c;微软将在 2022年3 月开始开发 Wind…...

Studio One6.5中文版本下载安装步骤

在唱歌效果调试当中&#xff0c;我们经常给客户安装的几款音频工作站。第一&#xff0c;Studio One 6是PreSonus公司开发的一款功能强大的音频工作平台&#xff0c;具有丰富的音频处理功能和灵活的工作流程。以下是Studio One6的一些主要特点&#xff1a; 1.多轨录音和编辑&…...