当前位置: 首页 > news >正文

【TensorFlow2 之015】 在 TF 2.0 中实现 AlexNet

一、说明

       在这篇文章中,我们将展示如何在 TensorFlow 2.0 中实现基本的卷积神经网络 \(AlexNet\)。AlexNet 架构由 Alex Krizhevsky 设计,并与 Ilya Sutskever 和 Geoffrey Hinton 一起发布。并获得Image Net2012竞赛中冠军。

教程概述:

  1. 理论回顾
  2. 在 TensorFlow 2.0 中的实现

二 理论回顾

        现实生活中的计算机视觉问题需要大量高质量数据进行训练。过去,人们使用 CIFAR 和 NORB 数据集作为计算机视觉问题的基准数据集。然而,ImageNet竞赛改变了这一点。该数据集需要比以前更复杂的网络才能获得良好的结果。

        AlexNet 是 2012 年取得最佳结果的一种网络架构。它的 Top-5 错误率为 15.3%。第二好的成绩远远落后(26.2%)。

        该架构有大约 6000 万个参数,由以下层组成。

图层类型特征图尺寸内核大小跨步激活
图像1227×227
卷积9655×5511×114ReLU
最大池化9627×273×32
卷积25627×275×51ReLU
最大池化25613×133×32
卷积第384章13×133×31ReLU
卷积第384章13×133×31ReLU
卷积25613×133×31ReLU
最大池化2566×63×32
完全连接4096ReLU
完全连接4096ReLU
完全连接1000软最大

        在我们的例子中,我们将仅在 ImageNet 数据集中的两个类上训练模型,因此我们的最后一个全连接层将只有两个具有 Softmax 激活函数的神经元。

        有一些变化使得 AlexNet 与当时的其他网络不同。让我们看看是什么改变了历史!

2.1  重叠的池化层

        标准池化层汇总同一内核图中相邻神经元组的输出。传统上,相邻池单元总结的邻域不重叠。重叠池化层与标准池化层类似,只是计算 Max 的相邻窗口彼此重叠。

重叠池化与非重叠池化

2.2 ReLU 非线性

        评估神经元输出的传统方法是使用 sigmoid 或 tanh 激活函数。这两个函数固定在最小值和最大值之间,因此它们是饱和非线性的。然而,在 AlexNet 中,使用了修正线性单位函数,或者简称为 \(ReLU\)。该函数的阈值为\(0\)。这是一个非饱和激活函数。

        \(ReLU\) 函数需要更少的计算并允许更快的学习,这对在大型数据集上训练的大型模型的性能有很大影响。

2.3  局部响应标准化

        局部响应归一化 (LRN) 首次在 AlexNet 架构中引入,其中选择的激活函数是 \(ReLU\)。使用 LRN 的原因是为了鼓励 侧向抑制。 这是指神经元减少其邻居活动的能力。当我们使用 ReLU 激活函数处理神经元时,这非常有用。具有 \(ReLU\) 激活函数的神经元具有无界激活,我们需要 LRN 对其进行标准化。

三. TensorFlow 2.0中的实现

        交互式 Colab 笔记本可在以下链接找到

        让我们从导入所有必需的库开始

# Load the TensorBoard notebook extension
%load_ext tensorboard
import datetime
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as pltfrom tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout

        导入后,我们需要准备数据。在这里,我们将仅使用 ImageNet 数据集的一小部分。使用以下代码,您可以下载所有图像并将它们存储在文件夹中。

