线性回归矩阵求解和梯度求解
正规方程求解线性回归
首先正规方程如下:
Θ = ( X T X ) − 1 X T y \begin{equation} \Theta = (X^T X)^{-1} X^T y \end{equation} Θ=(XTX)−1XTy
接下来通过线性代数的角度理解这个问题。
二维空间
在二维空间上,有两个向量 a a a和 b b b,若 b b b投影到 a a a要怎么做,很简单,做垂线, 那么投影后的向量记为 p p p,那么 b b b和 p p p之间的error记为 e = b − p e=b-p e=b−p。同时 p p p在 a a a上,所以 p p p一定是 a a a的 x x x(标量)倍,记为 p = x a p=xa p=xa。因为 e e e垂直 a a a,所以 a T ( b − x a ) = 0 a^T(b-xa)=0 aT(b−xa)=0 ,即 x a T a = a T b xa^Ta=a^Tb xaTa=aTb,得到
x = a T b a T a x=\frac{a^Tb}{a^Ta} x=aTaaTb
那么
p = x a = a a T b a T a p=xa=a\frac{a^Tb}{a^Ta} p=xa=aaTaaTb
根据上面的公式,如果 a a a翻倍了,那么投影不变,如果 b b b翻倍了,投影也翻倍。投影是由一个矩阵 P P P完成的, p = P b p=Pb p=Pb,那么投影矩阵 P P P:
P = a a T a T a P=\frac{aa^T}{a^Ta} P=aTaaaT
用任何向量乘这个投影矩阵,你总会变换到它的列空间中。同时显然有: P T = P P^T=P PT=P , P 2 = P P^2=P P2=P,即投影两次的结果还是和第一次一样。
高维空间
为什么要做投影呢?
因为, A x = b Ax=b Ax=b可能无解,比如一堆等式,比未知数还多,就可能造成无解。那么该怎么办,只能求解最接近的哪个可能解,哪个才是最接近的呢?问题是 A x Ax Ax总是在 A A A的列空间中,而 b b b不一定在。所以要怎么微调 b b b将它变为列空间中最接近它的那一个,那么就将问题换作求解,有解的 A x ^ = p A\hat{x}=p Ax^=p。所以得找最好的那个投影 p p p,以最好的接近 b b b,这就是为什么要引入投影的原因了。
那么我们来看高维空间,这里以三维空间举例,自然可以推广到n维空间。
现在有一个不在平面上的 b b b向量,想要将 b b b投影在平面上,平面可以由两个基向量 a 1 a_1 a1和 a 2 a_2 a2表示。同样的 b b b投影到平面上的误差记为 e = b − p e=b-p e=b−p,这个 e e e是垂直平面的。 p = x 1 ^ a 1 + x 2 ^ a 2 = A x ^ p=\hat{x_1}a_1+\hat{x_2}a_2=A\hat{x} p=x1^a1+x2^a2=Ax^,我们想要解出 x ^ \hat{x} x^。因为 e e e是垂直平面,所以有 b − A x ^ b-A\hat{x} b−Ax^垂直平面,即有 a 1 T ( b − A x ^ ) = 0 a_1^T(b-A\hat{x})=0 a1T(b−Ax^)=0, a 2 T ( b − A x ^ ) = 0 a_2^T(b-A\hat{x})=0 a2T(b−Ax^)=0,表示为矩阵乘法便有
A T ( b − A x ^ ) = A e = 0 A^T(b-A\hat{x})=Ae=0 AT(b−Ax^)=Ae=0
这个形式与二维空间的很像吧。对于 A e = 0 Ae=0 Ae=0,可知 e e e位于 A T A^T AT的零空间,也就是说 e e e垂直于于 A A A的列空间。由上面式子可得
A T A x ^ = A T b A^TA\hat{x}=A^Tb ATAx^=ATb
继而
x ^ = ( A T A ) − 1 A T b \hat{x}=(A^TA)^{-1}A^Tb x^=(ATA)−1ATb
这不就是我们的正规方程吗。到这里我们的正规方程便推导出来了,但为了内容完整,我们下面收个尾。
p = A x ^ = A ( A T A ) − 1 A T b P = A ( A T A ) − 1 A T P T = P P 2 = P p=A\hat{x}=A(A^TA)^{-1}A^Tb \\ P=A(A^TA)^{-1}A^T\\ P^T=P\\ P^2=P p=Ax^=A(ATA)−1ATbP=A(ATA)−1ATPT=PP2=P
这些结论还是和二维空间上的一样, P T = P P^T=P PT=P , P 2 = P P^2=P P2=P,即投影两次的结果还是和第一次一样。
最小二乘法
正规方程的一个常见应用例子是最小二乘法。从线性代数的角度来看,正规方程是通过最小二乘法求解线性回归问题的一种方法。以下是正规方程的概述:
1. 模型表示
在线性回归中,我们假设目标变量 y y y 与特征矩阵 X X X 之间存在线性关系:
y ^ = X θ \hat{y} = X \theta y^=Xθ
其中:
- y ^ \hat{y} y^ 是预测值(一个 m m m 维列向量)。
- X X X 是特征矩阵( m × n m \times n m×n),每行代表一个样本,每列代表一个特征。
- θ \theta θ 是模型参数(权重向量)。
2. 目标函数
我们的目标是最小化预测值与实际值之间的误差,通常使用残差平方和:
J ( θ ) = ∥ y − X θ ∥ 2 J(\theta) = \|y - X\theta\|^2 J(θ)=∥y−Xθ∥2
3. 求解过程
为了找到使得 J ( θ ) J(\theta) J(θ) 最小的 θ \theta θ,我们可以通过对 J ( θ ) J(\theta) J(θ) 关于 θ \theta θ 的导数求解,设导数为零:
∇ J ( θ ) = − 2 X T ( y − X θ ) = 0 \nabla J(\theta) = -2X^T(y - X\theta) = 0 ∇J(θ)=−2XT(y−Xθ)=0
展开后得到:
X T X θ = X T y X^T X \theta = X^T y XTXθ=XTy
4. 正规方程
这个方程称为正规方程,其形式为:
X T X θ = X T y X^T X \theta = X^T y XTXθ=XTy
5. 解的唯一性
- 若 X T X X^T X XTX 是可逆的(即列向量线性无关),则可以通过求逆得到参数的解:
θ = ( X T X ) − 1 X T y \theta = (X^T X)^{-1} X^T y θ=(XTX)−1XTy
- 如果 X T X X^T X XTX 不可逆(即存在多重共线性),则正规方程可能没有唯一解。
6. 几何解释
从几何的角度,正规方程可以被视为在特征空间中寻找一个超平面,使得目标变量 y y y 的投影与预测值 X θ X \theta Xθ 之间的误差最小化。
总结
正规方程通过线性代数的方法为线性回归提供了解的表达式,使得我们可以有效地计算参数。其核心思想是通过最小化残差平方和,寻找最佳拟合的线性模型。
梯度下降求解线性回归
import numpy as np
def linear_regression_gradient_descent(X: np.ndarray, y: np.ndarray, alpha: float, iterations: int) -> np.ndarray:m, n = X.shapetheta = np.zeros((n, 1))for _ in range(iterations):predictions = X @ thetaerrors = predictions - y.reshape(-1, 1)updates = X.T @ errors / mtheta -= alpha * updatesreturn np.round(theta.flatten(), 4)
其他都好理解,下面主要讲梯度updates的推导
1. 定义损失函数
线性回归的损失函数通常是均方误差(Mean Squared Error, MSE):
MSE = 1 2 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 \text{MSE} = \frac{1}{2m} \sum_{i=1}^{m} (h_\theta(x^{(i)}) - y^{(i)})^2 MSE=2m1i=1∑m(hθ(x(i))−y(i))2
这里, h θ ( x ( i ) ) = X ( i ) ⋅ θ h_\theta(x^{(i)}) = X^{(i)} \cdot \theta hθ(x(i))=X(i)⋅θ 是模型的预测值, y ( i ) y^{(i)} y(i) 是实际值。
2. 对损失函数求导
为了最小化损失函数,我们需要对参数 θ \theta θ 求导:
∂ MSE ∂ θ = ∂ ∂ θ ( 1 2 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 ) \frac{\partial \text{MSE}}{\partial \theta} = \frac{\partial}{\partial \theta} \left( \frac{1}{2m} \sum_{i=1}^{m} (h_\theta(x^{(i)}) - y^{(i)})^2 \right) ∂θ∂MSE=∂θ∂(2m1i=1∑m(hθ(x(i))−y(i))2)
应用链式法则,首先求导内部的平方项:
∂ ∂ θ ( h θ ( x ( i ) ) − y ( i ) ) 2 = 2 ( h θ ( x ( i ) ) − y ( i ) ) ⋅ ∂ h θ ( x ( i ) ) ∂ θ \frac{\partial}{\partial \theta} (h_\theta(x^{(i)}) - y^{(i)})^2 = 2(h_\theta(x^{(i)}) - y^{(i)}) \cdot \frac{\partial h_\theta(x^{(i)})}{\partial \theta} ∂θ∂(hθ(x(i))−y(i))2=2(hθ(x(i))−y(i))⋅∂θ∂hθ(x(i))
而且 h θ ( x ( i ) ) = X ( i ) ⋅ θ h_\theta(x^{(i)}) = X^{(i)} \cdot \theta hθ(x(i))=X(i)⋅θ,所以:
∂ h θ ( x ( i ) ) ∂ θ = X ( i ) \frac{\partial h_\theta(x^{(i)})}{\partial \theta} = X^{(i)} ∂θ∂hθ(x(i))=X(i)
将这个结果代入:
∂ MSE ∂ θ = 1 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) X ( i ) \frac{\partial \text{MSE}}{\partial \theta} = \frac{1}{m} \sum_{i=1}^{m} (h_\theta(x^{(i)}) - y^{(i)}) X^{(i)} ∂θ∂MSE=m1i=1∑m(hθ(x(i))−y(i))X(i)
3. 用向量表示
将上述和式转换为向量形式。定义误差向量:
errors = predictions − y \text{errors} = \text{predictions} - y errors=predictions−y
其中 predictions = X ⋅ θ \text{predictions} = X \cdot \theta predictions=X⋅θ。这样,梯度可以表示为:
gradient = 1 m ( X T ⋅ errors ) \text{gradient} = \frac{1}{m} (X^T \cdot \text{errors}) gradient=m1(XT⋅errors)
4. 结论
因此,梯度的计算公式来源于损失函数的求导过程,通过向量化的方式将每个样本的误差与特征相乘,得出对每个参数的影响。这是梯度下降法中更新参数的基础。
相关文章:
线性回归矩阵求解和梯度求解
正规方程求解线性回归 首先正规方程如下: Θ ( X T X ) − 1 X T y \begin{equation} \Theta (X^T X)^{-1} X^T y \end{equation} Θ(XTX)−1XTy 接下来通过线性代数的角度理解这个问题。 二维空间 在二维空间上,有两个向量 a a a和 b b b&…...

