python-pytorch基础之神经网络分类
这里写目录标题
- 生成数据函数
- 定义数据集
- 定义loader加载数据
- 定义神经网络模型
- 测试输出是否为2个
- 输入数据,输出结果
- 训练模型函数
- 计算正确率
- 训练数据并保存模型
- 测试模型
- 准备数据
- 加载模型预测
- 对比结果
生成数据函数
import random
def get_rectangle():width=random.random()height=random.random()# 如果width大于height就是1,否则就是0fat=int(width>=height)return width,height,fatget_rectangle()
(0.19267437580138802, 0.645061020860627, 0)
fat=int(0.8>=0.9)
fat
0
定义数据集
import torch
class Dataset(torch.utils.data.Dataset):def __init__(self):passdef __len__(self):return 1000def __getitem__(self,i):width,height,fat=get_rectangle()# 这里注意width和height列表形式再转为tensor,另外是floattensorx=torch.FloatTensor([width,height])y=fatreturn x,y
dataset=Dataset()
# 这里没执行一次都要变,有什么意义?
len(dataset),dataset[999]
(1000, (tensor([0.4756, 0.1713]), 1))
定义loader加载数据
# 定义了数据集,然后使用loader去加载,需要记住batch_size,shuffle,drop_last这个三个常用的
loader=torch.utils.data.DataLoader(dataset=dataset,batch_size=9,shuffle=True,drop_last=True)# 加载完了以后,就可以使用next去遍历了
len(loader),next(iter(loader))
(111,[tensor([[0.1897, 0.6766],[0.2460, 0.2725],[0.5871, 0.7739],[0.3035, 0.9607],[0.7006, 0.7421],[0.4279, 0.9501],[0.6750, 0.1704],[0.5777, 0.1154],[0.5512, 0.3933]]),tensor([0, 0, 0, 0, 0, 0, 1, 1, 1])])
# 输出结果:
# (111,
# [tensor([[0.3577, 0.3401],
# [0.0156, 0.7550],
# [0.0435, 0.4984],
# [0.1329, 0.5488],
# [0.4330, 0.5362],
# [0.1070, 0.8500],
# [0.1073, 0.2496],
# [0.1733, 0.0226],
# [0.6790, 0.2119]]),
# tensor([1, 0, 0, 0, 0, 0, 0, 1, 1])]) 这里是torch自动将y组合成了一个tensor
定义神经网络模型
class Model(torch.nn.Module):def __init__(self):super().__init__()# 定义神经网络结构self.fb=torch.nn.Sequential(# 第一层输入两个,输出32个torch.nn.Linear(in_features=2,out_features=32),# 激活层的意思是小于0的数字变为0,即负数归零操作torch.nn.ReLU(),# 第二层,输入是32个,输出也是32个torch.nn.Linear(in_features=32,out_features=32),# 激活torch.nn.ReLU(),# 第三次,输入32个,输出2个torch.nn.Linear(in_features=32,out_features=2),# 激活,生成的两个数字相加等于1torch.nn.Softmax(dim=1))# 定义网络计算过程def forward(self,x):return self.fb(x)model=Model()
测试输出是否为2个
# 测试 8行2列的数据,让模型测试,看是否最后输出是否也是8行一列?model(torch.rand(8,2)).shape
torch.Size([8, 2])
输入数据,输出结果
model(torch.tensor([[0.3577, 0.3401]]))
tensor([[0.5228, 0.4772]], grad_fn=<SoftmaxBackward>)
训练模型函数
def train():#定义优化器,le-4表示0.0004,10的负四次方opitimizer=torch.optim.Adam(model.parameters(),lr=1e-4)#定义损失函数 ,一般回归使用mseloss,分类使用celossloss_fn=torch.nn.CrossEntropyLoss()#然后trainmodel.train()# 全量数据遍历100轮for epoch in range(100):# 遍历loader的数据,循环一次取9条数据,这里注意unpack时,x就是【width,height】for i,(x,y) in enumerate(loader):out=model(x)# 计算损失loss=loss_fn(out,y)# 计算损失的梯度loss.backward()# 优化参数opitimizer.step()# 梯度清零opitimizer.zero_grad()# 第二十轮的时候打印一下数据if epoch % 20 ==0:# 正确率# out.argmax(dim=1) 表示哪个值大就说明偏移哪一类,dim=1暂时可以看做是固定的# (out.argmax(dim=1)==y).sum() 表示的true的个数acc=(out.argmax(dim=1)==y).sum().item()/len(y)print(epoch,loss.item(),acc)torch.save(model,"4.model")
计算正确率
执行的命令:
print(out,“======”,y,“+++++”,(out.argmax(dim=1)==y).sum().item())
print(out.argmax(dim=1),“------------”,(out.argmax(dim=1)==y),“~~~~~~~~~~~”,(out.argmax(dim=1)==y).sum())
输出的结果:
tensor([[9.9999e-01, 1.4671e-05],
[4.6179e-14, 1.0000e+00], [3.2289e-02, 9.6771e-01], [1.1237e-22, 1.0000e+00],[9.9993e-01, 7.0015e-05],[8.6740e-02, 9.1326e-01],[1.1458e-18, 1.0000e+00],[5.2558e-01, 4.7442e-01],[9.7923e-01, 2.0772e-02]], grad_fn=<SoftmaxBackward>) ====== tensor([0, 1, 1, 1, 0, 1, 1, 1, 0]) +++++ 8
tensor([0, 1, 1, 1, 0, 1, 1, 0, 0]) ------------ tensor([ True, True, True, True, True, True, True, False, True]) ~~~~~~~~~~~ tensor(8)
解释:
out.argmax(dim=1) 表示哪个值大就说明偏移哪一类,dim=1暂时可以看做是固定的
(out.argmax(dim=1)==y).sum() 表示的true的个数
a = torch.tensor([[1,2,3],[4,7,6]])d = a.argmax(dim=1)
print(d)
tensor([2, 1])
训练数据并保存模型
train()
80 0.3329530954360962 1.0
80 0.31511250138282776 1.0
80 0.33394935727119446 1.0
80 0.3242819309234619 1.0
80 0.3188716471195221 1.0
80 0.3405844569206238 1.0
80 0.32696405053138733 1.0
80 0.3540787696838379 1.0
80 0.3390745222568512 1.0
80 0.3645476996898651 0.8888888888888888
80 0.3371085822582245 1.0
80 0.31789034605026245 1.0
80 0.31553390622138977 1.0
80 0.3162603974342346 1.0
80 0.35249051451683044 1.0
80 0.3582523465156555 1.0
80 0.3162645995616913 1.0
80 0.37988030910491943 1.0
80 0.34384390711784363 1.0
80 0.31773826479911804 1.0
80 0.3145104646682739 1.0
80 0.31753242015838623 1.0
80 0.3222736120223999 1.0
80 0.38612237572669983 1.0
80 0.35490038990974426 1.0
80 0.34469687938690186 1.0
80 0.34534531831741333 1.0
80 0.31800928711891174 1.0
80 0.34892910718917847 1.0
80 0.33424195647239685 1.0
80 0.37350085377693176 1.0
80 0.3298128843307495 1.0
80 0.3715909719467163 1.0
80 0.3507140874862671 1.0
80 0.33337005972862244 1.0
80 0.3134789764881134 1.0
80 0.35244104266166687 1.0
80 0.3148314654827118 1.0
80 0.3376845419406891 1.0
80 0.3315282464027405 1.0
80 0.3450225591659546 1.0
80 0.3139556646347046 1.0
80 0.34932857751846313 1.0
80 0.3512738049030304 1.0
80 0.3258627951145172 1.0
80 0.3197799324989319 1.0
80 0.358166366815567 0.8888888888888888
80 0.3716268837451935 1.0
80 0.31426626443862915 1.0
80 0.32130196690559387 1.0
80 0.3207002282142639 1.0
80 0.3891155421733856 1.0
80 0.35045987367630005 1.0
80 0.32332736253738403 1.0
80 0.31951677799224854 1.0
80 0.3184094727039337 1.0
80 0.3341224491596222 1.0
80 0.3408585786819458 1.0
80 0.3139263093471527 1.0
80 0.33058592677116394 1.0
80 0.3134475648403168 1.0
80 0.3281571567058563 1.0
80 0.33370518684387207 1.0
80 0.33172252774238586 1.0
80 0.32849007844924927 1.0
80 0.3604048788547516 1.0
80 0.3651810884475708 1.0
测试模型
准备数据
# 准备数据
x,fat=next(iter(loader))
x,fat
(tensor([[0.6733, 0.4044],[0.6503, 0.0303],[0.9353, 0.9518],[0.4145, 0.6948],[0.9560, 0.8009],[0.6331, 0.0852],[0.5510, 0.8283],[0.1402, 0.2726],[0.3257, 0.8351]]),tensor([1, 1, 0, 0, 1, 1, 0, 0, 0]))
加载模型预测
# 加载模型
modell=torch.load("4.model")
# 使用模型
out=modell(x)
out
tensor([[1.5850e-04, 9.9984e-01],[1.4121e-06, 1.0000e+00],[7.1068e-01, 2.8932e-01],[9.9994e-01, 5.5789e-05],[4.1401e-03, 9.9586e-01],[3.3441e-06, 1.0000e+00],[9.9995e-01, 4.7039e-05],[9.9111e-01, 8.8864e-03],[1.0000e+00, 1.5224e-06]], grad_fn=<SoftmaxBackward>)
out.argmax(dim=1)
tensor([1, 1, 0, 0, 1, 1, 0, 0, 0])
对比结果
fat
tensor([1, 1, 0, 0, 1, 1, 0, 0, 0])
查看上面out的结果和fat的结论一致,不错
相关文章:
![](https://www.ngui.cc/images/no-images.jpg)
python-pytorch基础之神经网络分类
这里写目录标题 生成数据函数定义数据集定义loader加载数据定义神经网络模型测试输出是否为2个输入数据,输出结果 训练模型函数计算正确率 训练数据并保存模型测试模型准备数据加载模型预测对比结果 生成数据函数 import randomdef get_rectangle():widthrandom.ra…...
![](https://img-blog.csdnimg.cn/a7883763657f418ebfcdba3060461461.png)
【C++ 程序设计】实战:C++ 变量实践练习题
目录 01. 变量:定义 02. 变量:初始化 03. 变量:参数传递 04. 变量:格式说明符 ① 占位符 “%d” 改为格式说明符 “%llu” ② 占位符 “%d” 改为格式说明符 “%f” 或 “%e” 05. 变量:字节数统计 06. 变量&a…...
![](https://img-blog.csdnimg.cn/c90ded869c5d44b383189e2dee05f552.png#pic_center)
微软对Visual Studio 17.7 Preview 4进行版本更新,新插件管理器亮相
近期微软发布了Visual Studio 17.7 Preview 4版本,而在这个版本当中,全新设计的扩展插件管理器将亮相,并且可以让用户可更简单地安装和管理扩展插件。 据了解,目前用户可以从 Visual Studio Marketplace 下载各式各样的 VS 扩展插…...
![](https://img-blog.csdnimg.cn/493d3d469f1341edb94d1d4e7ce78fba.png)
Kafka 入门到起飞 - Kafka怎么做到保障消息不会重复消费的? 消费者组是什么?
Kafka怎么做到避免消息重复消费的? 消费者组是什么? 消费者: 1、订阅Topic(主题) 2、从订阅的Topic消费(pull)消息, 3、将消费消息的offset(偏移量)保存在K…...
![](https://www.ngui.cc/images/no-images.jpg)
MongoDB 的增、查、改、删
Monogo使用 增 单条增加 db.member.insertOne({"name":"张三","age":18,"create":new Date()}) db.member.insert({"name":"李四1","age":18,"create":new Date()}) db.member.insertOne(…...
![](https://www.ngui.cc/images/no-images.jpg)
mysql常用操作命令
mysql常用操作命令 mysql:单进程多线程模型,一个SQL语句无法利用多个cpu core 一:基本命令 0.查看当前连接数 show global status like Thread$; show variables like "%timeout%"; show variables like "log_%";1.查看当前连接状态 show processlist…...
![](https://www.ngui.cc/images/no-images.jpg)
数学建模常见模型汇总
优化问题 线性规划、半定规划、几何规划、非线性规划、整数规划、多目标规划(分层序列法)、动态规划、存贮论、代理模型、响应面分析法、列生成算法 预测模型 微分方程、小波分析、回归分析、灰色预测、马尔可夫预测、时间序列分析(AR MAMA.RMA ARTMA LSTM神经网络)、混沌模…...
![](https://www.ngui.cc/images/no-images.jpg)
C#使用LINQ查询操作符实例代码(二)
目录 六、连表操作符 1、内连接2、左外连接(DefaultIfEmpty)3、组连接七、集合操作 八、分区操作符 1、Take():2、TakeWhile():3、Skip():4、SkipWhile():九、聚合操作符 1、Count: 返回集合项数。 2、LongCount&…...
![](https://www.ngui.cc/images/no-images.jpg)
jenkinsfile小试牛刀
序 本文主要演示一下如何用jenkinsfile来编译java服务 安装jenkins 这里使用docker来安装jenkins docker run --name jenkins-docker \ --volume $HOME/jenkins_home:/var/jenkins_home \ -p 8080:8080 jenkins/jenkins:2.416之后访问http://${yourip}:8080,然后…...
![](https://www.ngui.cc/images/no-images.jpg)
C++ xmake构建
文章目录 一、xmake.lua二、xmake常用语句 一、xmake.lua --xmake.luaset_project("XXX")add_rules("mode.debug", "mode.release") set_config("arch", "x64")if is_plat("windows") then -- the release modei…...
![](https://img-blog.csdnimg.cn/60effa480fda42e08d0763b79d3d78ba.png)
推荐带500创作模型的付费创作V2.1.0独立版系统源码
ChatGPT 付费创作系统 V2.1.0 提供最新的对应版本小程序端,上一版本增加了 PC 端绘画功能, 绘画功能采用其他绘画接口 – 意间 AI,本版新增了百度文心一言接口。 后台一些小细节的优化及一些小 BUG 的处理,前端进行了些小细节优…...
![](https://img-blog.csdnimg.cn/6d3c404e6f66427a9c36c5c2476131d2.png)
wps图表怎么改横纵坐标,MLP 多层感知器和CNN卷积神经网络区别
目录 wps表格横纵坐标轴怎么设置? MLP (Multilayer Perceptron) 多层感知器 CNN (Convolutional Neural Network) 卷积神经网络 多层感知器MLP,全连接网络,DNN三者的关系 wps表格横纵坐标轴怎么设置? 1、打开表格点击图的右侧…...
![](https://www.ngui.cc/images/no-images.jpg)
rdb和aof
RDB持久化:原理是将Redis在内存中的数据库记录定时dump到磁盘上的RDB持久化AOF持久化:原理是将Redis的操作日志以追加的方式写入文件 rdb: 开启方式:客户端可以通过向Redis服务器发送save或bgsave命令让服务器生成rdb文件&#…...
![](https://img-blog.csdnimg.cn/d723762109cd4c5ab96038f1a91fa6ae.png)
TCP网络通信编程之网络上传文件
【图片】 【思路解析】 【客户端代码】 import java.io.*; import java.net.InetAddress; import java.net.Socket; import java.net.UnknownHostException;/*** ProjectName: Study* FileName: TCPFileUploadClient* author:HWJ* Data: 2023/7/29 18:44*/ public class TCPFil…...
![](https://img-blog.csdnimg.cn/a9baeb1dd45f4dc58fb8927095f634f9.png#pic_center)
Java中对Redis的常用操作
目录 数据类型五种常用数据类型介绍各种数据类型特点 常用命令字符串操作命令哈希操作命令列表操作命令集合操作命令有序集合操作命令通用命令 在Java中操作RedisRedis的Java客户端Spring Data Redis使用方式介绍环境搭建配置Redis数据源编写配置类,创建RedisTempla…...
![](https://img-blog.csdnimg.cn/d424345310cc479193490e2ef87f86b3.png)
链路追踪设计
...
![](https://img-blog.csdnimg.cn/74ba64dd71ef4484b77c2b064ec90d04.png)
Golang之路---02 基础语法——常量 (包括特殊常量iota)
常量 //显式类型定义const a string "test" //隐式类型定义const b 20 //多个常量定义 const(c "test2"d 2.3e 27)iota iota是Golang语言的常量计数器,只能在常量表达式中使用 iota在const关键字出现时将被重置为0,const中每新…...
![](https://img-blog.csdnimg.cn/abda4d3ea0da4b7f832ff23645f2e0f2.png)
Pytest学习教程_装饰器(二)
前言 pytest装饰器是在使用 pytest 测试框架时用于扩展测试功能的特殊注解或修饰符。使用装饰器可以为测试函数提供额外的功能或行为。 以下是 pytest 装饰器的一些常见用法和用途: 装饰器作用pytest.fixture用于定义测试用例的前置条件和后置操作。可以创建可重…...
![](https://img-blog.csdnimg.cn/0e948fe981364439ba7a9372304f2c86.png)
redis的如何使用
1、redis的使用 1.1windows安装 安装包下载地址:Releases dmajkic/redis GitHub 1.2 redis中常使用的几个文件 1.3 redis中运行 双击redis-server,既可以运行。 1.4使用redis客户单来连接redis 1.5redis的常用指标 redis-serve 服务端,端口号&am…...
![](https://img-blog.csdnimg.cn/6239389f3a4141eb9d92c46837634f42.png)
MyBatis(二)
文章目录 一.MyBatis的模式开发1.1 定义数据表和实体类1.2 配置数据源和MyBatis1.3 编写Mapper接口和增加xxxMapper.xml1.4 测试我们功能的是否实现. 二. Mybatis的增删查改操作2.1 单表查询2.2 多表查询三.动态SQL的实现3.1 什么是动态SQL3.2 动态SQL的使用if标签的使用trim标…...
![](https://img-blog.csdnimg.cn/c5c62b6f7e514c0188646895c2f88dd0.png)
【【51单片机AD转换模块】】
代码是简单的,板子是坏的,电阻是识别不出来的 main.c #include <REGX52.H> #include "delay.h" #include "LCD1602.h" #include "XPT2046.h"unsigned int ADValue;void main(void) {LCD_Init();LCD_ShowString(1,1…...
![](https://www.ngui.cc/images/no-images.jpg)
Longest Divisors Interval(cf)
题意:给定一个正整数n,求正整数的区间[l,r]的最大大小,使得对于区间中的每个i(即l≤i≤r),n是i的倍数。给定两个整数l≤r,区间[l,r]的大小为r−l1(即…...
![](https://www.ngui.cc/images/no-images.jpg)
配置文件、request对象请求方法、Django连接MySQL、Django中的ORM、ORM增删改查字段、ORM增删改查数据
一、配置文件的介绍 1.注册应用的 INSTALLED_APPS [django.contrib.admin,django.contrib.auth,django.contrib.contenttypes,django.contrib.sessions,django.contrib.messages,django.contrib.staticfiles,app01.apps.App01Config, ]################中间件###############…...
![](https://img-blog.csdnimg.cn/f8a16220e3584ab4aa4268c990bcd8ac.png)
CTF学习路线指南(附刷题练习网址)
前言: PWN,Reverse:偏重对汇编,逆向的理解; Gypto:偏重对数学,算法的深入学习; Web:偏重对技巧沉淀,快速搜索能力的挑战; Mic:则更为复杂&…...
![](https://www.ngui.cc/images/no-images.jpg)
【Rust 基础篇】Rust默认泛型参数:简化泛型使用
导言 Rust是一种以安全性和高效性著称的系统级编程语言,其设计哲学是在不损失性能的前提下,保障代码的内存安全和线程安全。在Rust中,泛型是一种非常重要的特性,它允许我们编写一种可以在多种数据类型上进行抽象的代码。然而&…...
![](https://img-blog.csdnimg.cn/316e8c5774694cd38d4d48cf41f778e0.png)
从源码分析Handler面试问题
Handler 老生常谈的问题了,非常建议看一下Handler 的源码。刚入行的时候,大佬们就说 阅读源码 是进步很快的方式。 Handler的基本原理 Handler 的 重要组成部分 Message 消息MessageQueue 消息队列Lopper 负责处理MessageQueue中的消息 消息是如何添加…...
![](https://www.ngui.cc/images/no-images.jpg)
shell编程 变量作用域
变量 变量赋值不用$,访问值时用$,赋值时两边不留空格,双引号括起来的变量被值替换{}标记变量开始和结束,变量名区分大小写,所有bash变量的值变量不区分类型,统一为字符串 变量类型 环境变量,子进程可以继承父进程环境…...
![](https://img-blog.csdnimg.cn/7e8e029abbb7400e98fdab2363013814.png)
华为eNSP:isis的配置
一、拓扑图 二、路由器的配置 配置接口IP AR1: <Huawei>system-view [Huawei]int g0/0/0 [Huawei-GigabitEthernet0/0/0]ip add 1.1.1.1 24 [Huawei-GigabitEthernet0/0/0]qu AR2: <Huawei>system-view [Huawei]int g0/0/0 [Huawei-GigabitEthe…...
![](https://www.ngui.cc/images/no-images.jpg)
FS.05-SAS-UP-Methodology
FS.05-SAS-UP-Methodology-v9.2.pdf 附录 D 数据处理审核 作为现场数据处理系统和支持流程审核的一部分,受审核方最好在审核日期之前准备一些 SAS 特定的测试数据文件。 本文件提供了建议的方法; 受审核方和审核团队将就每次审核的具体方法达成一致。 …...
![](https://img-blog.csdnimg.cn/0741ea56c4d046f0b58ee1069fa1f45a.png)
Jmeter并发测试
基本步骤 1、新建线程组 测试计划右键——>添加——>线程(用户)——>线程组 2、 添加HTTP请求 线程组右键——>添加——>取样器——>HTTP请求 3、 添加HTTP信息头管理器 线程组右键——>添加——>配置元件——>HTTP信息头…...
![](/images/no-images.jpg)
国际网站怎么做/app拉新推广平台渠道
由于种种原因,决定今天开始从centos7转为ubuntu14.0.但是一开始就遇到了一个问题。ubuntu的源貌似被墙了。没有办法,只好换源。如是参考了下面的帖子。 换源参考链接 但是又出现了一个问题,百度了好久才找到 zhaoubuntu:~$ sudo apt-get -f install gcc …...
![](/images/no-images.jpg)
如何给网站做推广/想做一个网站
前端性能优化 减少请求数量 减小资源大小 优化网络连接 优化资源加载 减少重绘回流 性能更好的API webpack优化 常见状态码 100:这个状态码是告诉客户端应该继续发送请求,这个临时响应是用来通知客户端的,部分的请求服务器已经接受…...
![](https://img-blog.csdnimg.cn/20200923070043298.png)
哪个网站做批发最便宜吗/长春网站优化服务
Jackson是Spring Boot(SpringBoot)默认的JSON数据处理框架,但是其并不依赖于任何的Spring 库。有的小伙伴以为Jackson只能在Spring框架内使用,其实不是的,没有这种限制。它提供了很多的JSON数据处理方法、注解,也包括流式API、树模…...
![](/images/no-images.jpg)
贵阳网站建设公司排名/怎么找需要做推广的公司
转自http://blog.csdn.net/gzlaiyonghao/article/details/1674808 Python Python的优点: 1、Python比其它语言有更多扩展模块。 2、在网上可以找到很多Python教程。不仅如此,还有大量的英文书籍和资料。Python.org有很多为初学者准备的依主题组织的资料、…...
![](https://imgsa.baidu.com/exp/w=500/sign=46c550d5f91f4134e037057e151e95c1/80cb39dbb6fd526630cdf122af18972bd4073626.jpg)
定远建设局网站/关键词优化排名的步骤
浏览器记住密码,怎么查看密码是什么? 听语音| 浏览:7891 | 更新:2015-01-28 14:26 | 标签:浏览器 1234567分步阅读现在浏览器都有一种功能叫记住密码,其实这样很不安全。 你眼睛看的那几个‘******’并没有…...
![](https://www.oschina.net/img/hot3.png)
wordpress适合做什么网站/推广竞价账户托管
为什么80%的码农都做不了架构师?>>> apache Httpclient基于java BIO实现的,也是基于apache HttpCore项目。他最基本的功能是执行HTTP方法。HttpClient的API的主要入口就是HttpClient接口,看看这个示例: package httpc…...