当前位置: 首页 > news >正文

深度学习基础小结_项目实战:手机价格预测

目录

库函数导入

一、构建数据集

 二、构建分类网络模型

三、编写训练函数

四、编写评估函数

五、网络性能调优


鲍勃开了自己的手机公司。他想与苹果、三星等大公司展开硬仗。 他不知道如何估算自己公司生产的手机的价格。在这个竞争激烈的手机市场,你不能简单地假设事情。为了解决这个问题,他收集了各个公司的手机销售数据。

鲍勃想找出手机的特性(例如:RAM、内存等)和售价之间的关系。但他不太擅长机器学习。所以他需要你帮他解决这个问题。 在这个问题中,你不需要预测实际价格,而是要预测一个价格区间,表明价格多高。

需要注意的是: 在这个问题中,我们不需要预测实际价格,而是一个价格范围,它的范围使用 0、1、2、3 来表示,所以该问题也是一个分类问题。

数据说明:手机价格分类_数据集-阿里云天池 Mobile Price Classification

库函数导入

import torch
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import import Dataset,DataLoader,TensorDataset
from sklearn.preprocessing import StandardScaler
import time

一、构建数据集

数据共有 2000 条, 其中 1600 条数据作为训练集, 400 条数据用作测试集。 我们使用 sklearn 的数据集划分工作来完成。并使用 PyTorch 的 TensorDataset 来将数据集构建为 Dataset 对象,方便构造数据集加载对象。

"""构建数据"""
def phone_data_set(path):# 加载本地数据文件data = pd.read_csv(path)# 抽离特征和目标数据x = data.iloc[:,:-1]y = data.iloc[:,-1]# 标准化transfer = StanderdScaler()x = transfer.fit_transform(x.values)x = torch.tensor(x,dtype=torch.float32)y = torch.tensor(y.values,dtype=torch.int64) # 输出是分类结果,所以用整型# 数据集划分x_trian,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=42,stratify=y)return x_trian,x_test,y_train,y_testclass my_phone_data_loader(Dataset):
# TensorDataset 可取代该类的作用def __init__(self,x,y):self.x = xself.y = ydef __len__(self):return len(self.x)def __getitem__(self, index):return self.x[index],self.y[index]def data_loader(x_trian,y_train,batch_size=16):# data=my_phone_data_loader(x_train,y_train)data = TensorDataset(x_train,y_trian)data_loader = DataLoader(data,batch_size=batch_size,shuffle=True)return data_loader

 二、构建分类网络模型

构建用于手机价格分类的模型叫做全连接神经网络。它主要由数个线性层来构建,在每个线性层后,还需使用激活函数。

"""模型 构建和初始化参数"""
class Net(torch.nn.Module):def __init__(self,input_features,out_features):super(Net,self).__init__()# 隐藏层 LeakyReLU 激活和输出层 Softmax 激活self.hide1 = nn.Sequential(nn.Linear(input_features,128),nn.LeakyReLU())self.hide2 = nn.Sequential(nn.Linear(128,256),nn.LeakyReLU())self.hide3 = nn.Sequential(nn.Linear(256,512),nn.LeakyReLU())self.hide4 = nn.Sequential(nn.Linear(512,128),nn.LeakyReLU())self.out = nn.Sequential(nn.Linear(128,out_features),nn.Softmax())self.initdata()def forward(self,input_data):# 前向传播x = self.hide1(iput_data)x=self.hide2(x)x=self.hide3(x)x=self.hide4(x)y_pred=self.out(x)return y_pred#分类结果,混淆矩阵 [[0.9,0.01,.02,.7],[...],...]def initdata(self):# 权重He初始化nn.init.kaiming_uniform_(self.hide1[0].weight,nonlinearity="leaky_relu")nn.init.kaiming_uniform_(self.hide2[0].weight,nonlinearity="leaky_relu")nn.init.kaiming_uniform_(self.hide3[0].weight,nonlinearity="leaky_relu")nn.init.kaiming_uniform_(self.hide4[0].weight,nonlinearity="leaky_relu")       