M3U8不知道如何转MP4?包能学会的4种格式转换教学!
在流媒体视频大量生产的今天,M3U8作为一种基于HTTP Live Streaming(HLS)协议的播放列表格式,广泛应用于网络视频直播和点播中。它包含了媒体播放列表的信息,指向了视频文件被分割成的多个TS(Transport Stre…...
C++第4课——swap、switch-case-for循环(含视频讲解)
文章目录 1、课程代码2、课程视频 1、课程代码 #include<iostream> using namespace std; int main(){/* //第一个任务:学会swap int a,b,c;//从小到大排序输出 升序 cin>>a>>b>>c;//5 4 3if(a>b)swap(a,b);//4 5 3 swap()函数是用于交…...

大数据新视界 -- 大数据大厂之大数据重塑影视娱乐产业的未来(4 - 4)
💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…...

在Java中,需要每120分钟刷新一次的`assetoken`,并且你想使用Redis作为缓存来存储和管理这个令牌
学习总结 1、掌握 JAVA入门到进阶知识(持续写作中……) 2、学会Oracle数据库入门到入土用法(创作中……) 3、手把手教你开发炫酷的vbs脚本制作(完善中……) 4、牛逼哄哄的 IDEA编程利器技巧(编写中……) 5、面经吐血整理的 面试技…...
linux网络编程7——协程设计原理与汇编实现
文章目录 协程设计原理与汇编实现1. 协程概念2. 协程的实现2.1 setjmp2.2 ucontext2.3 汇编实现2.4 优缺点2.5 实现协程原语2.5.1 create()2.5.2 yield()2.5.3 resume()2.5.4 exit()2.5.5 switch()2.5.6 sleep() 2.6 协程调度器 3. 利用hook使用协程版本的库函数学习参考 协程设…...

Ubuntu22.04版本左右,扩充用户可使用内存
1 取得root权限后,输入命令 lsblk 查看所有磁盘和分区,找到想要替换用户可使用文件夹内存的磁盘和分区。若没有进行分区,并转为所需要的分区数据类型,先进行分区与格式化,过程自行查阅。 扩充替换过程,例如…...

基于ArcMap中Python 批量处理栅格数据(以按掩膜提取为例)
注:图片来源于公众号,公众号也是我自己的。 ArcMap中的python编辑器是很多本科生使用ArcMap时容易忽略的一个工具,本人最近正在读一本书《ArcGIS Python 编程基础与应用》,在此和大家分享、交流一些相关的知识。 这篇文章主要分享…...
【flink】之集成mybatis对mysql进行读写
背景: 在现代大数据应用中,数据的高效处理和存储是核心需求之一。Flink作为一款强大的流处理框架,能够处理大规模的实时数据流,提供丰富的数据处理功能,如窗口操作、连接操作、聚合操作等。而MyBatis则是一款优秀的持…...
Java设计模式—观察者模式详解
引言 模式角色 UML图 示例代码 应用场景 优点 缺点 结论 引言 观察者模式(Observer Pattern)是一种行为设计模式,它定义了对象之间的一对多依赖关系,当一个对象的状态发生改变时,所有依赖于它的对象都会得到通知…...

【Cri-Dockerd】安装cri-dockerd
cri-dockerd的作用: 在k8s1.24之前。k8s会通过dockershim来调用docker进行容器运行时containerd,并且会自动安装dockershim,但是从1.24版本之前k8s为了降低容器运行时的调用的复杂度和效率,直接调用containerd了,并且…...

GCC及GDB的使用
参考视频及博客 https://www.bilibili.com/video/BV1EK411g7Li/?spm_id_from333.999.0.0&vd_sourceb3723521e243814388688d813c9d475f https://www.bilibili.com/video/BV1ei4y1V758/?buvidXU932919AEC08339E30CE57D39A2BABF6A44F&from_spmidsearch.search-result.0…...

大数据新视界 -- 大数据大厂之大数据重塑影视娱乐产业的未来(4 - 3)
💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…...

数据结构——基础知识补充
1.队列 1.普通队列 queue.Queue 是 Python 标准库 queue 模块中的一个类,适用于多线程环境。它实现了线程安全的 FIFO(先进先出)队列。 2.双端队列 双端队列(Deque,Double-Ended Queue)是一种具有队列和…...
只有.git文件夹时如何恢复项目
有时候误删文件但由于.git是隐藏文件夹而幸存,或者项目太大,单单甩给你一个.git文件夹让你自己恢复整个项目,该怎么办呢? 不用担心,只要进行以下步骤,即可把原项目重新搭建起来: 创建一个文件…...
anchor、anchor box、bounding box之间关系
最近学YOLO接触到这些概念,一下子有点蒙,简单总结一下。 anchor和anchor box Anchor:表示一组预定义的尺寸比例,用来代表常见物体的宽高比。可以把它看成是一个模板或规格,定义了物体框的“形状”和“比例”ÿ…...

代码随想录算法训练营第三十天 | 452.用最少数量的箭引爆气球 435.无重叠区间 763.划分字母区间
LeetCode 452.用最少数量的箭引爆气球: 文章链接 题目链接:452.用最少数量的箭引爆气球 思路: 气球的区间有重叠部分,只要弓箭从重叠部分射出来,那么就能减少所使用的弓箭数 **局部最优:**只要有重叠部分…...

海亮科技亮相第84届中国教装展 尽显生于校园 长于校园教育基因
10月25日,第84届中国教育装备展示会(以下简称“教装展”)在昆明滇池国际会展中心开幕。作为国内教育装备领域规模最大、影响最广的专业展会,本届教装展以“数字赋能教育,创新引领未来”为主题,为教育领域新…...
C语言数据结构学习:栈
C语言 数据结构学习 汇总入口: C语言数据结构学习:[汇总] 1. 栈 栈,实际上是一种特殊的线性表。这里使用的是链表栈,链表栈的博客:C语言数据结构学习:单链表 2. 栈的特点 只能在一端进行存取操作&#x…...

如何快速分析音频中的各种频率成分
从视频中提取音频 from moviepy.editor import VideoFileClip# Load the video file and extract audio video_path "/mnt/data/WeChat_20241026235630.mp4" video_clip VideoFileClip(video_path)# Extract audio and save as a temporary file for further anal…...

通过Wrangler CLI在worker中创建数据库和表
官方使用文档:Getting started Cloudflare D1 docs 创建数据库 在命令行中执行完成之后,会在本地和远程创建数据库: npx wranglerlatest d1 create prod-d1-tutorial 在cf中就可以看到数据库: 现在,您的Cloudfla…...
vue3 定时器-定义全局方法 vue+ts
1.创建ts文件 路径:src/utils/timer.ts 完整代码: import { onUnmounted } from vuetype TimerCallback (...args: any[]) > voidexport function useGlobalTimer() {const timers: Map<number, NodeJS.Timeout> new Map()// 创建定时器con…...

现代密码学 | 椭圆曲线密码学—附py代码
Elliptic Curve Cryptography 椭圆曲线密码学(ECC)是一种基于有限域上椭圆曲线数学特性的公钥加密技术。其核心原理涉及椭圆曲线的代数性质、离散对数问题以及有限域上的运算。 椭圆曲线密码学是多种数字签名算法的基础,例如椭圆曲线数字签…...
今日科技热点速览
🔥 今日科技热点速览 🎮 任天堂Switch 2 正式发售 任天堂新一代游戏主机 Switch 2 今日正式上线发售,主打更强图形性能与沉浸式体验,支持多模态交互,受到全球玩家热捧 。 🤖 人工智能持续突破 DeepSeek-R1&…...

【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)
骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术,它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton):由层级结构的骨头组成,类似于人体骨骼蒙皮 (Mesh Skinning):将模型网格顶点绑定到骨骼上,使骨骼移动…...
css3笔记 (1) 自用
outline: none 用于移除元素获得焦点时默认的轮廓线 broder:0 用于移除边框 font-size:0 用于设置字体不显示 list-style: none 消除<li> 标签默认样式 margin: xx auto 版心居中 width:100% 通栏 vertical-align 作用于行内元素 / 表格单元格ÿ…...