import cv2
import urllib
import requests
import PIL.Image
import numpy as np
from bs4 import BeautifulSoup#ship synset
page = requests.get("http://www.image-net.org/api/text/imagenet.synset.geturls?wnid=n04194289")
soup = BeautifulSoup(page.content, 'html.parser')
#bicycle synset
bikes_page = requests.get("http://www.image-net.org/api/text/imagenet.synset.geturls?wnid=n02834778")
bikes_soup = BeautifulSoup(bikes_page.content, 'html.parser')str_soup=str(soup)
split_urls=str_soup.split('\r\n')bikes_str_soup=str(bikes_soup)
bikes_split_urls=bikes_str_soup.split('\r\n')!mkdir /content/train
!mkdir /content/train/ships
!mkdir /content/train/bikes
!mkdir /content/validation
!mkdir /content/validation/ships
!mkdir /content/validation/bikesimg_rows, img_cols = 32, 32
input_shape = (img_rows, img_cols, 3)def url_to_image(url):resp = urllib.request.urlopen(url)image = np.asarray(bytearray(resp.read()), dtype="uint8")image = cv2.imdecode(image, cv2.IMREAD_COLOR)return imagen_of_training_images=100
for progress in range(n_of_training_images):if not split_urls[progress] == None:try:I = url_to_image(split_urls[progress])if (len(I.shape))==3:save_path = '/content/train/ships/img'+str(progress)+'.jpg'cv2.imwrite(save_path,I)except:Nonefor progress in range(n_of_training_images):if not bikes_split_urls[progress] == None:try:I = url_to_image(bikes_split_urls[progress])if (len(I.shape))==3:save_path = '/content/train/bikes/img'+str(progress)+'.jpg'cv2.imwrite(save_path,I)except:Nonefor progress in range(50):if not split_urls[progress] == None:try:I = url_to_image(split_urls[n_of_training_images+progress])if (len(I.shape))==3:save_path = '/content/validation/ships/img'+str(progress)+'.jpg'cv2.imwrite(save_path,I)except:Nonefor progress in range(50):if not bikes_split_urls[progress] == None:try:I = url_to_image(bikes_split_urls[n_of_training_images+progress])if (len(I.shape))==3:save_path = '/content/validation/bikes/img'+str(progress)+'.jpg'cv2.imwrite(save_path,I)except:None

        现在我们可以创建一个网络。原始 AlexNet 的最后一层有 1000 个神经元,但这里我们只使用一个。这是因为我们只将图像用于两个类。为了构建我们的卷积神经网络,我们将使用 Sequential API。

num_classes = 2# AlexNet model
class AlexNet(Sequential):def __init__(self, input_shape, num_classes):super().__init__()self.add(Conv2D(96, kernel_size=(11,11), strides= 4,padding= 'valid', activation= 'relu',input_shape= input_shape,kernel_initializer= 'he_normal'))self.add(MaxPooling2D(pool_size=(3,3), strides= (2,2),padding= 'valid', data_format= None))self.add(Conv2D(256, kernel_size=(5,5), strides= 1,padding= 'same', activation= 'relu',kernel_initializer= 'he_normal'))self.add(MaxPooling2D(pool_size=(3,3), strides= (2,2),padding= 'valid', data_format= None)) self.add(Conv2D(384, kernel_size=(3,3), strides= 1,padding= 'same', activation= 'relu',kernel_initializer= 'he_normal'))self.add(Conv2D(384, kernel_size=(3,3), strides= 1,padding= 'same', activation= 'relu',kernel_initializer= 'he_normal'))self.add(Conv2D(256, kernel_size=(3,3), strides= 1,padding= 'same', activation= 'relu',kernel_initializer= 'he_normal'))self.add(MaxPooling2D(pool_size=(3,3), strides= (2,2),padding= 'valid', data_format= None))self.add(Flatten())self.add(Dense(4096, activation= 'relu'))self.add(Dense(4096, activation= 'relu'))self.add(Dense(1000, activation= 'relu'))self.add(Dense(num_classes, activation= 'softmax'))self.compile(optimizer= tf.keras.optimizers.Adam(0.001),loss='categorical_crossentropy',metrics=['accuracy'])model = AlexNet((227, 227, 3), num_classes)

        创建模型后,我们定义一些重要的参数以供以后使用。此外,让我们创建图像数据生成器。\(AlexNet\)的参数非常多,有6000万个,这是一个巨大的数字。如果没有足够的数据,这将很可能导致过度拟合。因此,在这里,我们将利用数据增强技术,您可以在此处找到更多相关信息。

        出于同样的原因,AlexNet 中使用了 dropout 层。该技术包括以预定概率“关闭”神经元。这迫使每个神经元具有更强大的特征,可以与其他神经元一起使用。我们不会在这里使用 dropout 层,因为我们不会使用整个数据集。

# some training parameters
EPOCHS = 100
BATCH_SIZE = 32
image_height = 227
image_width = 227
train_dir = "train"
valid_dir = "validation"
model_dir = "my_model.h5"

train_datagen = ImageDataGenerator(rescale=1./255,rotation_range=10,width_shift_range=0.1,height_shift_range=0.1,shear_range=0.1,zoom_range=0.1)train_generator = train_datagen.flow_from_directory(train_dir,target_size=(image_height, image_width),color_mode="rgb",batch_size=BATCH_SIZE,seed=1,shuffle=True,class_mode="categorical")valid_datagen = ImageDataGenerator(rescale=1.0/255.0)
valid_generator = valid_datagen.flow_from_directory(valid_dir,target_size=(image_height, image_width),color_mode="rgb",batch_size=BATCH_SIZE,seed=7,shuffle=True,class_mode="categorical")
train_num = train_generator.samples
valid_num = valid_generator.samples

        现在我们可以设置TensorBoard并开始训练我们的模型。这样我们就可以实时跟踪模型性能。

