当前位置: 首页 > news >正文

卷积神经网络(CNN):乳腺癌识别.ipynb

文章目录

  • 一、前言
  • 一、设置GPU
  • 二、导入数据
    • 1. 导入数据
    • 2. 检查数据
    • 3. 配置数据集
    • 4. 数据可视化
  • 三、构建模型
  • 四、编译
  • 五、训练模型
  • 六、评估模型
    • 1. Accuracy与Loss图
    • 2. 混淆矩阵
    • 3. 各项指标评估

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码

来自专栏:机器学习与深度学习算法推荐

一、设置GPU

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")import matplotlib.pyplot as plt
import os,PIL,pathlib
import numpy as np
import pandas as pd
import warnings
from tensorflow import keraswarnings.filterwarnings("ignore")             #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号

二、导入数据

1. 导入数据

import pathlibdata_dir = "./32-data"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)
图片总数为: 13403
batch_size = 16
img_height = 50
img_width  = 50
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
Found 13403 files belonging to 2 classes.
Using 10723 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
Found 13403 files belonging to 2 classes.
Using 2680 files for validation.
class_names = train_ds.class_names
print(class_names)
['0', '1']

2. 检查数据

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break
(16, 50, 50, 3)
(16,)

3. 配置数据集

AUTOTUNE = tf.data.AUTOTUNEdef train_preprocessing(image,label):return (image/255.0,label)train_ds = (train_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)val_ds = (val_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)

4. 数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")class_names = ["乳腺癌细胞","正常细胞"]for images, labels in train_ds.take(1):for i in range(15):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)# 显示图片plt.imshow(images[i])# 显示标签plt.xlabel(class_names[labels[i]-1])plt.show()

在这里插入图片描述

三、构建模型

import tensorflow as tfmodel = tf.keras.Sequential([tf.keras.layers.Conv2D(filters=16,kernel_size=(3,3),padding="same",activation="relu",input_shape=[img_width, img_height, 3]),tf.keras.layers.Conv2D(filters=16,kernel_size=(3,3),padding="same",activation="relu"),tf.keras.layers.MaxPooling2D((2,2)),tf.keras.layers.Dropout(0.5),tf.keras.layers.Conv2D(filters=16,kernel_size=(3,3),padding="same",activation="relu"),tf.keras.layers.MaxPooling2D((2,2)),tf.keras.layers.Conv2D(filters=16,kernel_size=(3,3),padding="same",activation="relu"),tf.keras.layers.MaxPooling2D((2,2)),tf.keras.layers.Flatten(),tf.keras.layers.Dense(2, activation="softmax")
])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 50, 50, 16)        448       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 50, 50, 16)        2320      
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 25, 25, 16)        0         
_________________________________________________________________
dropout (Dropout)            (None, 25, 25, 16)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 25, 25, 16)        2320      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 12, 12, 16)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 12, 12, 16)        2320      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 6, 6, 16)          0         
_________________________________________________________________
flatten (Flatten)            (None, 576)               0         
_________________________________________________________________
dense (Dense)                (None, 2)                 1154      
=================================================================
Total params: 8,562
Trainable params: 8,562
Non-trainable params: 0
_________________________________________________________________

四、编译

model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])

五、训练模型

from tensorflow.keras.callbacks import ModelCheckpoint, Callback, EarlyStopping, ReduceLROnPlateau, LearningRateSchedulerNO_EPOCHS = 100
PATIENCE  = 5
VERBOSE   = 1# 设置动态学习率
annealer = LearningRateScheduler(lambda x: 1e-3 * 0.99 ** (x+NO_EPOCHS))# 设置早停
earlystopper = EarlyStopping(monitor='loss', patience=PATIENCE, verbose=VERBOSE)# 
checkpointer = ModelCheckpoint('best_model.h5',monitor='val_accuracy',verbose=VERBOSE,save_best_only=True,save_weights_only=True)
train_model  = model.fit(train_ds,epochs=NO_EPOCHS,verbose=1,validation_data=val_ds,callbacks=[earlystopper, checkpointer, annealer])

六、评估模型

1. Accuracy与Loss图

acc = train_model.history['accuracy']
val_acc = train_model.history['val_accuracy']loss = train_model.history['loss']
val_loss = train_model.history['val_loss']epochs_range = range(len(acc))plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

2. 混淆矩阵

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):# 生成混淆矩阵conf_numpy = confusion_matrix(labels, predictions)# 将矩阵转化为 DataFrameconf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  plt.figure(figsize=(8,7))sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")plt.title('混淆矩阵',fontsize=15)plt.ylabel('真实值',fontsize=14)plt.xlabel('预测值',fontsize=14)
val_pre   = []
val_label = []for images, labels in val_ds:#这里可以取部分验证数据(.take(1))生成混淆矩阵for image, label in zip(images, labels):# 需要给图片增加一个维度img_array = tf.expand_dims(image, 0) # 使用模型预测图片中的人物prediction = model.predict(img_array)val_pre.append(class_names[np.argmax(prediction)])val_label.append(class_names[label])
plot_cm(val_label, val_pre)

