当前位置: 首页 > news >正文

ccc-pytorch-回归问题(1)

文章目录

      • 1.简单回归实战:
      • 2.手写数据识别

1.简单回归实战:

用 线性回归拟合二维平面中的100个点
在这里插入图片描述
公式:y=wx+by=wx+by=wx+b
损失函数:∑(yreally−y)2\sum(y_{really}-y)^2(yreallyy)2
迭代方法:梯度下降法,其中www,bbb更新公式如下:
wN+1=wN−η∗∂loss∂wbN+1=bN−η∗∂loss∂bw_{N+1}=w_N-\eta*\frac{\partial loss}{\partial w}\\ b_{N+1}=b_{N}-\eta*\frac{\partial loss}{\partial b}wN+1=wNηwlossbN+1=bNηbloss
其中η\etaη表示学习率,∂\partial表示微分
∂loss∂w=2(wx+b−y)x/n∂loss∂b=2(wx+b−y)/n\frac{\partial loss}{\partial w}=2(wx+b-y)x/n\\ \frac{\partial loss}{\partial b}=2(wx+b-y)/n wloss=2(wx+by)x/nbloss=2(wx+by)/n
项目文件
计算损失函数:

def compute_loss(b,w,points):total = 0for i in range(0,len(points)):x = points[i,0]y = points[i,1]total += (y-(w*x+b)) ** 2return total / float(len(points))

梯度下降迭代更新:

def gradient(b,w,points,leanrningRate):b_gradient = 0w_gradient = 0N = float(len(points))for i in range(0,len(points)):x = points[i,0]y = points[i,1]b_gradient += (2/N) * (((w * x)+b)-y)w_gradient += (2/N) * (((w * x)+b)-y) * xnew_b = b - (leanrningRate * b_gradient)new_w = w - (leanrningRate * w_gradient)return [new_b , new_w]def graient_descent_runner(points, b, w, learning_rate, num_iterations):new_b = bnew_w = wfor i in range(num_iterations):new_b, new_w = gradient(new_b, new_w, np.array(points), learning_rate)return [new_b, new_w]

主函数运行以及绘图结果:

def run():points = np.genfromtxt("data.csv",delimiter=",")learning_rate = 0.0001initial_b = 0initial_w = 0num_iteractions = 1000print("Starting gradient descent at b = {0}, w = {1}, error = {2}".format(initial_b,initial_w,compute_loss(initial_b,initial_w,points)))print("Runing...")[b, w] = graient_descent_runner(points,initial_b,initial_w,learning_rate,num_iteractions)print("After {0} iterations b = {1}, w = {2}, error = {3}".format(num_iteractions,b,w,compute_loss(b,w,points)))x = np.linspace(20, 80, 5)y = w * x + bpyplot.plot(x, y)pyplot.scatter(points[:, 0], points[:, 1])pyplot.show()if __name__ == '__main__':run()

在这里插入图片描述
在这里插入图片描述

2.手写数据识别

工具函数库:

import torch
from matplotlib import pyplot as pltdef plot_curve(data):fig = plt.figure()plt.plot(range(len(data)), data, color='blue')plt.legend(['value'], loc='upper right')plt.xlabel('step')plt.ylabel('value')plt.show()def plot_image(img, label, name):fig = plt.figure()for i in range(6):plt.subplot(2, 3, i + 1)plt.tight_layout()plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')plt.title("{}: {}".format(name, label[i].item()))plt.xticks([])plt.yticks([])plt.show()def one_hot(label, depth=10):out = torch.zeros(label.size(0), depth)idx = torch.LongTensor(label).view(-1, 1)out.scatter_(dim=1, index=idx, value=1)return out

第一步:导入库和图像数据

import torch
from torch import nn #构建神经网络
from torch.nn import functional as F 
from torch import optim #最优化工具
import torchvision #视觉工具
from utils import plot_image, plot_curve, one_hotbatch_size = 512
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=False)x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), y.min())
plot_image(x, y, 'image sample')

在这里插入图片描述
第二步:新建一个三层的非线性的网层

