当前位置: 首页 > news >正文

深度学习中的早停法

早停法(Early Stopping)是一种用于防止模型过拟合的技术,在训练过程中监视验证集(或者测试集)上的损失值。具体设立早停的限制包括两个主要参数:

  1. Patience(耐心):这是指验证集损失在连续多少个epoch没有显著改善时,才触发早停。当验证集损失连续几个epoch没有下降或者停止减少时,表示模型可能已经过拟合或者陷入局部最优点,这时候早停就会被触发。

  2. Best Loss(最佳损失):这是指在早停过程中保存的最低验证集损失值。当验证集损失值低于当前最佳损失时,更新最佳损失并重置耐心计数器。如果验证集损失连续不降,耐心计数器超过设定的耐心值时,早停就会被触发,训练过程停止。

    早停的具体设立是基于验证集上的损失值 val_loss。每次验证后,如果当前的 val_lossbest_loss 还要低,就更新 best_loss 并重置 patience_counter;否则,增加 patience_counter。当 patience_counter 达到设定的 patience 值时,早停被触发,即停止训练过程以防止模型过拟合。

    总结来说,早停的设立限制是基于耐心参数和最佳损失值,用来判断模型是否应该停止训练以避免过拟合。

# 训练模型
num_epochs = 200  # 总的训练轮数
best_loss = float('inf')  # 初始化最佳验证损失为正无穷大
patience = 10  # 早停的耐心值
patience_counter = 0  # 耐心计数器for epoch in range(num_epochs):model.train()for geno, pheno in train_loader:optimizer.zero_grad()  # 梯度清零outputs = model(geno)  # 前向传播loss = criterion(outputs.squeeze(), pheno)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 优化模型参数model.eval()val_loss = 0with torch.no_grad():  # 不计算梯度for geno, pheno in test_loader:outputs = model(geno)  # 前向传播val_loss += criterion(outputs.squeeze(), pheno).item()  # 计算验证损失val_loss /= len(test_loader)  # 计算平均验证损失print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}')scheduler.step(val_loss)  # 更新学习率# 早停法if val_loss < best_loss:best_loss = val_loss  # 更新最佳验证损失patience_counter = 0  # 重置耐心计数器else:patience_counter += 1  # 增加耐心计数器if patience_counter >= patience:  # 如果耐心计数器达到设定的耐心值print("Early stopping triggered")  # 触发早停break
  1. EarlyStopping
    • __init__ 方法初始化早停的参数,如 patience(耐心值)、verbose(是否打印消息)和 delta(损失改进的最小变化)。
    • __call__ 方法根据验证损失来决定是否更新 best_loss,以及是否增加计数器或者触发早停。
  2. 训练循环
    • 训练和验证过程与之前相同。
    • 每个epoch结束时,调用 early_stopping 对象,传入当前的验证损失。
    • 检查 early_stopping.early_stop 标志,如果为 True,则打印消息并停止训练。

通过使用 EarlyStopping 类,你可以更简洁和模块化地实现早停功能,使代码更易于维护和扩展。

import torch
import numpy as npclass EarlyStopping:def __init__(self, patience=10, verbose=False, delta=0):"""EarlyStopping 初始化.Args:patience (int): 当验证集损失在指定的epoch数内没有减少时触发早停.verbose (bool): 如果为True,则每次验证集损失改进时会打印一条消息.delta (float): 验证集损失改进的最小变化."""self.patience = patienceself.verbose = verboseself.delta = deltaself.best_loss = Noneself.counter = 0self.early_stop = Falsedef __call__(self, val_loss):if self.best_loss is None:self.best_loss = val_losselif val_loss > self.best_loss - self.delta:self.counter += 1if self.verbose:print(f'EarlyStopping counter: {self.counter} out of {self.patience}')if self.counter >= self.patience:self.early_stop = Trueelse:self.best_loss = val_lossself.counter = 0if self.verbose:print(f'Validation loss decreased to {self.best_loss:.6f}. Resetting counter.')# 初始化EarlyStopping对象
early_stopping = EarlyStopping(patience=10, verbose=True)# 训练模型
num_epochs = 200
for epoch in range(num_epochs):model.train()for geno, pheno in train_loader:optimizer.zero_grad()outputs = model(geno)loss = criterion(outputs.squeeze(), pheno)loss.backward()optimizer.step()model.eval()val_loss = 0with torch.no_grad():for geno, pheno in test_loader:outputs = model(geno)val_loss += criterion(outputs.squeeze(), pheno).item()val_loss /= len(test_loader)print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}')scheduler.step(val_loss)# 检查是否触发早停early_stopping(val_loss)if early_stopping.early_stop:print("Early stopping triggered")break

 

 

