当前位置: 首页 > news >正文

GoogleNet网络训练集和测试集搭建

测试集和训练集都是在之前搭建好的基础上进行修改的,重点记录与之前不同的代码。

还是使用的花分类的数据集进行训练和测试的。

一、训练集

1、搭建网络

设置参数:使用辅助分类器,采用权重初始化

net = GoogleNet(num_classes=5, aux_logits=True, init_weights=True)

2、参数输出

之前的模型只有 1 个输出,但由于GoogleNet使用了两个辅助分类器,所以会有 3 个输出。

定义三个输出,分别计算主分类器、辅助分类器1、辅助分类器2的损失函数并相加,最后将损失函数反向传播,使用优化器更新参数模型。 

不单独放代码了,不知道哪里是改动的。图片中红色框中是改动的

整个训练集的代码

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib as plt
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import GoogleNet
import os
import json
import timedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = data_root + "/data_set/flower_data"
# train set
train_dataset = datasets.ImageFolder(root=image_path + "/train",transform=data_transform["train"])
train_num = len(train_dataset)# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# 把文件写入接送文件
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices,json', 'w') as json_file:json_file.write(json_str)batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)
#
validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size,shuffle=False, num_workers=0)# test_data_iter = iter(validate_loader)
# test_image, test_label = next(test_data_iter)
#
# # 查看图片
# def imshow(img):
#     img = img / 2 + 0.5
#     nping = img.numpy()
#     plt.imshow(np.transpose(nping, (1, 2, 0)))
#     plt.show()
# # print labels
# print(' '.join('%5s' % str(cla_dict[test_label[j].item()]) for j in range(4)))
# # show images
# imshow(utils.make_grid(test_image))net = GoogleNet(num_classes=5, aux_logits=True, init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0003)best_acc = 0.0
save_path = './GoogleNet.pth'
# best_acc = 0.0
for epoch in range(2):# trainnet.train()running_loss = 0.0t1 = time.perf_counter()for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()logits, aux_logits2, aux_logits1 = net(images.to(device))loss0 = loss_function(logits, labels.to(device))loss1 = loss_function(aux_logits1, labels.to(device))loss2 = loss_function(aux_logits2, labels.to(device))loss = loss0 + loss1 * 0.3 + loss2 * 0.3loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()rate = (step+1) / len(train_loader)a = "*" * int(rate*50)b = "." *int((1-rate)*50)print("\rtrain loss: (:3.0f)%[()->:.3f)".format(int(rate * 100), a, b, loss), end="")print()print(time.perf_counter()-t1)net.eval()acc = 0.0with torch.no_grad():for data_test in validate_loader:test_images, test_labels = data_testoutputs = net(test_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == test_labels.to(device)).sum().item()accurate_test = acc / val_numif accurate_test > best_acc:best_acc = accurate_testtorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, running_loss / step, acc / val_num))
print("Finished Training")

训练完成 

 中间有几次报错,不过在看懂报错后很快改过来了。

二、测试集

载入模型

在创建模型的时候,aux_logits不会构建辅助分类器,但是之前训练的参数会保存。

所以,在载入模型的时候,要设置参数strict=False, 它可以精准匹配当前模型与所需要载入的权重模型的结构。

辅助分类器中的参数全部存放在unexpecte_keys中。

测试集全部代码

 可以自己找图片进行预测看准确率。