log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
callback_list = [tensorboard_callback]# start training
model.fit(train_generator,epochs=EPOCHS,steps_per_epoch=train_num // BATCH_SIZE,validation_data=valid_generator,validation_steps=valid_num // BATCH_SIZE,callbacks=callback_list,verbose=0)# save the whole model
model.save(model_dir)%tensorboard --logdir logs/fit

        让我们使用我们的模型进行一些预测并将其可视化。

class_names = ['bike', 'ship']x_valid, label_batch  = next(iter(valid_generator))prediction_values = model.predict_classes(x_valid)# set up the figure
fig = plt.figure(figsize=(10, 6))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)# plot the images: each image is 227x227 pixels
for i in range(8):ax = fig.add_subplot(2, 4, i + 1, xticks=[], yticks=[])ax.imshow(x_valid[i,:],cmap=plt.cm.gray_r, interpolation='nearest')if prediction_values[i] == np.argmax(label_batch[i]):# label the image with the blue textax.text(3, 17, class_names[prediction_values[i]], color='blue', fontsize=14)else:# label the image with the red textax.text(3, 17, class_names[prediction_values[i]], color='red', fontsize=14)

 

四、概括

        在这篇文章中,我们展示了如何在 TensorFlow 2.0 中实现 \(AlexNet\)。我们只使用了 ImageNet 数据集的一部分,这就是为什么我们没有得到最好的结果。为了获得更高的准确性,需要更多的数据和更长的训练时间。

参考资料:

 数据黑客变种rs    深度学习 机器学习 TensorFlow    2020 年 2 月 29 日  |  0

相关文章:

【TensorFlow2 之015】 在 TF 2.0 中实现 AlexNet

一、说明 在这篇文章中,我们将展示如何在 TensorFlow 2.0 中实现基本的卷积神经网络 \(AlexNet\)。AlexNet 架构由 Alex Krizhevsky 设计,并与 Ilya Sutskever 和 Geoffrey Hinton 一起发布。并获得Image Net2012竞赛中冠军。 教程概述: 理论…...

Python进阶之迭代器

文章目录 前言一、迭代器介绍及作用1.可迭代对象2. 迭代器 二、常用函数和迭代器1.常用函数2.迭代器 三、总结结束语 💂 个人主页:风间琉璃🤟 版权: 本文由【风间琉璃】原创、在CSDN首发、需要转载请联系博主💬 如果文章对你有帮助、欢迎关注…...

Vue鼠标右键画矩形和Ctrl按键多选组件

效果图 说明 下面会贴出组件代码以及一个Demo&#xff0c;上面的效果图即为Demo的效果&#xff0c;建议直接将两份代码拷贝到自己的开发环境直接运行调试。 组件代码 <template><!-- 鼠标画矩形选择对象 --><div class"objects" ref"objectsR…...

【MySQL JDBC】使用Java连接MySQL数据库

一、什么是JDBC&#xff1f; 理解API的概念 API&#xff1a;Application Programing Interface -- 应用程序编程接口写好一个程序&#xff0c;这个程序需要给别人提供哪些功能&#xff1f;这些功能就是通过一些 函数/类 这样的方式来提供的。例如 Random、Scanner、ArrayList..…...

字节码学习之常见java语句的底层原理

文章目录 前言1. if语句字节码的解析 2. for循环字节码的解析 3. while循环4. switch语句5. try-catch语句6. i 和i的字节码7. try-catch-finally8. 参考文档 前言 上一章我们聊了《JVM字节码指令详解》 。本章我们学以致用&#xff0c;聊一下我们常见的一些java语句的特性底层…...

Godot C#连接信号不能像GDScirpt一样自动添加代码

前言 我网上找了好久&#xff0c;发现Godot 对于C# 的支持还有待增强 使用c#脚本有办法像gds那样连接节点自带信号时自动生成信号吗&#xff1f; 百度贴吧 Godot C# How To, Episode 9. Signals With Parameters | Godot Mono 解决方案 把信号拉长&#xff0c;看他的属性 修…...

快速自动化处理JavaScript渲染页面