相关文章:

深度学习中的早停法

早停法&#xff08;Early Stopping&#xff09;是一种用于防止模型过拟合的技术&#xff0c;在训练过程中监视验证集&#xff08;或者测试集&#xff09;上的损失值。具体设立早停的限制包括两个主要参数&#xff1a; Patience&#xff08;耐心&#xff09;&#xff1a;这是指验…...

科普文:JUC系列之多线程门闩同步器CountDownLatch的使用和源码

CountDownLatch类位于java.util.concurrent包下&#xff0c;利用它可以实现类似计数器的功能。比如有一个任务A&#xff0c;它要等待其他10个线程的任务执行完毕之后才能执行&#xff0c;此时就可以利用CountDownLatch来实现这种功能了。 CountDownLatch是通过一个计数器来实现…...

foreach循环和for循环在PHP中各有什么优势

在PHP中&#xff0c;foreach循环和for循环都是用来遍历数组的常用结构&#xff0c;但它们各有其优势和使用场景。 foreach循环的优势 简化代码&#xff1a;foreach循环提供了一种更简洁的方式来遍历数组&#xff0c;不需要手动控制索引或指针。易于阅读&#xff1a;对于简单的…...

巧用casaos共享挂载自己的外接硬盘为局域网共享

最近入手了个魔改机顶盒&#xff0c;已经刷好了的armbian&#xff0c;虽然是原生的&#xff0c;但是我觉得挺强大的&#xff0c;内置了很多 常用的docker和应用&#xff0c;只需要armbian-software 安装就行&#xff0c;缺点就是emmc太小了。 买到之后第一时间装上了casaos和1p…...

标题:解码“八股文”:助力、阻力,还是空谈?

标题&#xff1a;解码“八股文”&#xff1a;助力、阻力&#xff0c;还是空谈&#xff1f; 在程序员的面试与职场发展中&#xff0c;“八股文”一直是一个备受争议的话题。它既是求职者展示自己技术功底的途径&#xff0c;也是一些公司筛选人才的标准之一。但“八股文”在实际…...

语言无界,沟通无限:2024年好用在线翻译工具推荐

随着技术的发展现在的翻译在线工具从基础词句翻译到复杂的文章翻译都不在话下。为了防止你被五花八门的工具挑花眼&#xff0c;我给你介绍几款我用过的便捷、高效、准确的翻译工具吧。 1.福晰翻译端 链接直通&#xff1a;https://www.foxitsoftware.cn/fanyi/ 这个软件支持…...

【Golang 面试 - 进阶题】每日 3 题(十八)

✍个人博客&#xff1a;Pandaconda-CSDN博客 &#x1f4e3;专栏地址&#xff1a;http://t.csdnimg.cn/UWz06 &#x1f4da;专栏简介&#xff1a;在这个专栏中&#xff0c;我将会分享 Golang 面试中常见的面试题给大家~ ❤️如果有收获的话&#xff0c;欢迎点赞&#x1f44d;收藏…...

二分+dp,CF 1993D - Med-imize

一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 D - Med-imize 二、解题报告 1、思路分析 对于n < k的情况直接排序就行 对于n > k的情况 最终的序列长度一定是 (n - 1) % k 1 这个序列是原数组的一个子序列 对于该序列的第一个元素&#xff0…...

三十种未授权访问漏洞复现 合集( 三)

未授权访问漏洞介绍 未授权访问可以理解为需要安全配置或权限认证的地址、授权页面存在缺陷&#xff0c;导致其他用户可以直接访问&#xff0c;从而引发重要权限可被操作、数据库、网站目录等敏感信息泄露。---->目录遍历 目前主要存在未授权访问漏洞的有:NFS服务&a…...

数据湖和数据仓库核心概念与对比

随着近几年数据湖概念的兴起&#xff0c;业界对于数据仓库和数据湖的对比甚至争论就一直不断。有人说数据湖是下一代大数据平台&#xff0c;各大云厂商也在纷纷的提出自己的数据湖解决方案&#xff0c;一些云数仓产品也增加了和数据湖联动的特性。但是数据仓库和数据湖的区别到…...

