当前位置: 首页 > news >正文

PyTorch|Dataset与DataLoader使用、构建自定义数据集

文章目录

  • 一、Dataset与DataLoader
  • 二、自定义Dataset类
    • (一)\_\_init\_\_函数
    • (二)\_\_len\_\_函数
    • (三)\_\_getitem\_\函数
    • (四)全部代码
  • 三、将单个样本组成minibatch(DataLoader)
    • (一)PyTorch的DataLoader源码
      • 1、DataLoader的参数
      • 2、init函数
      • 3、iter函数
    • (二)使用DataLoader遍历


一、Dataset与DataLoader

PyTorch提供的两个常用数据API:

  • torch.utils.data.Dataset:用于处理单个训练样本,读取数据特征、size、标签等,并且包括数据转换等;
  • torch.utils.data.DataLoader:DataLoader在Dataset周围重载一个可迭代对象,以便轻松访问样本。

官方案例: Fashion-MNIST数据集
torchvision:torch的一个视觉库,将torchvision中的datasets导入进来,就能获得其中的各种数据集

FashionMNIST图像存储在目录img_dir中,标签存储在CSV文件annotations_file中

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)

对上述数据集进行可视化:

labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()img, label = training_data[sample_idx]figure.add_subplot(rows, cols, i)plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

二、自定义Dataset类

  • 构建自定义的Dataset类,需要继承TensorFlow的官方dataset类
  • 自定义Dataset类必须实现三个函数:__init__,__len__和__getitem__

pytorch中的dataset类是在pytorch的torch下的utils之下的data文件夹里有一个dataset.py
在这里插入图片描述

(一)__init__函数

包含图像、注释文件和两个转换:

  • annotations_file:标注文件
  • img_dir:图像目录
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file) #标签存储在CSV文件annotations_file中self.img_dir = img_dir #FashionMNIST图像存储在目录img_dir中self.transform = transform #图像转换self.target_transform = target_transform

(二)__len__函数

返回数据集的样本数(就是img_labels的长度)

def __len__(self):return len(self.img_labels)

(三)__getitem_\函数

输入索引index,getitem函数从数据集中加载并返回对应index的一个样本:

def __getitem__(self, idx):#img_labels的第index行第0列标注了对应的照片文件名称img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path) #使用read_image将图像转换为张量label = self.img_labels.iloc[idx, 1] #从self中的csv数据中检索相应的标签#调用转换函数if self.transform: image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label #返回张量图像和相应的标签

(四)全部代码

import os
import pandas as pd
from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

三、将单个样本组成minibatch(DataLoader)

(一)PyTorch的DataLoader源码

1、DataLoader的参数

DataLoader通常是在torch.utils.data下
在这里插入图片描述
常用的参数有:

  • dataset(数据集):需要提取数据的数据集,Dataset对象
  • batch_size(批大小):每一次装载样本的个数,int型
  • shuffle:是否打乱数据顺序
  • sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
  • num_workers:进行数据加载时使用单个进程还是多进程进行加载,多进程意为加载速度更快,一般默认为0,表示使用主进程进行加载
  • collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数,一般用于对于一个batch进行后处理
  • pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中
  • drop_last:当样本数不能被batchsize整除时, 是否舍弃最后一批数据
from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

2、init函数

主要做了三件事:构建sampler、构建batch_sampler、构建collate_fn

定义属性:
在这里插入图片描述
如果设置了自定义的sampler然后又设置了shuffle=true,这种情况是没有意义的:
(shuffle是官方自定义的一个随机sampler)
在这里插入图片描述
设置了batch_sampler的情况下,就不需要设置batch_size、shuffle、sampler和drop_last了:
在这里插入图片描述
如果没有设置sampler,则先判断数据集类型,如果使用的是map-style(else逻辑),就根据是否设置shuffle来选择pytorch内置的sampler:
在这里插入图片描述
设置了batch_size但是没有设置batch_sampler时,会使用内置的BatchSampler:
在这里插入图片描述
如果没有设置collate_fn,就判断auto_collation是否设置(auto_collation是根据batch_sampler是否是None来设置的,如果batch_sampler不是none,auto_collation就是true),default_collate是将batch作为输入,batch输出,并没有对数据做额外处理:
在这里插入图片描述

3、iter函数

iter函数返回的是get_iterator的值:
在这里插入图片描述
get_iterator根据num_workers的设置选择对应的内置DataLoaderIter:
在这里插入图片描述