三、编写训练函数

网络编写完成之后,我们需要编写训练函数。所谓的训练函数,指的是输入数据读取、送入网络、计算损失、更新参数的流程,该流程较为固定。我们使用的是多分类交叉生损失函数、使用 SGD 优化方法。最终,将训练好的模型持久化到磁盘中。

"""训练数据"""
# 加载数据
x_train,x_test,y_trian,y_test=phone_data_set("./data/手机价格预测.csv")
def train():data_loader_=data_loader(x_train,y_trian)# 模型生成x_features=x_train.shape[1]y_features=torch.unique(y_trian).shape[0]# 输出特征(类别的数量)model=Net(x_features,y_features)# 初始化模型参数# 默认是初始化过的torch.nn.init.kaiming_uniform_(model.linear1.weight,nonlinearity="leaky_relu")torch.nn.init.kaiming_uniform_(model.linear2.weight,nonlinearity="leaky_relu") # 3.损失函数loss_fn=torch.nn.CrossEntropyLoss()#torch.nn.MSELoss()#虽然分类的结果也可以用均方误差来计算,但是一般用交叉熵(因为计算出来的梯度更大)# 4.优化器optim=torch.optim.Adam(model.parameters(),lr=1e-4)# 定义训练参数epoch=100for i in range(epoch):e=0count=0start_time=time.time()for x,y in data_loader_:count+=1# 生成预测值y_pred=model(x)#执行model对象的forward# 损失计算loss=loss_fn(y_pred,y)e+=loss# 梯度清零optim.zero_grad()# 反向传播loss.backward()# 更新参数optim.step()end_time=time.time()print(f"epoch:{i},loss:{e/count},time:{end_time-start_time}") # 保存模型参数# model.linear1.weight.datatorch.save(model.state_dict(),"./model/model.pth")

四、编写评估函数

评估函数、也叫预测函数、推理函数,主要使用训练好的模型,对未知的样本的进行预测的过程。这里使用前面单独划分出来的测试集来进行评估。

"""评估函数"""
def test():# 加载数据data_loader_=data_loader(x_test,y_test,batch_size=16)# data_loader_=DataLoader(data,batch_size=4,shuffle=True)# 加载模型# 模型生成x_test_features=x_test.shape[1]y_test_features=torch.unique(y_test).shape[0]# 类别数model=Net(x_test_features,y_test_features)# model.linear1.weight=model.linear1.weight.datastate_dict=torch.load("./model/model.pth",map_location="cpu")model.load_state_dict(state_dict)total=0for x,y in data_loader_:y_pred=model(x)#[0.4,0.3,0.1,0.1,0.1]y_pred=torch.argmax(y_pred,dim=1)# print("y",y)# print("y_pred",y_pred)total+=torch.sum(y_pred==y)print(f"精准度:{total/len(x_test)}")

五、网络性能调优

可以通过以下方面对模型进行调优:

  1. 对输入数据进行标准化

  2. 调整优化方法

  3. 调整学习率

  4. 增加批量归一化层

  5. 增加网络层数、神经元个数

  6. 增加训练轮数

相关文章:

深度学习基础小结_项目实战:手机价格预测

目录 库函数导入 一、构建数据集 二、构建分类网络模型 三、编写训练函数 四、编写评估函数 五、网络性能调优 鲍勃开了自己的手机公司。他想与苹果、三星等大公司展开硬仗。 他不知道如何估算自己公司生产的手机的价格。在这个竞争激烈的手机市场,你不能简单地…...

EMall实践DDD模拟电商系统总结