在进行网络数据抓取时&#xff0c;许多网站使用了JavaScript来动态加载内容&#xff0c;这给传统的网络爬虫带来了一定的挑战。本文将介绍如何使用Selenium和ChromeDriver来实现自动化处理JavaScript渲染页面&#xff0c;并实现有效的数据抓取。 1、Selenium和ChromeDriver简介…...

通过API接口进行商品价格监控,可以按照以下步骤进行操作

要实现通过API接口进行商品价格监控&#xff0c;可以按照以下步骤进行操作&#xff1a; 申请平台账号并选择API接口&#xff1a;根据需要的功能&#xff0c;选择相应的API接口&#xff0c;例如商品API接口、店铺API接口、订单API接口等&#xff0c;这一步骤通常需要我们在相应…...

(vue3)大事记管理系统 文章管理页

[element-plus进阶] 文章列表渲染&#xff08;带搜索&到分页&#xff09; 表单架设&#xff1a;当前el-form标签配置一个inline属性&#xff0c;里面的元素就会在一行显示了 中英国际化处理&#xff1a;App.vue中el-config-provider标签包裹组件&#xff0c;意味着整个组…...

springboot 使用RocketMQ客户端生产消费消息DEMO

创建springboot项目省略 项目依赖 注意&#xff1a;当前客户端版本是 5.1.3 &#xff0c;安装的rocketmq服务的版本要与其对应 <properties><java.version>11</java.version><rocketmq-client-java-version>5.1.3</rocketmq-client-java-version&…...

第三章 内存管理 四、连续分配管理方式

目录 一、内存空间的分配与回收 1、连续分配管理方式 &#xff08;1&#xff09;、单一连续分配 优点&#xff1a; 缺点&#xff1a; &#xff08;2&#xff09;、固定分区分配 分区大小相等&#xff1a; 分区大小不等&#xff1a; &#xff08;3&#xff09;、动态分区…...

npm install报--4048错误和ERR_SOCKET_TIMEOUT问题解决方法之一

一、问题描述 学习vue数字大屏加载动漫效果时&#xff0c;在项目终端页面输入全局下载指令 npm install -g json-server 问题1、报--4048错误 会报如下错误 operation not permitted......errno: -4048code:EPERMsyscall: mkdir......The operation was reiected by your op…...

合并两个有序数组

给你两个按 非递减顺序 排列的整数数组 nums1 和 nums2&#xff0c;另有两个整数 m 和 n &#xff0c;分别表示 nums1 和 nums2 中的元素数目。 请你 合并 nums2 到 nums1 中&#xff0c;使合并后的数组同样按 非递减顺序 排列。 注意&#xff1a;最终&#xff0c;合并后数组…...

自动泊车系统设计学习笔记

1 概述 1.1 自动泊车系统研究现状 目前对于自动泊车系统的研究方法通常有两种实现方式&#xff1a; 整个泊车操作可以分为四个阶段&#xff1a;第一阶段车辆向前行驶进行车位识别&#xff0c;第二阶段车辆行驶到准备泊车时的待泊车区域&#xff0c;第三阶段车辆按照规划好的…...

基于Java的家电销售网站管理系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言具体实现截图论文参考论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09;有保障的售后福利 代码参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域…...

设计模式~备忘录模式(memento)-22

目录  (1)优点&#xff1a; (2)缺点&#xff1a; (3)使用场景&#xff1a; (4)注意事项&#xff1a; (5)应用实例&#xff1a; 代码 备忘录模式(memento) 备忘录模式&#xff08;Memento Pattern&#xff09;保存一个对象的某个状态&#xff0c;以便在适当的时候恢复对…...

【Agora UID 踩坑记录 Java 数据类型】

目录 负数二进制表示Java中32位无符号数的取法项目踩坑记录Java 0xffffffff隐式类型转换的坑 负数二进制表示 由于计算机中数据都以二进制表示&#xff0c;而负数的二级制是根据正数二进制取补码&#xff08;补码就是先取反码&#xff0c;然后加1&#xff09;得到&#xff0c;…...

ESP8285 RTOS SDK OTA

一、官方资源说明 官方指南&#xff1a;空中升级 (OTA) - ESP32 - — ESP-IDF 编程指南 v4.3.6 文档&#xff0c;虽然是正对ESP32的&#xff0c;但是原理是一样的。 官方参考例程&#xff1a;esp-idf\ESP8266_RTOS_SDK\examples\system\ota\&#xff0c;其中包含两个例程&…...

Hadoop3教程(四):HDFS的读写流程及节点距离计算