探索WebKit的奥秘:打造高效、兼容的现代网页应用

1. 简介 1.1. 主要特点 WebKit 是一个开源的浏览器引擎,它允许开发者构建高性能、功能丰富的 web 应用程序。WebKit 与 Mozilla Firefox 等使用的 Gecko 引擎、Internet Explorer 使用的 Trident 引擎以及 EdgeHTML 引擎共同构成了现代 web 浏览器的核心技术。 1.2. 学习资…...

【leetcode】平衡二叉树、对称二叉树、二叉树的层序遍历(广度优先遍历)(详解)

Hi~&#xff01;这里是奋斗的明志&#xff0c;很荣幸您能阅读我的文章&#xff0c;诚请评论指点&#xff0c;欢迎欢迎 ~~ &#x1f331;&#x1f331;个人主页&#xff1a;奋斗的明志 &#x1f331;&#x1f331;所属专栏&#xff1a;数据结构、LeetCode专栏 &#x1f4da;本系…...

最短路径算法:Floyd-Warshall算法

引言 在图论中&#xff0c;Floyd-Warshall算法是一种用于计算任意两点之间最短路径的动态规划算法。它适用于加权有向图和无向图&#xff0c;可以处理带有负权重边的图&#xff0c;但要求图中不能有负权重环。本文将详细介绍Floyd-Warshall算法的定义、步骤及其实现。 Floyd-…...

3DM游戏运行库合集离线安装包2024最新版

3DM游戏运行库合集离线安装包是一款由国内最大的游戏玩家论坛社区3DM推出的集成式游戏运行库合集软件&#xff0c;旨在解决玩家在玩游戏时遇到的运行库缺失或错误问题。该软件包含多种常用的系统运行库组件&#xff0c;支持32位和64位操作系统&#xff0c;能够自动识别系统版本…...

【Bigdata】什么是混合型联机分析处理

这是我父亲 日记里的文字 这是他的生命 留下留下来的散文诗 几十年后 我看着泪流不止 可我的父亲已经 老得像一个影子 &#x1f3b5; 许飞《父亲写的散文诗》 混合型联机分析处理&#xff08;Hybrid OLAP&#xff0c;简称 HOLAP&#xff09;是一种结合了多…...

Java 并发编程:volatile 关键字介绍与使用

大家好&#xff0c;我是栗筝i&#xff0c;这篇文章是我的 “栗筝i 的 Java 技术栈” 专栏的第 026 篇文章&#xff0c;在 “栗筝i 的 Java 技术栈” 这个专栏中我会持续为大家更新 Java 技术相关全套技术栈内容。专栏的主要目标是已经有一定 Java 开发经验&#xff0c;并希望进…...

【Spark计算引擎----第三篇(RDD)---《深入理解 RDD:依赖、Spark 流程、Shuffle 与缓存》】

前言&#xff1a; &#x1f49e;&#x1f49e;大家好&#xff0c;我是书生♡&#xff0c;本阶段和大家一起分享和探索大数据技术Spark—RDD&#xff0c;本篇文章主要讲述了&#xff1a;RDD的依赖、Spark 流程、Shuffle 与缓存等等。欢迎大家一起探索讨论&#xff01;&#xff0…...

四、日志收集loki+ promtail+grafana

一、简介 Loki是受Prometheus启发由Grafana Labs团队开源的水平可扩展&#xff0c;高度可用的多租户日志聚合系统。 开发语言: Google Go。它的设计具有很高的成本效益&#xff0c;并且易于操作。使用标签来作为索引&#xff0c;而不是对全文进行检索&#xff0c;也就是说&…...

xdma的linux驱动编译给arm使用(中断检测-测试程序)

1、驱动链接 XDMA驱动源码官网下载地址为&#xff1a;https://github.com/Xilinx/dma_ip_drivers 下载最新版本的XDMA驱动源码&#xff0c;即master版本&#xff0c;否则其驱动用不了&#xff08;xdma ip核版本为4.1&#xff09;。 2、驱动 此部分来源于博客&#xff1a;xd…...

探索之路——初识 Vue Router:构建单页面应用的完整指南