class Net(nn.Module):def __init__(self):super(Net, self).__init__()#第一层(28*28是图片,256根据经验随机决定)self.fc1 = nn.Linear(28 * 28, 256)self.fc2 = nn.Linear(256, 64)#第三层(十分类输出一定是10)self.fc3 = nn.Linear(64, 10)def forward(self, x):# x: [b, 1, 28, 28]# h1 = relu(xw1+b1) h2 = relu(h1w2+b2) h3 = h2w3+b3x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x

第三步:train训练

net = Net()
# [w1, b1, w2, b2, w3, b3]
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
train_loss = []for epoch in range(3):for batch_idx, (x, y) in enumerate(train_loader):# x: [b, 1, 28, 28], y: [512]# [b, 1, 28, 28] => [b, 784],将整个图片看做特征向量x = x.view(x.size(0), 28*28)# => [b, 10]out = net(x)# [b, 10]y_onehot = one_hot(y)# loss = mse(out, y_onehot)loss = F.mse_loss(out, y_onehot)optimizer.zero_grad()loss.backward()#梯度下降# w' = w - lr*gradoptimizer.step()train_loss.append(loss.item())if batch_idx % 10==0:print(epoch, batch_idx, loss.item())plot_curve(train_loss)

在这里插入图片描述
第四步:准确度测试

total_correct = 0
for x,y in test_loader:x  = x.view(x.size(0), 28*28)out = net(x)# out: [b, 10] => pred: [b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_loader.dataset)
acc = total_correct / total_num
print('test acc:', acc)x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, 'test')

在这里插入图片描述
在这里插入图片描述
无注释代码:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from utils import plot_image, plot_curve, one_hotbatch_size = 512
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=False)x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), y.min())
plot_image(x, y, 'image sample')class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(28 * 28, 256)self.fc2 = nn.Linear(256, 64)self.fc3 = nn.Linear(64, 10)def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()
# [w1, b1, w2, b2, w3, b3]
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
train_loss = []for epoch in range(3):for batch_idx, (x, y) in enumerate(train_loader):x = x.view(x.size(0), 28*28)out = net(x)y_onehot = one_hot(y)loss = F.mse_loss(out, y_onehot)optimizer.zero_grad()loss.backward()optimizer.step()train_loss.append(loss.item())if batch_idx % 10==0:print(epoch, batch_idx, loss.item())plot_curve(train_loss)total_correct = 0
for x,y in test_loader:x  = x.view(x.size(0), 28*28)out = net(x)pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_loader.dataset)
acc = total_correct / total_num
print('test acc:', acc)x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, 'test')

相关文章:

ccc-pytorch-回归问题(1)

文章目录1.简单回归实战:2.手写数据识别1.简单回归实战: 用 线性回归拟合二维平面中的100个点 公式:ywxbywxbywxb 损失函数:∑(yreally−y)2\sum(y_{really}-y)^2∑(yreally​−y)2 迭代方法:梯度下降法,…...

【JAVA八股文】框架相关

框架相关1. Spring refresh 流程2. Spring bean 生命周期3. Spring bean 循环依赖解决 set 循环依赖的原理4. Spring 事务失效5. Spring MVC 执行流程6. Spring 注解7. SpringBoot 自动配置原理8. Spring 中的设计模式1. Spring refresh 流程 Spring refresh 概述 refresh 是…...

二叉树的相关列题!!

对于二叉树,很难,很难!笔者也是感觉很难!虽然能听懂课程,但是,对于大部分的练习题并不能做出来!所以感觉很尴尬!!因此,笔者经过先前的那篇博客,已…...

Java设计模式 - 原型模式

简介 原型模式(Prototype Pattern)是用于创建重复的对象,同时又能保证性能。这种类型的设计模式属于创建型模式,它提供了一种创建对象的最佳方式。 这种模式是实现了一个原型接口,该接口用于创建当前对象的克隆。当直…...

深度学习中的 “Hello World“

Here’s an interesting fact—Each month, there are 186.000 Google searches for the keyword “deep learning.” 大家好✨,这里是bio🦖。每月有超18万的人使用谷歌搜索深度学习这一关键词,是什么让人们对深度学习如此感兴趣?接下来请跟随我来揭开深度学习的神秘面纱。…...