3. 各项指标评估

from sklearn import metricsdef test_accuracy_report(model):print(metrics.classification_report(val_label, val_pre, target_names=class_names)) score = model.evaluate(val_ds, verbose=0)print('Loss function: %s, accuracy:' % score[0], score[1])test_accuracy_report(model)
             precision    recall  f1-score   support乳腺癌细胞       0.92      0.90      0.91      1339正常细胞       0.91      0.92      0.91      1341accuracy                           0.91      2680macro avg       0.91      0.91      0.91      2680
weighted avg       0.91      0.91      0.91      2680Loss function: 0.22688131034374237, accuracy: 0.9138059616088867

pport

   乳腺癌细胞       0.92      0.90      0.91      1339正常细胞       0.91      0.92      0.91      1341accuracy                           0.91      2680

macro avg 0.91 0.91 0.91 2680
weighted avg 0.91 0.91 0.91 2680

Loss function: 0.22688131034374237, accuracy: 0.9138059616088867


相关文章:

卷积神经网络(CNN):乳腺癌识别.ipynb

文章目录 一、前言一、设置GPU二、导入数据1. 导入数据2. 检查数据3. 配置数据集4. 数据可视化 三、构建模型四、编译五、训练模型六、评估模型1. Accuracy与Loss图2. 混淆矩阵3. 各项指标评估 一、前言 我的环境: 语言环境:Python3.6.5编译器&#xf…...

有文件实体的后门无文件实体的后门rootkit后门

有文件实体后门和无文件实体后门&RootKit后门 什么是有文件的实体后门: 在传统的webshell当中,后门代码都是可以精确定位到某一个文件上去的,你可以rm删除它,可以鼠标右键操作它,它是有一个文件实体对象存在的。…...

GPT实战系列-大模型训练和预测,如何加速、降低显存

GPT实战系列-大模型训练和预测,如何加速、降低显存 不做特别处理,深度学习默认参数精度为浮点32位精度(FP32)。大模型参数庞大,10-1000B级别,如果不注意优化,既耗费大量的显卡资源,…...

SQL Sever 基础知识 - 数据排序

SQL Sever 基础知识 - 二 、数据排序 二 、对数据进行排序第1节 ORDER BY 子句简介第2节 ORDER BY 子句示例2.1 按一列升序对结果集进行排序2.2 按一列降序对结果集进行排序2.3 按多列对结果集排序2.4 按多列对结果集不同排序2.5 按不在选择列表中的列对结果集进行排序2.6 按表…...

vscode配置使用 cpplint