目录 1. Vue Router 简介 2. 安装与配置 Vue Router 安装步骤 配置路由 3. 在 Vue 应用中使用路由 4. 进阶使用 路由守卫 懒加载 高级路由技术 嵌套路由 动态路由匹配 编程式的路由导航 路由懒加载 路由元信息 在现代前端开发中,单页面应用(SPA)因其出…...

传输层_计算机网络

文章目录 运输层UDPTCPTCP连接管理TCP三次握手TCP四次挥手 可靠机制流量控制拥塞控制 QUIC 运输层 网络层提供了主机之间的逻辑通信 运输层为运行在不同主机上的进程之间提供了逻辑通信 UDP(用户数据报协议)提供一种不可靠、无连接的服务&#xff0c;数据报 TCP(传输控制协议)…...

自动驾驶的六个级别是什么?

自动驾驶汽车和先进的驾驶辅助系统&#xff08;ADAS&#xff09;预计将帮助拯救全球数百万人的生命&#xff0c;消除拥堵&#xff0c;减少排放&#xff0c;并使我们能够在人而不是汽车周围重建城市。 自动驾驶的世界并不只由一个维度组成。从没有任何自动化到完整的自主体验&a…...

深度学习复盘与论文复现F

文章目录 1、Environment construction1.1 macos conda1.2 macos PyTorch1.3 iTerm settings1.4 install jupyter 2、beam search2.1 greedy search2.2 exhaustive search2.3 beam search 3、Attention score3.1 Masking softmax operation3.2 Additive attention3.3 Zoom dot …...

如何学习自动化测试工具!

要学习和掌握自动化测试工具的使用方法&#xff0c;可以按照以下步骤进行&#xff1a; 一、明确学习目标 首先&#xff0c;需要明确你想要学习哪种自动化测试工具。自动化测试工具种类繁多&#xff0c;包括但不限于Selenium、Appium、JMeter、Postman、Robot Framework等&…...

短信接口被恶意盗刷

短信接口被恶意盗刷是指攻击者通过各种手段&#xff0c;大量发送短信请求&#xff0c;导致短信资源被浪费&#xff0c;服务提供商可能面临经济损失&#xff0c;正常用户的服务也可能受到影响。以下是一些可能导致短信接口被恶意盗刷的原因和相应的解决方案&#xff1a; 原因&a…...

实验4-2-1 求e的近似值

//实验4-2-1 求e的近似值 /* 自然常数 e 可以用级数 11/1!1/2!⋯1/n!⋯ 来近似计算。 本题要求对给定的非负整数 n&#xff0c;求该级数的前 n1 项和。 输入格式:输入第一行中给出非负整数 n&#xff08;≤1000&#xff09;。 输出格式:在一行中输出部分和的值&#xff0c;保留…...

内网穿透--LCX+portmap转发实验

实验背景 通过公司带有防火墙功能的路由器接入互联网&#xff0c;然后由于私网IP的缘故&#xff0c;公网 无法直接访问内部web服务器主机&#xff0c;通过内网其它主机做代理&#xff0c;穿透访问内网web 服务器主机 实验设备 1. 路由器、交换机各一台 2. 外网 kali 一台&…...

缓存一致性问题

1. 引言 1.1 数据库与缓存的工程实践 在软件工程领域&#xff0c;数据库&#xff08;Database&#xff09;和缓存&#xff08;Cache&#xff09;是两种常见的数据存储解决方案&#xff0c;它们在系统架构中扮演着至关重要的角色。数据库是数据持久化的后端存储&#xff0c;它…...

【MYSQL】MYSQL逻辑架构

mysql逻辑架构分为3层 mysql逻辑架构分为3层 1). 连接层&#xff1a;主要完成一些类似连接处理&#xff0c;授权认证及相关的安全方案。 2). 服务层&#xff1a;在 MySQL据库系统处理底层数据之前的所有工作都是在这一层完成的&#xff0c;包括权限判断&#xff0c;SQL接口&…...

【Python】数据类型之字符串

本篇文章将继续讲解字符串其他功能&#xff1a; 1、求字符串长度 功能&#xff1a;len(str) &#xff0c;该功能是求字符串str的长度。 代码演示&#xff1a; 2、通过索引获取字符串的字符。 功能&#xff1a;str[a] str为字符串&#xff0c;a为整型。该功能是获取字符…...