购买WMS系统前,有搞清楚与ERP仓库模块的区别吗

经常有朋友在后台询问我们关于WMS系统的问题,他们自己也有ERP系统,但是总觉得好像还差了点什么,不知道是什么。今天,我想通过本文,来向您简要地阐述ERP与WMS系统在仓储管理上的不同之处。 ERP仓库是以财务为导向的&…...

一文吃透 Spring 中的IOC和DI

✅作者简介:2022年博客新星 第八。热爱国学的Java后端开发者,修心和技术同步精进。 🍎个人主页:Java Fans的博客 🍊个人信条:不迁怒,不贰过。小知识,大智慧。 💞当前专栏…...

分布式任务处理:XXL-JOB分布式任务调度框架

文章目录1.业务场景与任务调度2.任务调度的基本实现2.1 多线程方式实现2.2 Timer方式实现2.3 ScheduledExecutor方式实现2.4 第三方Quartz方式实现3.分布式任务调度4.XXL-JOB介绍5.搭建XXL-JOB —— 调度中心5.1 下载与查看XXL-JOB5.2 创建数据库表5.3 修改默认的配置信息5.4 启…...

【源码解析】Ribbon和Feign实现不同服务不同的配置

Ribbon服务实现不同服务,不同配置是通过RibbonClient和RibbonClients两个注解来实现的。RibbonClient注册的某个Client配置类。RibbonClients注册的全局默认配置类。 Feign实现不同服务,不同配置,是根据FeignClient来获取自定义的配置。 示…...

【webpack5】一些常见优化配置及原理介绍(二)