目录 一、事件风暴 二、系统用例 三、领域上下文 四、架构设计 (一)六边形架构 (二)系统分层 五、系统实现 (一)项目结构 (二)提交订单功能实现 (三&#xff0…...

【随笔】AI技术在电商中的应用

这几年,伴随着ChatGPT开始的AI浪潮席卷全球,从聊天场景逐步向多场景扩散,形成了广泛开花的现象。至今,虽然在部分场景的进展已经略显疲态,但当前的这种趋势仍然还在不断的扩展。不少公司,甚至有不少大型电商…...

序列式容器详细攻略(vector、list)C++

vector std::vector 是 STL 提供的 内存连续的、可变长度 的数组(亦称列表)数据结构。能够提供线性复杂度的插入和删除,以及常数复杂度的随机访问。 为什么要使用 vector 作为 OIer,对程序效率的追求远比对工程级别的稳定性要高得多,而 vector 由于其对内存的动态处理,…...

快速启动项目

1 后端项目 https://gitee.com/liuyunkai666/gungun-boot.git 分支: mini 是 springboot3 jdk17 的基础版本,后续其他功能模块陆续在其基础上追加即可​。 1.1 必备环境 1.1.1 mysql 创建一个 自定义名称 数据库,【只要】 执行对应数据库…...

springboot347基于web的铁路订票管理系统(论文+源码)_kaic

摘 要 当今社会进入了科技进步、经济社会快速发展的新时代。计算机技术对经济社会发展和人民生活改善的影响也日益突出,人类的生存和思考方式也产生了变化。传统铁路订票管理采取了人工的管理方法,但这种管理方法存在着许多弊端,比如效率低…...

使用API管理Dynadot域名,在账户中添加域名服务器(Name Server)

前言 Dynadot是通过ICANN认证的域名注册商,自2002年成立以来,服务于全球108个国家和地区的客户,为数以万计的客户提供简洁,优惠,安全的域名注册以及管理服务。 Dynadot平台操作教程索引(包括域名邮箱&…...

【Linux | 计网】TCP协议深度解析:从连接管理到流量控制与滑动窗口

目录 前言: 1、三次握手和四次挥手的联系: 为什么挥手必须要将ACK和FIN分开呢? 2.理解 CLOSE_WAIT 状态 CLOSE_WAIT状态的特点 3.FIN_WAIT状态讲解 3.1、FIN_WAIT_1状态 3.2、FIN_WAIT_2状态 3.3、FIN_WAIT状态的作用与意义 4.理解…...

go语言的成神之路-筑基篇-对文件的操作

目录 一、对文件的读写 Reader 接口 Writer接口 copy接口 bufio的使用 ioutil库 二、cat命令 三、包 1. 包的声明 2. 导入包 3. 包的可见性 4. 包的初始化 5. 标准库包 6. 第三方包 7. 包的组织 8. 包的别名 9. 包的路径 10. 包的版本管理 四、go mod 1. 初始…...

两道数据结构编程题

1.写出在顺序存储结构下将线性表逆转的算法,要求使用最少的附加空间。 解:输入:长度为n的线性表数组A(1:n) 输出:逆转后的长度为n的线性表数组A(1:n)。 C语言描述如下(其中ET为数据元素的类型):…...

【Qt】QDateTimeEdit控件实现清空(不保留默认时间/最小时间)

一、QDateTimeEdit控件 QDateTimeEdit 提供了一个用于编辑日期和时间的控件。用户可以通过键盘或使用上下箭头键来增加或减少日期和时间值。日期和时间的显示格式根据设置的格式显示,可以通过 setDisplayFormat() 方法来设置。 二、如何清空 我在使用的时候&#…...

12、字符串

1、字符串概念 字符串用来存储一组字符,因此需要字符数组来存。 C语言中字符串的表示 C语言里面字符串只能用字符数组来存 字符串要求这个数组的末尾必须至少有一个\0 char ch1[] {a,b,c}; // 不是字符串 char ch2[5] {h,e,l,l,o}; // 不是字符串 char…...

DPDK用户态协议栈-Tcp Posix API 1

和udp一样&#xff0c;我们需要实现和系统调用一样的接口来实现我们的tcp server。先来看看我们之前写的unix_tcp使用了哪些接口&#xff0c;这边我加上两个系统调用&#xff0c;分别是接收数据和发送数据。 #include <stdio.h> #include <arpa/inet.h> #include …...

【人工智能-科普】图神经网络(GNN):与传统神经网络的区别与优势

文章目录 图神经网络(GNN):与传统神经网络的区别与优势什么是图神经网络?图的基本概念GNN的工作原理GNN与传统神经网络的不同1. 数据结构的不同2. 信息传递方式的不同3. 模型的可扩展性4. 局部与全局信息的结合GNN的应用领域总结图神经网络(GNN):与传统神经网络的区别与…...

LabVIEW实现UDP通信

目录 1、UDP通信原理 2、硬件环境部署 3、云端环境部署 4、UDP通信函数 5、程序架构 6、前面板设计 7、程序框图设计 8、测试验证 本专栏以LabVIEW为开发平台,讲解物联网通信组网原理与开发方法,覆盖RS232、TCP、MQTT、蓝牙、Wi-Fi、NB-IoT等协议。 结合实际案例,展示如何利…...

[pdf,epub]228页《分析模式》漫谈合集01-45提供下载

《分析模式》漫谈合集01-45的pdf、epub文件提供下载。已上传至本号的CSDN资源。 如果CSDN资源下载有问题&#xff0c;可到umlchina.com/url/ap.html。 已排版成适合手机阅读&#xff0c;pdf的排版更好一些。 ★UMLChina为什么叒要翻译《分析模式》&#xff1f; ★[缝合故事]…...

Kafka的消费消息是如何传递的?

大家好&#xff0c;我是锋哥。今天分享关于【Kafka的消费消息是如何传递的&#xff1f;】面试题。希望对大家有帮助&#xff1b; Kafka的消费消息是如何传递的&#xff1f; 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 在Kafka中&#xff0c;消息的消费是通过消费…...

二分查找(Java实现)(1)

二分查找&#xff08;Java实现&#xff09;&#xff08;1&#xff09; leetcode 34.排序数组中查找元素第一个和最后一个位置 题目描述: 给你一个按照非递减顺序排列的整数数组 nums&#xff0c;和一个目标值 target。请你找出给定目标值在数组中的开始位置和结束位置。 如…...

力扣103.二叉树的锯齿形层序遍历

题目描述 题目链接103. 二叉树的锯齿形层序遍历 给你二叉树的根节点 root &#xff0c;返回其节点值的 锯齿形层序遍历 。&#xff08;即先从左往右&#xff0c;再从右往左进行下一层遍历&#xff0c;以此类推&#xff0c;层与层之间交替进行&#xff09;。 示例 1&#xff…...

Search with Orama

1.前言 在不久之前&#xff0c;我把 DevNow 的搜索组件通过 Lunr 进行了重构&#xff0c;从前端角度实现了对文章内容的搜索&#xff0c;但是在使用体验上&#xff0c;感觉不是特别好&#xff0c;大概有如下几个原因&#xff1a; 社区的文章数量比较少&#xff0c;项目的 Com…...

从LiveData到Kotlin Flow:Pokedex响应式编程的终极演进指南

从LiveData到Kotlin Flow&#xff1a;Pokedex响应式编程的终极演进指南 【免费下载链接】Pokedex &#x1f5e1;️ Pokedex demonstrates modern Android development with Hilt, Material Motion, Coroutines, Flow, Jetpack (Room, ViewModel) based on MVVM architecture. …...

AI驱动的3D建模革命:PIFuHD开源工具让零基础用户轻松创建高精度数字人

AI驱动的3D建模革命&#xff1a;PIFuHD开源工具让零基础用户轻松创建高精度数字人 【免费下载链接】pifuhd High-Resolution 3D Human Digitization from A Single Image. 项目地址: https://gitcode.com/gh_mirrors/pi/pifuhd 在数字内容创作、游戏开发和AR/VR应用领域…...

淘宝淘金币自动化脚本:每天节省20分钟的终极解决方案

淘宝淘金币自动化脚本&#xff1a;每天节省20分钟的终极解决方案 【免费下载链接】taojinbi 淘宝淘金币自动执行脚本&#xff0c;包含蚂蚁森林收取能量&#xff0c;芭芭农场全任务&#xff0c;解放你的双手 项目地址: https://gitcode.com/gh_mirrors/ta/taojinbi 淘宝淘…...

Phi-4-reasoning-vision-15B行业应用:银行手机银行截图→交易流程合规性审计

Phi-4-reasoning-vision-15B在银行手机银行截图合规审计中的应用实践 1. 银行业务合规审计的痛点与机遇 在银行业务数字化转型的浪潮中&#xff0c;手机银行已成为客户办理业务的主要渠道。然而&#xff0c;随之而来的是海量的交易截图和操作记录需要人工审核&#xff0c;以确…...

SAP Basis实战:Client创建与数据迁移的完整流程与避坑指南

1. 理解SAP Client的基本概念 在SAP系统中&#xff0c;Client&#xff08;客户端&#xff09;是一个非常重要的概念。简单来说&#xff0c;它就像是系统中的一个独立工作空间&#xff0c;每个Client都有自己的配置和数据。想象一下&#xff0c;一家大型企业有多个子公司&#x…...

gte-base-zh效果展示:中文诗歌风格迁移评估——基于向量空间距离的风格量化分析

gte-base-zh效果展示&#xff1a;中文诗歌风格迁移评估——基于向量空间距离的风格量化分析 1. 引言&#xff1a;当AI遇见古诗词 想象一下&#xff0c;你是一位诗词爱好者&#xff0c;想尝试把李白的豪放诗句改写成李清照的婉约风格。传统上&#xff0c;这需要深厚的文学功底…...

cat-catch:构建智能化媒体资源捕获的浏览器扩展解决方案

cat-catch&#xff1a;构建智能化媒体资源捕获的浏览器扩展解决方案 【免费下载链接】cat-catch 猫抓 chrome资源嗅探扩展 项目地址: https://gitcode.com/GitHub_Trending/ca/cat-catch cat-catch是一款专注于网页媒体资源智能捕获的浏览器扩展工具&#xff0c;通过深度…...

大模型入门学习教程(非常详细)非常详细收藏我这一篇就够了!大模型教程

本文系统介绍了LLM&#xff08;大型语言模型&#xff09;的基础知识&#xff0c;包括机器学习的数学基础、Python编程及其在数据科学中的应用、神经网络原理等。文章深入剖析了LLM科学家和工程师的角色&#xff0c;涵盖了大型语言模型架构、指令数据集构建、预训练模型、监督微…...

3大突破!零门槛掌握资源嗅探:猫抓插件全平台使用指南

3大突破&#xff01;零门槛掌握资源嗅探&#xff1a;猫抓插件全平台使用指南 【免费下载链接】cat-catch 猫抓 chrome资源嗅探扩展 项目地址: https://gitcode.com/GitHub_Trending/ca/cat-catch 一、为什么你需要专业的资源嗅探工具&#xff1f; 场景化痛点直击 作为…...

5款强力资源获取工具深度评测:猫抓媒体解析技术如何重塑内容管理流程

5款强力资源获取工具深度评测&#xff1a;猫抓媒体解析技术如何重塑内容管理流程 【免费下载链接】cat-catch 猫抓 chrome资源嗅探扩展 项目地址: https://gitcode.com/GitHub_Trending/ca/cat-catch 在数字内容爆炸的时代&#xff0c;高效获取和管理网络媒体资源已成为…...