文章目录 &#xff08;55&#xff09;HDFS 写数据流程&#xff08;56&#xff09; 节点距离计算&#xff08;57&#xff09;机架感知&#xff08;副本存储节点选择&#xff09;&#xff08;58&#xff09;HDFS 读数据流程参考文献 &#xff08;55&#xff09;HDFS 写数据流程 …...

[0xGameCTF 2023] web题解

文章目录 [Week 1]signinbaby_phphello_httprepo_leakping [Week 2]ez_sqli方法一&#xff08;十六进制绕过&#xff09;方法二&#xff08;字符串拼接&#xff09; ez_upload [Week 1] signin 打开题目&#xff0c;查看下js代码 在main.js里找到flag baby_php <?php /…...

AI-调查研究-01-正念冥想有用吗?对健康的影响及科学指南

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; &#x1f680; AI篇持续更新中&#xff01;&#xff08;长期更新&#xff09; 目前2025年06月05日更新到&#xff1a; AI炼丹日志-28 - Aud…...

设计模式和设计原则回顾

设计模式和设计原则回顾 23种设计模式是设计原则的完美体现,设计原则设计原则是设计模式的理论基石, 设计模式 在经典的设计模式分类中(如《设计模式:可复用面向对象软件的基础》一书中),总共有23种设计模式,分为三大类: 一、创建型模式(5种) 1. 单例模式(Sing…...

springboot 百货中心供应链管理系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;百货中心供应链管理系统被用户普遍使用&#xff0c;为方…...

云启出海,智联未来|阿里云网络「企业出海」系列客户沙龙上海站圆满落地

借阿里云中企出海大会的东风&#xff0c;以**「云启出海&#xff0c;智联未来&#xff5c;打造安全可靠的出海云网络引擎」为主题的阿里云企业出海客户沙龙云网络&安全专场于5.28日下午在上海顺利举办&#xff0c;现场吸引了来自携程、小红书、米哈游、哔哩哔哩、波克城市、…...

【项目实战】通过多模态+LangGraph实现PPT生成助手

PPT自动生成系统 基于LangGraph的PPT自动生成系统&#xff0c;可以将Markdown文档自动转换为PPT演示文稿。 功能特点 Markdown解析&#xff1a;自动解析Markdown文档结构PPT模板分析&#xff1a;分析PPT模板的布局和风格智能布局决策&#xff1a;匹配内容与合适的PPT布局自动…...

Spring Boot+Neo4j知识图谱实战:3步搭建智能关系网络!

一、引言 在数据驱动的背景下&#xff0c;知识图谱凭借其高效的信息组织能力&#xff0c;正逐步成为各行业应用的关键技术。本文聚焦 Spring Boot与Neo4j图数据库的技术结合&#xff0c;探讨知识图谱开发的实现细节&#xff0c;帮助读者掌握该技术栈在实际项目中的落地方法。 …...

工业自动化时代的精准装配革新:迁移科技3D视觉系统如何重塑机器人定位装配

AI3D视觉的工业赋能者 迁移科技成立于2017年&#xff0c;作为行业领先的3D工业相机及视觉系统供应商&#xff0c;累计完成数亿元融资。其核心技术覆盖硬件设计、算法优化及软件集成&#xff0c;通过稳定、易用、高回报的AI3D视觉系统&#xff0c;为汽车、新能源、金属制造等行…...

基于Java+MySQL实现(GUI)客户管理系统

客户资料管理系统的设计与实现 第一章 需求分析 1.1 需求总体介绍 本项目为了方便维护客户信息为了方便维护客户信息&#xff0c;对客户进行统一管理&#xff0c;可以把所有客户信息录入系统&#xff0c;进行维护和统计功能。可通过文件的方式保存相关录入数据&#xff0c;对…...

排序算法总结(C++)

目录 一、稳定性二、排序算法选择、冒泡、插入排序归并排序随机快速排序堆排序基数排序计数排序 三、总结 一、稳定性 排序算法的稳定性是指&#xff1a;同样大小的样本 **&#xff08;同样大小的数据&#xff09;**在排序之后不会改变原始的相对次序。 稳定性对基础类型对象…...

Linux中《基础IO》详细介绍

目录 理解"文件"狭义理解广义理解文件操作的归类认知系统角度文件类别 回顾C文件接口打开文件写文件读文件稍作修改&#xff0c;实现简单cat命令 输出信息到显示器&#xff0c;你有哪些方法stdin & stdout & stderr打开文件的方式 系统⽂件I/O⼀种传递标志位…...