这里写目录标题介绍sourcemap定位报错热模块替换(或热替换,HMR)oneOf精准解析指定或排除编译开启缓存多进程打包移除未引用代码配置babel,减小代码体积代码分割(Code Split)介绍预获取/预加载(prefetch/pre…...

力扣sql简单篇练习(十九)

力扣sql简单篇练习(十九) 1 查询结果的质量和占比 1.1 题目内容 1.1.1 基本题目信息 1.1.2 示例输入输出 1.2 示例sql语句 # 用count是不会统计为null的数据的 SELECT query_name,ROUND(AVG(rating/position),2) quality,ROUND(count(IF(rating<3,rating,null))/count(r…...

线段树c++

前言 在谈论到种种算法知识与数据结构的时候,线段树无疑总是与“简单”和“平常”联系起来的。而这些特征意味着,线段树作为一种常用的数据结构,有常用性,基础性和易用性等诸多特点。因此,今天我来讲一讲关于线段树的话题。 定义 首先,线段树是一棵“树”,而且是一棵…...

HTML+CSS+JavaScript学习笔记~ 从入门到精通!

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录前言一、HTML1. 什么是HTML&#xff1f;一个完整的页面&#xff1a;<!DOCTYPE> 声明中文编码2.HTML基础①标签头部元素标题段落注释水平线文本格式化②属性3.H…...

LeetCode 430. 扁平化多级双向链表

原题链接 难度&#xff1a;middle\color{orange}{middle}middle 题目描述 你会得到一个双链表&#xff0c;其中包含的节点有一个下一个指针、一个前一个指针和一个额外的 子指针 。这个子指针可能指向一个单独的双向链表&#xff0c;也包含这些特殊的节点。这些子列表可以有一…...

2.5|iot|第1章嵌入式系统概论|操作系统概述|嵌入式操作系统

目录 第1章&#xff1a; 嵌入式系统概论 1.嵌入式系统发展史 2.嵌入式系统定义* 3.嵌入式系统特点* 4.嵌入式处理器的特点 5.嵌入式处理分类 6.嵌入式系统的应用领域及嵌入式系统的发展趋势 第8章&#xff1a;Linux内核配置 1.内核概述 2.内核代码结构 第1章&#xf…...

一文教会你使用ChatGPT画图

引言 当今,ChatGPT在各行各业都有着广泛的应用,其自然语言处理技术也日益成熟。ChatGPT是一种被广泛使用的技术,除了能够生成文本,ChatGPT还可以用于绘图,这为绘图技术的学习和应用带来了新的可能性。本文将介绍如何利用ChatGPT轻松绘制各种形状,为对绘图技术感兴趣的读…...

Java资料分享

随着Java开发的薪资越来越高&#xff0c;越来越多人开始学习 Java 。在众多编程语言中&#xff0c;Java学习难度还是偏高的&#xff0c;逻辑性也比较强&#xff0c;但是为什么还有那么多人要学Java呢&#xff1f;Java语言是目前流行的互联网等企业的开发语言&#xff0c;是市面…...

yum/vim工具的使用

yum 我们生活在互联网发达的时代&#xff0c;手机电脑也成为了我们生活的必须品&#xff0c;在你的脑海中是否有着这样的记忆碎片&#xff0c;在一个明媚的早上你下定决心准备发奋学习&#xff0c;“卸载”了你手机上的所有娱乐软件&#xff0c;一心向学&#xff01;可是到了下…...

内网渗透(三十九)之横向移动篇-pass the ticket 票据传递攻击(PTT)横向攻击

系列文章第一章节之基础知识篇 内网渗透(一)之基础知识-内网渗透介绍和概述 内网渗透(二)之基础知识-工作组介绍 内网渗透(三)之基础知识-域环境的介绍和优点 内网渗透(四)之基础知识-搭建域环境 内网渗透(五)之基础知识-Active Directory活动目录介绍和使用 内网渗透(六)之基…...

Unity性能优化之纹理格式终极篇

知识早班车&#xff1a;1、当n大于1时&#xff0c;2的n次幂一定能被4整除&#xff1b;证明&#xff1a;2^n 2^2*2^(n-1) 4*2^(n-1)2、4的倍数不一定都是2的次幂&#xff1b;证明&#xff1a;4*3 12&#xff1b;12不是2的次幂3、Pixel&#xff08;像素&#xff09;是组成图片…...

【Spark分布式内存计算框架——Spark SQL】9. Dataset(下)RDD、DF与DS转换与面试题

5.3 RDD、DF与DS转换 实际项目开发中&#xff0c;常常需要对RDD、DataFrame及Dataset之间相互转换&#xff0c;其中要点就是Schema约束结构信息。 1&#xff09;、RDD转换DataFrame或者Dataset 转换DataFrame时&#xff0c;定义Schema信息&#xff0c;两种方式转换为Dataset时…...

Windows 环境下,cmake工程导入OpenCV库

目录 1、下载 OpenCV 库 2、配置环境变量 3、CmakeLists.txt 配置 1、下载 OpenCV 库 OpenCV官方下载地址&#xff1a;download | OpenCV 4.6.0 下载完毕后解压&#xff0c;便可以得到下面的文件 2、配置环境变量 我们需要添加两个环境变量&#xff0c;一个是 OpenCVConfi…...

微服务架构设计模式-(16)重构

绞杀者应用程序 由微服务组成的应用程序&#xff0c;将新功能作为服务&#xff0c;并逐步从单体应用中提取服务来实现。好处 尽早并频繁的体现价值 快速开发交付&#xff0c;使用 与之相对的是“一步到位”重构&#xff0c;这时间长&#xff0c;且期间有新的功能加入&#xff…...

数据结构:归并排序和堆排序

归并排序 归并排序(merge sort)是利用“归并”操作的一种排序方法。从有序表的讨论中得知,将两个有序表“归并”为一个有序表,无论是顺序表还是链表,归并操作都可以在线性时间复杂度内实现。归并排序的基本操作是将两个位置相邻的有序记录子序列R[i…m]R[m1…n]归并为一个有序…...

基于easyexcel的MySQL百万级别数据的excel导出功能

前言最近我做过一个MySQL百万级别数据的excel导出功能&#xff0c;已经正常上线使用了。这个功能挺有意思的&#xff0c;里面需要注意的细节还真不少&#xff0c;现在拿出来跟大家分享一下&#xff0c;希望对你会有所帮助。原始需求&#xff1a;用户在UI界面上点击全部导出按钮…...

js-DOM02

1.DOM查询 - 通过具体的元素节点来查询 - 元素.getElementsByTagName() - 通过标签名查询当前元素的指定后代元素 - 元素.childNodes - 获取当前元素的所有子节点 - 会获取到空白的文本子节点 …...

作为一名开发工程师,我对 ChatGPT 的一些看法

ChatGPT 又又火了。 ChatGPT 第一次爆火是2022年12月的时候,我从一些球友的讨论中知道了这个 AI 程序。 今年2月,ChatGPT 的热火更加猛烈,这时我才意识到,原来上次的热火只是我们互联网圈子内部火了,这次是真真正正的破圈了,为大众所熟悉了。 这个 AI 程序是一个智能问…...

Flask中基于Token的身份认证

Flask提供了多种身份认证方式&#xff0c;其中基于Token的身份认证是其中一种常用方式。基于Token的身份认证通常是在用户登录之后&#xff0c;为用户生成一个Token&#xff0c;然后在每次请求时用户将该Token作为请求头部中的一个参数进行传递&#xff0c;服务器端在接收到请求…...

波奇学数据结构:时间复杂度和空间复杂度

数据结构&#xff1a;计算机存储&#xff0c;组织数据方式。数据之间存在多种特定关系。时间复杂度&#xff1a;程序基本操作&#xff08;循环等&#xff09;执行的次数大O渐进法表示法用最高阶的项来表示&#xff0c;且常数变为1。F&#xff08;n&#xff09;3*n^22n1//F(n)为…...

移动OA办公系统为企业带来便捷办公

移动OA系统是指企业员工同手机等移动设备来使用OA办公系统&#xff0c;在外出差的员工只需要通过OA系统的手机APP就可以接收相关的新信息。PC办公与移动OA办公的相结合&#xff0c;构建用户单位随时随地办公的一体化环境。 相比PC办公&#xff0c;移动OA办公给企业带来更多的便…...

电商网站开发 数商云/快手seo

概述 有时候我们需要进行一些预处理和后处理&#xff0c;或者是拦截请求&#xff0c;在请求前后后做一些处理 使用Spring MVC框架&#xff0c;那么建议使用HandlerInterceptor&#xff0c;它可以类似于普通bean直接注册到Spring容器中被管理 HandlerInterceptor的三个抽象方法…...

公安门户网站/sem专员

InitList(&L) //操作结果&#xff1a;构造一个空的线性列表 DestroyList(&L) //初始条件&#xff1a;线性表L已经存在 //操作结果&#xff1a;销毁线性表L ClearList(&L) //初始条件&#xff1a;线性表L已存在 //操作条件&#xff1a;将L重置为空表 ListEmpty(L) /…...

国外设计网站app有哪些/台湾搜索引擎

编辑~/.bashrc文件&#xff0c;然后在最后加上你想设置的目录就可以了。 这样做之后就可以在终端中执行你想要的程序了&#xff0c;不过如果你使用其它程序在后台调用的话可能还是会调用不到&#xff0c;因为这个设置是针对bash有效的。 例如我现在使用eclipse来调用arm-none-e…...

网站如何做双语言/网站在线生成app

JSP&#xff08;Java Server Pages)是由Sun Microsystems公司倡导、许多公司参与一起建立的一种动态网页技术标准。JSP技术有点类似ASP技术&#xff0c;它是在传统的网页HTML文件(*.htm,*.html)中插入Java程序段(Scriptlet)和JSP标记(tag)&#xff0c;从而形成JSP文件(*.jsp)。…...

相亲网站/seo的主要内容

https://www.lydsy.com/JudgeOnline/problem.php?id2809 Description 在一个忍者的帮派里&#xff0c;一些忍者们被选中派遣给顾客&#xff0c;然后依据自己的工作获取报偿。在这个帮派里&#xff0c;有一名忍者被称之为 Master。除了 Master以外&#xff0c;每名忍者都有且仅…...

高端网站建设案例/百度公司的业务范围

Alpha第二天 听说 031502543 周龙荣&#xff08;队长&#xff09; 031502615 李家鹏 031502632 伍晨薇 031502637 张柽 031502639 郑秦 1.前言 任务分配是VV、ZQ、ZC负责前端开发&#xff0c;由JP和LL负责建库和服务器。界面开发的教辅材料是《第一行代码》&#xff0c;利用And…...