标题安装clang-format和cpplint sudo apt-get install clang-format sudo pip3 install cpplint标题以下settings.json文件放置xxx/Code/User目录 settings.json {"sync.forceDownload": false,"workbench.sideBar.location": "right","…...

C++ 系列 第四篇 C++ 数据类型上篇—基本类型

系列文章 C 系列 前篇 为什么学习C 及学习计划-CSDN博客 C 系列 第一篇 开发环境搭建(WSL 方向)-CSDN博客 C 系列 第二篇 你真的了解C吗?本篇带你走进C的世界-CSDN博客 C 系列 第三篇 C程序的基本结构-CSDN博客 前言 面向对象编程(OOP)的…...

C++ 指针详解

目录 一、指针概述 指针的定义 指针的大小 指针的解引用 野指针 指针未初始化 指针越界访问 指针运算 二级指针 指针与数组 二、字符指针 三、指针数组 四、数组指针 函数指针 函数指针数组 指向函数指针数组的指针 回调函数 指针与数组 一维数组 字符数组…...

.locked、locked1勒索病毒的最新威胁:如何恢复您的数据?

导言: 网络安全问题变得愈加严峻。.locked、locked1勒索病毒是近期备受关注的一种恶意软件,给用户的数据带来了巨大威胁。本文将深入探讨.locked、locked1勒索病毒的特征,探讨如何有效恢复被其加密的数据,并提供一些建议&#xf…...

Apache Sqoop使用

1. Sqoop介绍 Apache Sqoop 是在 Hadoop 生态体系和 RDBMS 体系之间传送数据的一种工具。 Sqoop 工作机制是将导入或导出命令翻译成 mapreduce 程序来实现。在翻译出的 mapreduce 中主要是对 inputformat 和 outputformat 进行定制。 Hadoop 生态系统包括:HDFS、Hi…...

【UGUI】实现UGUI背包系统的六个主要交互功能

在这篇教程中,我们将详细介绍如何在Unity中实现一个背包系统的六个主要功能:添加物品、删除物品、查看物品信息、排序物品、搜索物品和使用物品。让我们开始吧! 一、添加物品 首先,我们需要创建一个方法来添加新的物品到背包中。…...

电压驻波比

电压驻波比 关于IF端口的电压驻波比 一个信号变频后,从中频端口输出,它的输出跟输入是互异的。这个电压柱波比反映了它输出的能量有多少可以真正的输送到后端连接的器件或者设备。...

Open3D 最小二乘拟合二维直线(直接求解法)

目录 一、算法原理二、代码实现三、结果展示本文由CSDN点云侠原创,原文链接。爬虫网站自重。 一、算法原理 平面直线的表达式为: y = k x + b...

面试题目总结(二)

1. IoC 和 AOP 的区别 控制反转(Ioc) 和面向切面编程(AOP) 是两个不同的概念,它们在软件设计中有着不同的应用和目的。 IoC 是一种基于对象组合的编程模式,通过将对象的创建、依赖关系和生命周期等管理权交给外部容器或框架来实现程序间的解耦。IoC 的…...

TrustZone概述

目录 一、概述 1.1 在开始之前 二、什么是TrustZone? 2.1 Armv8-M的TrustZone 2.2 Armv9-A Realm Management Ext...

[go 面试] Go Kit中读取原始HTTP请求体的方法

关注公众号【爱发白日梦的后端】分享技术干货、读书笔记、开源项目、实战经验、高效开发工具等,您的关注将是我的更新动力! 在Go Kit中,如果你想读取未序列化的HTTP请求体,可以使用标准的net/http包来实现。以下是一个示例,演示了如何完成这个任务: package mainimport …...

小程序如何刷新当前页面?

在小程序中,刷新当前页面通常有两种方法: 使用 wx.navigateBack 方法: wx.navigateBack({delta: 1 }) 这将返回上一页,并刷新页面。你可以通过调整 delta 参数来控制返回的页面数。例如,如果你想要返回到两页之前的页…...

ChatGPT使用路径:从新手到专家的指南

原文&精华文章&转载注明:ChatGPT与日本首相交流核废水事件-精准Prompt... hello,我是小索奇,有任何问题或者需要帮助的都可以在这里找到我或者留言哈 一、初识ChatGPT 什么是ChatGPT? ChatGPT是一种大型语言模型&…...

VsCode 调试 MySQL 源码

1. 启动 MySQL 2. 查看 MySQL 进程号 [root ~]# ps -ef | grep mysqld root 21479 1 0 Nov01 ? 00:00:00 /bin/sh /usr/local/mysql/bin/mysqld_safe --datadir/usr/local/mysql/data --pid-file/usr/local/mysql/data/mysqld.pid root 26622 21479 0 …...

Mysql中的正经行锁、间隙锁和临键锁

行锁、间隙锁和临键锁是数据库中的三种不同类型的锁,三者都属于行锁,第一个一般叫他正经的行锁(《Mysql是怎样运行的》一书中的说法)。 行锁(Row Lock):行锁是指对数据表中的某一行进行的锁定操…...

最强AI之风袭来,你爱了吗?

2017年,柯洁同阿尔法狗人机大战,AlphaGo以3比0大获全胜,一代英才泪洒当场...... 2019年,换脸哥视频“杨幂换朱茵”轰动全网,时至今日AI换脸仍热度只增不减; 2022年,ChatGPT一经发布便轰动全球&a…...

时间序列预测实战(二十三)进阶版LSTM多元和单元预测(课程设计毕业设计首选)

一、本文介绍 本篇文章给大家带来的是利用我个人编写的架构进行LSTM模型进行时间序列建模(专门为了时间序列领域新人编写的架构,简单且不同于市面上大家用GPT写的代码),包括结果可视化、支持单元预测、多元预测、模型拟合效果检测…...

Python之Appium 2自动化测试(Android篇)

一、环境搭建及准备工作 1、Appium 2 环境搭建 请参考另一篇文章: Windows系统搭建Appium 2 和 Appium Inspector 环境 2、安装 Appium-Python-Client,版本要求3.0及以上 pip install Appium-Python-ClientVersion: 3.1.03、手机连接电脑,并在dos窗口…...

chromium通信系统-ipcz系统(四)-ipcz-分层、和mojo的关系以及handle

在只有mojo的情况下, 进程间通信都是靠unix 域套接字来完成了,由于这种方式比较低效,并且不够灵活,后来引入了ipcz。 但是系统中基本上使用mojo做进程间通信,想要一步到位迁移到ipcz系统是比较困难的。 所以chrome团队…...

推荐一些研发人员经常用到的免费API接口

快递物流订阅与推送(含物流轨迹):【物流订阅与推送、H5物流轨迹、单号识别】支持单号的订阅与推送,订阅国内物流信息,当信息有变化时,推送到您的回调地址。地图轨迹支持在地图中展示包裹运输轨迹。包括顺丰…...

高薪资是跳出来的,好工作是面出来的~

听人劝、吃饱饭,奉劝各位小伙伴,不要订阅该文所属专栏。 如需要项目实战或者是体系化资源,文末名片加V! 作者:哈哥撩编程,工作十余年, 从事过全栈研发、产品经理等工作,目前在公司担任研发部门CTO。荣誉:2022年度博客之星Top4、2023年度超级个体得主、谷歌与亚马逊开发…...

记QListWidget中QPushButton QSS样式失效的“bug”

一、场景 有一个QListWidget的列表;里面存放了若干QListWidgetItem;每个QListWidgetItem与一个自定义类对象绑定——通过QListWidget的setItemWidget()实现。自定义对象继承于QWidget,且内含QPushButton。 二、bug描述 在该QListWidget的外…...

python提取通话记录中的时间信息

您需要安装适合中文的SpaCy模型。您可以通过运行 pip install spacypython -m spacy download zh_core_web_sm来安装和下载所需的模型。 import spacy# 加载中文模型 nlp spacy.load(zh_core_web_sm)# 示例电话记录文本 text """ Agent: 今天我们解决一下这…...

DSShop移动商城网店系统 反序列化RCE漏洞复现

0x01 产品简介 DSShop是长沙德尚网络科技有限公司推出的一款单店铺移动商城网店系统,能够帮助企业和个人快速构建手机移动商城,并减少二次开发带来的成本。 以其丰富的营销功能,精细化的用户运营,解决电商引流、推广难题,帮助企业打造生态级B2C盈利模式商业平台。完备的电商…...

docker搭建node环境开发服务器

docker搭建node环境开发服务器 本文章是我自己搭建node环境开发服务器的过程记录,不一定完全适用所有人。根据个人情况,按需取用。 命名项目路径 为了方便cd到项目路径,将项目路径重命名,方便输入。 vim /etc/profile # 修改p…...

传统制造业企业如何实现数字化转型?

传统制造企业的数字化转型涉及利用数字技术来提高效率、生产力和整体业务流程。以下是实现制造业数字化转型的关键步骤和策略: 1.当前流程的评估: 确定可以从数字化转型中受益的领域。这可能包括生产流程、供应链管理、库存控制和客户关系。 评估技术集…...

网站打开出现建设中/搜索引擎优化的简称是

虚拟化由于其带来的维护费用的大幅降低而受到追捧,如能减少服务器占用空间,降低购买软硬件设备的成本,大幅度提高系统的利用率。然而对其安全问题,人们也一直在争论不休,一方观点认为虚拟化技术能有效提升系统的安全性…...

网站关键词代码怎么做/成人营销管理培训班

Android应用程序均用Java开发,通过google的指导下,实现并总结了apk文件反编译过程,不难,需要相应的工具即可。 一、Apk反编译得到Java源代码 下载上述反编译工具包,打开apk2java目录下的dex2jar-0.0.9.9文件夹&#xf…...

WordPress 如何修改底部栏内容/四川旅游seo整站优化

在本文中,192路由网将给大家详细介绍,设置水星(MERCURY)MAC1300R无线路由器上网的方法。新买的或者恢复出厂设置后的水星MAC1300R路由器,要设置它连接宽带上网,请按照下面的步骤进行操作:1、路由器线路连接2、设置电脑…...

网站的新闻模块怎么做/网站内容seo

I just dont wanna give them any more ammunition than they already have.---《老友记》 第一季 第一集 我只是不想,让他们有藉题发挥的机会。 名词 n. [U] 1. 弹药,军火The ammunition depot is heavily guarded. 弹药库戒备森严。 2. 【喻】"炮弹"(指…...

有没有做外贸的网站啊/重庆最新数据消息

“密码”选项是BetterZip解压缩软件的解压密码管理器,其作用是管理解压密码以及在帮助用户使用密码解压压缩软件的。 要使用“密码”选项设置密码管理器,首先需要设置密码管理器主密码。 设置密码管理器主密码 快捷键“Command ,”打开首选…...

大型门户网站源码/seo网站优化方

对于二进制表示的float类型的2.5,其在内存中的表示为01000000 00100000 00000000 00000000,如果我们想打印出它在内存中是如何表示的,那么我们可以用1进行移位,与每个比特进行与运算,还是看看代码吧: 对…...