3-11单元格区域边界定位(End属性)学习笔记
返回一个Range 对象,只读。该对象代表包含源区域的区域上端下端左端右端的最后一个单元格。等同于按键 End 向上键(End(xlUp))、End向下键(End(xlDown))、End向左键(End(xlToLeft)End向右键(End(xlToRight)) 注意:它移动的位置必须是相连的有内容的单元格…...
Java + Spring Boot + Mybatis 实现批量插入
在 Java 中使用 Spring Boot 和 MyBatis 实现批量插入可以通过以下步骤完成。这里提供两种常用方法:使用 MyBatis 的 <foreach> 标签和批处理模式(ExecutorType.BATCH)。 方法一:使用 XML 的 <foreach> 标签ÿ…...

PHP 8.5 即将发布:管道操作符、强力调试
前不久,PHP宣布了即将在 2025 年 11 月 20 日 正式发布的 PHP 8.5!作为 PHP 语言的又一次重要迭代,PHP 8.5 承诺带来一系列旨在提升代码可读性、健壮性以及开发者效率的改进。而更令人兴奋的是,借助强大的本地开发环境 ServBay&am…...
起重机起升机构的安全装置有哪些?
起重机起升机构的安全装置是保障吊装作业安全的关键部件,主要用于防止超载、失控、断绳等危险情况。以下是常见的安全装置及其功能和原理: 一、超载保护装置(核心安全装置) 1. 起重量限制器 功能:实时监测起升载荷&a…...