所以可知,iter函数最终返回的是一个dataloaderiter对象,以SingleProcessDataLoaderIter为例,类里有next_data函数:
在这里插入图片描述
SingleProcessDataLoaderIter类是继承了BaseDataLoaderIter类,BaseDataLoaderIter类中的next函数就是使用了子类中的next_data:
在这里插入图片描述

(二)使用DataLoader遍历

根据上述源码分析,就可以对dataloader去迭代iter之后调用next函数来获得每一批次的数据:

  • 通过DataLoader实现对于数据集的遍历,每次遍历会得到一个batch的数据,这里设置batch_size为64:
from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
  • iter函数将train_dataloader变成一个迭代器,使用next函数可以以此从迭代器中生成一个一个的批次:
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

在这里插入图片描述由于batch_size=64,因此最终返回的Feature batch shape以及Labels batch shape均为64。


参考:
PyTorch官方文档:Datasets & DataLoaders
5、深入剖析PyTorch DataLoader源码

相关文章:

PyTorch|Dataset与DataLoader使用、构建自定义数据集

文章目录 一、Dataset与DataLoader二、自定义Dataset类(一)\_\_init\_\_函数(二)\_\_len\_\_函数(三)\_\_getitem\_\函数(四)全部代码 三、将单个样本组成minibatch(Data…...

4.6(信息差)

🌍 山西500千伏及以上输电线路工程首次采用无人机AI自主验收 🌋 中国与泰国将开展国际月球科研站等航天合作 ✨ 网页版微软 PowerPoint 新特性:可直接修剪视频 🍎 特斯拉开始在德国超级工厂生产出口到印度的右舵车 1.马斯克&…...

关于C#操作SQLite数据库的一些函数封装

主要功能:增删改查、自定义SQL执行、批量执行(事务)、防SQL注入、异常处理 1.NuGet中安装System.Data.SQLite 2.SQLiteHelper的封装: using System; using System.Collections.Generic; using System.Data.SQLite; using System.…...

LeetCode-79. 单词搜索【数组 字符串 回溯 矩阵】

LeetCode-79. 单词搜索【数组 字符串 回溯 矩阵】 题目描述:解题思路一:回溯 回溯三部曲。这里比较关键的是给board做标记,防止之后搜索时重复访问。解题思路二:回溯算法 dfs,直接看代码,很容易理解。visited哈希,防止…...

游戏引擎之高级动画技术

一、动画混合 当我们拥有各类动画素材(clips)时,要将它们融合起来成为一套完整的动画。 最经典的例子就是从走的动画自然的过渡到跑的动画。 1.1 线性插值 不同于上节课的LERP(同一个clip内不同pose之间)&#xff…...

Oracle 数据库中的全文搜索

Oracle 数据库中的全文搜索 0. 引言1. 整体流程2. 创建索引2-1. 创建一个简单的表2-2. 创建文本索引2-3. 查看创建的基础表 3. 运行查询3-1. 运行文本查询3-2. CONTAINS 运算符3-3. 混合查询3-4. OR 查询3-5. 通配符3-6. 短语搜索3-7. 模糊搜索(Fuzzy searches&…...

代码随想录阅读笔记-二叉树【二叉搜索树中的众数】

题目 给定一个有相同值的二叉搜索树(BST),找出 BST 中的所有众数(出现频率最高的元素)。 假定 BST 有如下定义: 结点左子树中所含结点的值小于等于当前结点的值结点右子树中所含结点的值大于等于当前结点的…...

AcWing-游戏

1388. 游戏 - AcWing题库 所需知识:博弈论,区间dp 由于双方都采取最优的策略来取数字,所以结果为确定的,有可能会有多个不同的过程,但是我们只需要关注最终结果就行了。 方法一: 定义dp[i][j] 表示区间…...

Mybatis——一对一映射

一对一映射 预置条件 在某网络购物系统中,一个用户只能拥有一个购物车,用户与购物车的关系可以设计为一对一关系 数据库表结构(唯一外键关联) 创建两个实体类和映射接口 package org.example.demo;import lombok.Data;import …...

Web 安全之 SSL 剥离攻击详解

目录 SSL/TLS简介 SSL 剥离攻击原理 SSL 剥离攻击的影响 SSL 剥离攻击的防范措施 小结 SSL 剥离攻击(SSL Stripping Attack)是一种针对安全套接层(SSL)或传输层安全性(TLS)协议的攻击手段,…...

数据结构——顺序表(C语言)

目录 一、顺序表概念 二、顺序表分类 1.静态顺序表 2.动态顺序表 三、顺序表的实现 1.顺序表的结构体定义 2. 顺序表初始化 3.顺序表销毁 4.顺序表的检验 5.顺序表打印 6.顺序表扩容 7.顺序表尾插与头插 8.尾删与头删 9.在pos处插入数据 10.在pos处删除数据 11.查找数据 …...

利用Idea实现Ajax登录(maven工程)

一、新建一个maven工程(不会建的小伙伴可以参考Idea引入maven工程依赖(保姆级)-CSDN博客),工程目录如图 ​​​​​​​ js文件可以上up网盘提取 链接:https://pan.baidu.com/s/1yOFtiZBWGJY64fa2tM9CYg?pwd5555 提取码&…...

环信IM集成教程——Web端UIKit快速集成与消息发送

写在前面: 千呼万唤始出来,环信Web端终于出UIKit了!🎉🎉🎉 文档地址:https://doc.easemob.com/uikit/chatuikit/web/chatuikit_overview.html 环信单群聊 UIKit 是基于环信即时通讯云 IM SDK 开…...

Anaconda如何切换国内镜像源

一、anaconda如何切换阿里镜像源 在Anaconda中切换到阿里云镜像源可以通过以下步骤进行: 1、打开终端(Windows)或者命令行界面(macOS/Linux)。 2、执行以下命令来配置阿里云镜像源: conda config --add…...

Android 14.0 添加自定义服务,并生成jar给第三方app调用

1.概述 在14.0系统ROM产品定制化开发中,由于需要新增加自定义的功能,所以要增加自定义服务,而app上层通过调用自定义服务,来调用相应的功能,所以系统需要先生成jar,然后生成jar 给上层app调用,接下来就来分析实现的步骤,然后来实现相关的功能 从而来实现所需要的功能 …...

解决沁恒ch592单片机在tmos中使用USB总线时,接入USB Hub无法枚举频繁Reset的问题

开发产品时采用了沁恒ch592,做USB开发时遇到了一个奇葩的无法枚举问题。 典型症状 使用USB线直连电脑时没有问题,可以正常使用。 如果接入某些特定方案的USB Hub(例如GL3510、GL3520),可能会出现以下2种情况&#xf…...

nvm保姆级安装使用教程

✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客 💞当前专栏: 开发环境篇 ✨特色专栏: M…...

大语言模型LLM《提示词工程指南》学习笔记02

文章目录 大语言模型LLM《提示词工程指南》学习笔记02设计提示时需要记住的一些技巧零样本提示少样本提示链式思考(CoT)提示自我一致性生成知识提示 大语言模型LLM《提示词工程指南》学习笔记02 设计提示时需要记住的一些技巧 指令 您可以使用命令来指…...

【realme x2手机解锁BootLoader(简称BL)】

realme手机解锁常识 https://www.realme.com/cn/support/kw/doc/2031665 realme手机解锁支持型号 https://www.realmebbs.com/post-details/1275426081138028544 realme x2手机解锁实践 参考:https://www.realmebbs.com/post-details/1255473809142591488 1 下载apk…...

攻防世界 wife_wife

在这个 JavaScript 示例中,有两个对象:baseUser 和 user。 baseUser 对象定义如下: baseUser { a: 1 } 这个对象有一个属性 a,其值为 1,没有显式指定原型对象,因此它将默认继承 Object.prototype。 …...

Visual Studio安装下载进度为零已解决

因为在安装pytorch3d0.3.0时遇到问题,提示没有cl.exe,VS的C编译组件,可以添加组件也可以重装VS。查了下2019版比2022问题少,选择了安装2019版,下面是下载安装时遇到的问题记录,关于下载进度为零网上有三类解…...

矩阵空间秩1矩阵小世界图

文章目录 1. 矩阵空间2. 微分方程3. 秩为1的矩阵4. 图 1. 矩阵空间 我们以3X3的矩阵空间 M 为例来说明相关情况。目前矩阵空间M中只关心两类计算,矩阵加法和矩阵数乘。 对称矩阵-子空间-有6个3X3的对称矩阵,所以为6维矩阵空间上三角矩阵-子空间-有6个3…...

《QT实用小工具·十三》FlatUI辅助类之各种炫酷的控件集合

1、概述 源码放在文章末尾 FlatUI辅助类之各种炫酷的控件集合 按钮样式设置。文本框样式设置。进度条样式。滑块条样式。单选框样式。滚动条样式。可自由设置对象的高度宽度大小等。自带默认参数值。 下面是demo演示: 项目部分代码如下所示: #ifnd…...

dm8 备份与恢复

dm8 备份与恢复 基础环境 操作系统:Red Hat Enterprise Linux Server release 7.9 (Maipo) 数据库版本:DM Database Server 64 V8 架构:单实例1 设置bak_path路径 --创建备份文件存放目录 su - dmdba mkdir -p /dm8/backup--修改dm.ini 文件…...

Vue项目中引入html页面(vue.js中引入echarts数据大屏html [静态非数据传递!] )

在项目原有vue(例如首页)基础上引入html页面 1、存放位置 vue3原有public文件夹下 我这边是新建一个static文件夹 专门存放要用到的html文件 复制拖拽过来 index为html的首页 2、更改路径引入到vue中 这里用到的是 iframe 方法 不同于vue的 component…...

ASTM C1186-22 纤维水泥平板

以无石棉类无机矿物纤维、有机合成纤维或纤维素纤维,单独或混合作为增强材料,以普通硅酸盐水泥或水泥中添加硅质、钙质材料代替部分水泥为胶凝材料,经制浆、成型、蒸汽或高压蒸汽养护制成的板材,俗称水泥压力板。 ASTM C1186-22纤…...

NoSQL概述

NoSQL概述 目录 一、为什么用NoSQL 二、什么是NoSQL 三、经典应用分析 四、N o S Q L 数 据 模 型 简 介 五、NoSQL四大分类 六、CAP BASE 一、为什么用NoSQL 1、单机MySQL的美好年代 在90年代,一个网站的访问量一般不大,用单个数据库完全可以轻松应…...

爬虫实战一、Scrapy开发环境(Win10+Anaconda3)搭建

#前言 在这儿推荐使用Anaconda进行安装,并不推荐大家用pythonpip安装,因为pythonpip的坑实在是太多了。 #一、环境中准备: Win10(企业版)Anaconda3-5.0.1-Windows-x86_64,下载地址,如果打不开…...

llama.cpp运行qwen0.5B

编译llama.cp 参考 下载模型 05b模型下载 转化模型 创建虚拟环境 conda create --prefixD:\miniconda3\envs\llamacpp python3.10 conda activate D:\miniconda3\envs\llamacpp安装所需要的包 cd G:\Cpp\llama.cpp-master pip install -r requirements.txt python conver…...

【接口】HTTP(3) |GET和POST两种基本请求方法有什么区别

在我面试时,在我招人面试别人时,10次能遇到7次这个问题,我听过我也说回答过: Get: 一般对于从服务器取数据的请求可以设置为get方式 Get方式在传递参数的时候,一般都会把参数直接拼接在url上 Get请求方法…...

陕西省西安市网站建设公司/个人免费网上注册公司

一.FTP 是File Transfer Protocol(文件传输协议)的英文简称,而中文简称为“文传协议”。用于Internet上的控制文件的双向传输。同时,它也是一个应用程序(Application)。 基于不同的操作系统有不同的FTP应用…...

wordpress 主页 导航/武汉百度推广seo

所谓的组件就是指封装了一些代码进行复用,以减少代码冗余性,使代码更加简洁优雅注意:若js注册组件名时采用了驼峰命名法,则在html中要加横线,否则无法解析不论是哪种方式创建出来的组件,必须只有一个根元素…...

世界经济新闻/360优化大师app

能解决题目的代码并不是一次就可以写好的我们需要根据我们的思路写出后通过debug模式找到不足再进行更改多次测试后才可得到能解决题目的代码!通过学习,练习【Java基础经典练习题】,让我们一起来培养这种解决问题思路。一、视频讲解思路讲解h…...

wordpress数据主机名/郭生b如何优化网站

要求 本次任务的目的是处理PO2,PCO2两个指标。这两个指标均为病人的血气指标,以一定的时间间隔采集。一个病人一次住院期间可能收集一次或者多次。要求,按照采集时间的前后顺序,汇总每个病人每次住院期间的所有的pO2, pCO2指标值…...

关于政府网站建设的文件/百度seo优化方案

《我的成功可以复制》--唐骏 这本书花了一周时间把他读完了。从书中学到和了解到了许多以前不知道或很模糊的东西。 我就谈谈我读了这本书的读后感吧。 1.我知道了唐骏是一个什么样的人,他的童年的家境不算富裕,他也是一个从小就过苦生活的人。 但是他的…...

wordpress字母头像纯css生成/游戏推广员怎么做

之前我们为大家介绍过一项非常酸爽的研究“Talking Face Generation”:给定音频或视频后(输入),可以让任意一个人的面部特征与输入的音视频信息保持一致,也就是说出输入的这段话。当时就想到了“杨超越的声音高晓松的脸”这样的神仙搭配。不过…...