import torch
import matplotlib.pyplot as plt
import json
from model import GoogleNet
from PIL import Image
from torchvision import transformsdata_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load image
img = Image.open("8.jpeg")
plt.imshow(img)
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)# read class_indent
try:json_file = open('./class_indices,json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = GoogleNet(num_classes=5, aux_logits=False)
model_weight_path = "./GoogleNet.pth"
missing_keys, unexpected_keys = model.load_state_dict(torch.load(model_weight_path), strict=False)
model.eval()
with torch.no_grad():output = torch.squeeze(model(img))predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

准确率好低,可能是模型训练的还不够吧。

相关文章:

GoogleNet网络训练集和测试集搭建

测试集和训练集都是在之前搭建好的基础上进行修改的,重点记录与之前不同的代码。 还是使用的花分类的数据集进行训练和测试的。 一、训练集 1、搭建网络 设置参数:使用辅助分类器,采用权重初始化 net GoogleNet(num_classes5, aux_logi…...

将数字状态码在后台转换为中文状态

这是我们的实体类 可以看出我们的状态status是2如过返回到前端我们根本不知道2代表的是什么,所以我们需要再这里将数字转换成能看懂的中文状态,首先我们创建一个枚举类 先将我们状态码所对应的中文状态枚举出来,然后创建一个静态方法&#…...

2017NOIP普及组真题 4. 跳房子

线上OJ: 一本通:http://ybt.ssoier.cn:8088/problem_show.php?pid1417\ 核心思想 首先、本题中提到 “ 至少 要花多少金币改造机器人,能获得 至少 k分 ”。看到这样的话语,基本可以考虑要使用 二分答案。 那么,本题中…...

网络与 Internet因特网的基本概念

目录 网络Internet (互联网或互连网)Internet(因特网)待续、更新中 网络 指将分布在不同地理位置的、相同或不同类型的网络通过网络互连设备(中继器、网桥、路由器或网关等)相互连接,形成一个范…...

vue-router 中 router-link 与 a 标签的区别

文章目录 前言 a标签定义 router-link定义 总结 前言 vue-router 中 router-link 与 a 标签的区别 a标签定义 <a> 标签定义超链接&#xff0c;用于从一张页面链接到另一张页面。 从一张页面跳转到另一张页面&#xff0c;但从这里来说就违背了多视图的单页Web应用这个…...

MySQL基础知识——MySQL事务

事务背景 什么是事务&#xff1f; 一组由一个或多个数据库操作组成的操作组&#xff0c;能够原子的执行&#xff0c;且事务间相互独立&#xff1b; 简单来说&#xff0c;事务就是要保证一组数据库操作&#xff0c;要么全部成功&#xff0c;要么全部失败。 注&#xff1a;MyS…...

【架构方法论(一)】架构的定义与架构要解决的问题

文章目录 一. 架构定义与架构的作用1. 系统与子系统2. 模块与组件3. 框架与架构4. 重新定义架构&#xff1a;4R 架构 二、架构设计的真正目的-别掉入架构设计的误区1. 是为了解决软件复杂度2. 简单的复杂度分析案例 三. 案例思考 本文关键字 架构定义 架构与系统的关系从业务逻…...

基于springboot实现人口老龄化社区服务与管理系统项目【项目源码+论文说明】计算机毕业设计

基于springboot实现人口老龄化社区服务与管理系统演示 摘要 随着信息技术在管理上越来越深入而广泛的应用&#xff0c;管理信息系统的实施在技术上已逐步成熟。本文介绍了人口老龄化社区服务与管理平台的开发全过程。通过分析人口老龄化社区服务与管理平台方面的不足&#xff…...

代码随想录算法训练营第三十七天| LeetCode 738.单调递增的数字、总结

一、LeetCode 738.单调递增的数字 题目链接/文章讲解/视频讲解&#xff1a;https://programmercarl.com/0738.%E5%8D%95%E8%B0%83%E9%80%92%E5%A2%9E%E7%9A%84%E6%95%B0%E5%AD%97.html 状态&#xff1a;已解决 1.思路 如何求得小于等于N的最大单调递增的整数&#xff1f;98&am…...

C++动态内存管理 解剖new/delete详细讲解(operator new,operator delete)

讨厌抄我作业和不让我抄作业的人 讨厌插队和不让我插队的人 讨厌用我东西和不让我用东西的人 讨厌借我钱和不借给我钱的人 讨厌开车加塞和不让我加塞的人 讨厌内卷和打扰我内卷的人 一、C中动态内存管理 1.new和delete操作内置类型 2.new和delete操作自定义类型 二、operat…...

python-re正则笔记0.2.0

1. 匹配linux文件路径 from re import match, search,findall str"sh refreshConfig.sh /opt/client/ccc.txt /opt/client/ccc.dfs 胜多负少的"patter1"\/.\.\w" print(findall(patter1, str))""" [/opt/client/ccc.txt /opt/client/ccc…...

.NET SignalR Redis实时Web应用

环境 Win10 VS2022 .NET8 Docker Redis 前言 什么是 SignalR&#xff1f; ASP.NET Core SignalR 是一个开放源代码库&#xff0c;可用于简化向应用添加实时 Web 功能。 实时 Web 功能使服务器端代码能够将内容推送到客户端。 适合 SignalR 的候选项&#xff1a; 需要从服…...

【热门话题】常见分类算法解析

&#x1f308;个人主页: 鑫宝Code &#x1f525;热门专栏: 闲话杂谈&#xff5c; 炫酷HTML | JavaScript基础 ​&#x1f4ab;个人格言: "如无必要&#xff0c;勿增实体" 文章目录 常见分类算法解析1. 逻辑回归&#xff08;Logistic Regression&#xff09;2. 朴…...

有效利用MRP能为中小企业带来什么?

在离散制造企业&#xff0c;主流的生产模式主要为面向订单生产和面向库存生产&#xff08;又称为预测生产&#xff09;&#xff0c;在中小企业中&#xff0c;一般为面向订单生产&#xff0c;也有部分面向库存和面向订单混合的生产方式&#xff08;以面向订单为主&#xff0c;面…...

InternlM2

第一次作业 基础作业 进阶作业 1. hugging face下载 2. 部署 首先&#xff0c;从github上git clone仓库 https://github.com/InternLM/InternLM-XComposer.git然后里面的指引安装环境...

2024-12.python高级语法

异常处理 首先我们要理解什么叫做**"异常”**&#xff1f; 在程序运行过程中&#xff0c;总会遇到各种各样的问题和错误。有些错误是我们编写代码时自己造成的&#xff1a; 比如语法错误、调用错误&#xff0c;甚至逻辑错误。 还有一些错误&#xff0c;则是不可预料的错误…...

【C语言】贪吃蛇项目(1) - 部分Win32 API详解 及 贪吃蛇项目思路

文章目录 一、贪吃蛇项目需要实现的基本功能二、Win32 API介绍2.1 控制台2.2 部分控制台命令及调用函数mode 和 title 命令COORD 命令GetStdHandle&#xff08;获取数据&#xff09;GetConsoleCursorInfo&#xff08;获取光标数据&#xff09;SetConsoleCursorInfo &#xff08…...

秋叶Stable diffusion的创世工具安装-带安装包链接

来自B站up秋葉aaaki&#xff0c;近期发布了Stable Diffusion整合包v4.7版本&#xff0c;一键在本地部署Stable Diffusion&#xff01;&#xff01; 适用于零基础想要使用AI绘画的小伙伴~本整合包支持SDXL&#xff0c;预装多种必须模型。无需安装git、python、cuda等任何内容&am…...

华为ensp中aaa(3a)实现telnet远程连接认证配置命令

作者主页&#xff1a;点击&#xff01; ENSP专栏&#xff1a;点击&#xff01; 创作时间&#xff1a;2024年4月14日18点49分 AAA认证的全称是Authentication、Authorization、Accounting&#xff0c;中文意思是认证、授权、计费。 以下是详细解释 认证&#xff08;Authentic…...

前端网络---http协议和https协议的区别

http协议和https的区别 1、http是超文本传输协议&#xff0c;信息是明文传输&#xff0c;https则是具有安全性的ssl加密传输协议。 2、http和https使用的端口不一样&#xff0c;http是80&#xff0c;https是443。 3、http的连接很简单&#xff0c;是无状态的&#xff08;可以…...

FactoryMethod工厂方法模式详解

目录 模式定义实现方式简单工厂工厂方法主要优点 应用场景源码中的应用 模式定义 定义一个用于创建对象的接口&#xff0c;让子类决定实例化哪一个类。 Factory Method 使得一个类的实例化延迟到子类。 实现方式 简单工厂 以下示例非设计模式&#xff0c;仅为编码的一种规…...

Java基础-知识点1(面试|学习)

Java基础-知识点1 Java与C、PythonJava &#xff1a;C&#xff1a;Python: java 与 C的异同相似之处&#xff1a;区别&#xff1a; Java8的新特性Lambda 表达式&#xff1a;Stream API&#xff1a;接口的默认方法和静态方法&#xff1a; 基本数据类型包装类自动装箱与自动拆箱自…...

【InternLM 实战营第二期-笔记1】书生浦语大模型开源体系详细介绍InternLM2技术报告解读(附相关论文)

书生浦语是上海人工智能实验室和商汤科技联合研发的一款大模型,很高兴能参与本次第二期训练营&#xff0c;我也将会通过笔记博客的方式记录学习的过程与遇到的问题&#xff0c;并为代码添加注释&#xff0c;希望可以帮助到你们。 记得点赞哟(๑ゝω╹๑) 书生浦语大模型开源体系…...

【免费】基于SOE算法的多时段随机配电网重构方法

1 主要内容 该程序是完全复现《Switch Opening and Exchange Method for Stochastic Distribution Network Reconfiguration》&#xff0c;也是一个开源代码&#xff0c;网上有些人卖的还挺贵&#xff0c;本次免费分享给大家&#xff0c;代码主要做的是一个通过配电网重构获取…...

Swift面向对象编程

类的定义与实例化&#xff1a; Swift中定义一个类使用class关键字&#xff0c;类的属性和方法都写在大括号内。示例代码如下&#xff1a; class MyClass {var property1: Intvar property2: Stringinit(property1: Int, property2: String) {self.property1 property1self.pr…...

IEDA 的各种常用插件汇总

目录 IEDA 的各种常用插件汇总1、 Alibaba Java Coding Guidelines2、Translation3、Rainbow Brackets4、MyBatisX5、MyBatis Log Free6、Lombok7、Gitee IEDA 的各种常用插件汇总 1、 Alibaba Java Coding Guidelines 作用&#xff1a;阿里巴巴代码规范检查插件&#xff0c;…...

浅谈C语言中异或运算符的10种妙用

目录 1、前言 2、基本准则定律 3、妙用归纳 4、总结 1、前言 C语言中异或运算符^作为一个基本的逻辑运算符&#xff0c;相信大家都知道其概念&#xff1a;通过对两个相同长度的二进制数进行逐位比较&#xff0c;若对应位的值不同&#xff0c;结果为 1, 否则结果为 0。 但是…...

Canal--->准备MySql主数据库---->安装canal

一、安装主数据库 1.在服务器新建文件夹 mysql/data&#xff0c;新建文件 mysql/conf.d/my.cnf 其中my.cnf 内容如下 [mysqld] log_timestampsSYSTEM default-time-zone8:00 server-id1 log-binmysql-bin binlog-do-db mall # 要监听的库 binlog_formatROW2.启动数据库 do…...

vs配置opencv运行时“发生生成错误,是否继续并运行上次的成功生成”BUG解决办法

vs“发生生成错误&#xff0c;是否继续并运行上次的成功生成” 新手在用vs配置opencv时遇到这个错误时&#xff0c;容易无从下手解决。博主亲身经历很有可能是release/debug模式和配置文件不符的问题。 在配置【链接器】→【输入】→【附加依赖项】环节&#xff0c;编辑查看选择…...

Dryad Girl Fawnia

一个可爱的Dryad Girl Fawnia的三维模型。她有ARKit混合形状,人形装备,多种颜色可供选择。她将是一个完美的角色,幻想或装扮游戏。 🔥 Dryad Girl | Fawnia 一个可爱的Dryad Girl Fawnia的三维模型。她有ARKit混合形状,人形装备,多种颜色可供选择。她将是一个完美的角色…...

怎么做自助交易网站/站长工具seo综合查询引流

回溯算法1、简介2、基本思想3、基本过程与求解步骤4、适用条件5、经典例题1、简介 \quad \quad回溯算法实际上是基于DFS(深度优先搜索)的一个类似枚举的搜索尝试过程&#xff0c;主要是在搜索尝试过程中寻找问题的解&#xff0c;当发现已不满足求解条件时&#xff0c;就“回溯”…...

wordpress page样式/大数据培训机构排名前十

学.net有一段时间了&#xff0c;也参与过老师的几个课题项目的开发工作&#xff0c;但却比较少接触Asp.net的UI设计&#xff0c;虽然VS2005推出快两年了&#xff0c;但VS2005中新增的一些控件还没有怎么用过&#xff0c;近来闲着没事&#xff0c;就来学着玩下&#xff0c;从MS的…...

网站怎么制作小程序/google首页

笔者最近开始对机器学习非常感兴趣&#xff0c;作为一个有志向的软设方向的女孩纸&#xff0c;我开始了学习的第一步入门&#xff0c;下面将今天刚刚学习的kNN及其应用进行总结和回顾&#xff0c;希望可以得到更好的提升&#xff0c;当然&#xff0c;有志同道合者&#xff0c;你…...

中国工商建设标准化协会网站/西安关键词seo公司

我已经创建了一个生成行星精灵的程序.我这样做是通过创建一个圆形路径,运行ctx.clip()来保持所有以下图层在圆圈内,然后绘制一个黑色和透明的纹理图层,然后在整个画布上随机着色的矩形,然后是阴影并在它上面发光.问题是裁剪后圆圈下面也会出现彩色线条,我不知道为什么.我需要删…...

印象网站建设/交换链接营销实现方式解读

objects对象所属类原理剖析&#xff1a; 我们通常做查询操作的时候&#xff0c;都是通过 模型名字.objects 的方式进行操作。其实 模型名字.objects 是一个 django.db.models.manager.Manager 对象&#xff0c;而 Manager 这个类是一个“空壳”的类&#xff0c;他本身是没有任何…...

广州 网站建设 制作/seo站长工具查询系统

好久没有写博客了&#xff0c;这些日子的确很忙。刚开始准备补考啊&#xff0c;实变&#xff0c;复变。实变算是考完了&#xff0c;复变下周才考&#xff0c;心里有些慌。工作的事情有些乱&#xff0c;不知道红帽那边是什么情况&#xff0c;我也搞不清楚了&#xff0c;就等着吧…...