当前位置: 首页 > news >正文

Pytorch构建自己的数据集

1.Pytorch内置的Dataset

Pytorch中内置了许多数据集,我们可以从torchvision库中进行导入。比如,我们可以导入Fashion-MNIST数据集

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)

但如果torchvision库中没有该数据集,我们需要自己构建一个。
其中一个方法就是把构建好的数据集使用torch.utils.data.TensorDataset()封装以下,然后再传入torch.utils.data.DataLoader

trainloader  =  torch.utils.data.DataLoader(training_data, batch_size=32, shuffle=True)
testloader  =  torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False)

但是如果自己写一个类的话会高达上一些,嘻嘻。下面看看如何自己构建一个Dataset Class。

2.Build Custom Dataset

构建一个Custom Dataset需要继承``三个函数__init__, __len__, 和 __getitem__

  • __init__: 对类进行初始化
  • __len__: 使该类可以返回dataset样本数量
  • __getitem__: 给定一个idx,从数据集中导入并返回一个样本

下面我们来看看该如何构建Custom Dataset:

import os
import pandas as pd
from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file) # load labelself.img_dir = img_dirself.transform = transform # transformationself.target_transform = target_transformdef __len__(self):return len(self.img_labels) # 返回sample的个数def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path) # load idx-th imagelabel = self.img_labels.iloc[idx, 1] # load idx-th labelif self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

注意:同时,__len__控制着产生样本的总个数。例如,如果总共有20个样本,我们希望20个样本全都放入dataloader中,则:

def __len(self):return 20

但如果我们只希望有20个样本中的15个放入到dataloader中,则:

def __len(self):return 15

但值得注意的是,return返回的数不能大于样本的总个数,即要小于等于20。并且,当返回的数小于总样本个数的时候,是取索引的前几个数,最后的几个数不会被放入dataloader中。比如return 15,是将data[:15]个数放入dataloader,而后5个数要舍去。可以用如下代码验证:

>>> data = np.arange(15).reshape(5,3)
>>> print(data)
array([[ 0, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11],[12, 13, 14]])
>>> class Data(Dataset):
...		def __init__(self, data) -> None:
...			super(Data, self).__init__()
...			self.data = data
...		def __len__(self):
...			return 4
...		def __getitem__(self, index):
...			out = self.data[index]
...			return torch.from_numpy(out)
>>> loader = DataLoader(Data(data), batch_size=4, shuffle=True)
>>> for i, x in enumerate(loader):
...		print(i, x)0 tensor([[ 3, 4, 5],[ 9, 10, 11],[ 0, 1, 2],[ 6, 7, 8]])

可以发现,无论如何都不会输出[12, 13, 14]

Reference:
Pytorch official tutorial
Writing custom datasets dataloaders and transforms

相关文章:

Pytorch构建自己的数据集

1.Pytorch内置的Dataset Pytorch中内置了许多数据集,我们可以从torchvision库中进行导入。比如,我们可以导入Fashion-MNIST数据集 import torch from torch.utils.data import Dataset from torchvision import datasets from torchvision.transforms …...

信息论小课堂:纠错码(海明码在信息传输编码时,通过巧妙的信道编码保证有了错误能够自动纠错。)

文章目录 引言I 纠错1.1 信息纠错的前提:信息冗余1.2 发现抄写错误的方法1.3 计算机的信息校验原理:奇偶校验1.4 有效的纠错编码II 案例2.1 例子1:自身DNA的编码2.2 例子2:海明码引言 预则立,不预则废:不确定性是我们这个世界自然的属性,在解决问题之前,要考虑到世界的不…...

MySQL执行计划(explain)

MySQL执行计划(explain) 1.什么是执行计划 2.如何分析执行计划 执行计划一共有12列,每一列都有着特殊的含义,接下来我们逐一分析 id select语句的查询顺序,包含一组数字,如果数字相同则从上到下,如果数字不同则从大到小。 select_type …...

思必驰回复第二轮审核问询,如何与科大讯飞、阿里巴巴“虎口夺食”?

‍数据智能产业创新服务媒体——聚焦数智 改变商业3月21日,思必驰科技股份有限公司(以下简称“思必驰”)更新上市申请审核动态,已回复上交所第二轮审核问询函,回复了涵盖关于实际控制人的认定、关于预计持续亏损及关于…...

基于Spring、SpringMVC、MyBatis的汽车租赁系统设计

文章目录 项目介绍主要功能截图:前台首页汽车信息列表汽车租赁留言反馈个人信息管理后台汽车类型管理汽车信息管理租赁信息管理用户管理续租信息管理归还信息管理保险信息管理违章登记管理部分代码展示设计总结项目获取方式🍅 作者主页:Java韩立 🍅 简介:Java领域优质创…...

读《刻意练习》后感,与原文好句摘抄

第一章,有目的的练习 所谓“天真的练习”,基本上只是反复的做某件事情,并指望只靠这种反复的练习,就能够提高表现和水平。 有目的练习的四个特点 有目的的练习具有定义明确的特定目标有目的的练习是专注的有目的的练习包含反馈…...

华为OD机试用java实现 -【选座位】

最近更新的博客 华为od 2023 | 什么是华为od,od 薪资待遇,od机试题清单华为OD机试真题大全,用 Python 解华为机试题 | 机试宝典【华为OD机试】全流程解析+经验分享,题型分享,防作弊指南华为od机试,独家整理 已参加机试人员的实战技巧本篇题解:选座位 题目 疫情期间需要大…...

国产蓝牙耳机怎么挑选?口碑最好的国产蓝牙耳机

蓝牙耳机已经成为现代人生活中必不可少的设备之一,因此市场上涌现出了众多的品牌和型号。但是,在这个竞争激烈的市场中,哪些品牌的蓝牙耳机更受欢迎呢?以下是几款口碑不错的蓝牙耳机品牌。 一、南卡小音舱蓝牙耳机 推荐系数&…...

seaborn从入门到精通03-绘图功能实现02-分类绘图Categorical plots

seaborn从入门到精通03-绘图功能实现02-分类绘图Categorical plots总结参考关系-分布-分类分类绘图-Visualizing categorical data图形级接口catplot--figure-level interface导入库与查看tips和diamonds 数据分类散点图参考分布散点图stripplot分布密度散点图-swarmplot&#…...

❤️独特的算法❤️:一文解决编辑距离问题

编辑距离问题 题目关键点115. 不同的子序列 - 力扣(LeetCode)*dp数组定义,情况讨论583. 两个字符串的删除操作 - 力扣(LeetCode)两个字符串删除,情况讨论多加一种72. 编辑距离 - 力扣(LeetCode…...

三次样条样条:Bézier样条和Hermite样条

总结 What is the Difference Between Natural Cubic Spline, Hermite Spline, Bzier Spline and B-spline? 1.多项式拟合中的 Runge Phenomenon 找到一条通过N1个点的多项式曲线 ,需要N次曲线。通过两个点的多项式曲线为一次,三个点的多项式曲线为二…...

Redis面试题 (2023最新版)

文章目录一、Redis为什么快?1、纯内存访问2、单线程,避免上下文切换3、渐进式ReHash、缓存时间戳(1)渐进式ReHash:(2)缓存时间戳:二、Redis合适的应用场景常用基本数据类型&#xff…...

基于springboot实现家乡特色食品景点推荐系统【源码+论文】分享

基于springboot实现家乡特色推荐系统演示开发语言:Java 框架:springboot JDK版本:JDK1.8 服务器:tomcat7 数据库:mysql 5.7 数据库工具:Navicat11 开发软件:eclipse/myeclipse/idea Maven包&…...

Spring MVC 启动之 HandlerMapping

在上一篇文章中,我们介绍了 Spring MVC 的启动流程,接下来我们将发分多个篇章详细介绍流程中的重点步骤 今天我们从 HandlerMapping 开始分析,HandlerMapping 是框架中的一个非常重要的组件。它的作用是将URL请求映射到合适的处理程序&#x…...

基于YOLOv5的停车位检测系统(清新UI+深度学习+训练数据集)

摘要:基于YOLOv5的停车位检测系统用于露天停车场车位检测,应用深度学习技术检测停车位是否占用,以辅助停车场对车位进行智能化管理。在介绍算法原理的同时,给出Python的实现代码、训练数据集以及PyQt的UI界面。博文提供了完整的Py…...

【Linux系统编程】5.vim基本操作命令

目录 跳转到指定行 命令模式 末行模式 跳转行首 跳转行尾 自动格式化代码 大括号、中括号、小括号对应 光标移至行首 光标移至行尾 删除单个字符 删除一个单词 删除光标至行尾 删除光标至行首 替换单个字符 删除指定区域 删除指定1行 删除指定多行 复制一行 …...

主流机器学习平台调研与对比分析

梗概 本报告主要调研目前主流的机器学习平台,包括但不限于Amazon的Sage maker,Alibaba的PAI,Baidu的PaddlePaddle。对产品的定位、功能、实践、定价四个方面进行详细解析,并通过标杆对比分析提出一套机器学习平台评价体系&#x…...

作业帮基于明道云开展的硬件业务数字化建设

今天由我代表作业帮来介绍公司在低代码平台应用的一些经验和心得。我今天分享的内容包含两部分,一个是作业帮硬件的介绍,另一个是基于明道云的系统能力建设,也是我们自己总结的经验,希望能给大家带来一些启发。 一、关于作业帮 …...

位图及布隆过滤器的模拟实现与面试题

位图 模拟实现 namespace yyq {template<size_t N>class bitset{public:bitset(){_bits.resize(N / 8 1, 0);//_bits.resize((N >> 3) 1, 0);}void set(size_t x)//将某位做标记{size_t i x / 8; //第几个char对象size_t j x % 8; //这个char对象的第几个比特…...

在 Python 中将天数添加到日期

使用 datetime 模块中的 timedelta() 方法将天数添加到日期中&#xff0c;例如 result_1 date_1 timedelta(days3)。 timedelta 方法可以传递天数参数并将指定的天数添加到日期。 from datetime import datetime, date, timedelta# ✅ 将天数添加到日期 my_str 09-24-2023 …...

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)

题目&#xff1a;3442. 奇偶频次间的最大差值 I 思路 &#xff1a;哈希&#xff0c;时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况&#xff0c;哈希表这里用数组即可实现。 C版本&#xff1a; class Solution { public:int maxDifference(string s) {int a[26]…...

Java 语言特性(面试系列2)

一、SQL 基础 1. 复杂查询 &#xff08;1&#xff09;连接查询&#xff08;JOIN&#xff09; 内连接&#xff08;INNER JOIN&#xff09;&#xff1a;返回两表匹配的记录。 SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id d.dept_id; 左…...

vscode(仍待补充)

写于2025 6.9 主包将加入vscode这个更权威的圈子 vscode的基本使用 侧边栏 vscode还能连接ssh&#xff1f; debug时使用的launch文件 1.task.json {"tasks": [{"type": "cppbuild","label": "C/C: gcc.exe 生成活动文件"…...

大数据零基础学习day1之环境准备和大数据初步理解

学习大数据会使用到多台Linux服务器。 一、环境准备 1、VMware 基于VMware构建Linux虚拟机 是大数据从业者或者IT从业者的必备技能之一也是成本低廉的方案 所以VMware虚拟机方案是必须要学习的。 &#xff08;1&#xff09;设置网关 打开VMware虚拟机&#xff0c;点击编辑…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)

服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

1688商品列表API与其他数据源的对接思路

将1688商品列表API与其他数据源对接时&#xff0c;需结合业务场景设计数据流转链路&#xff0c;重点关注数据格式兼容性、接口调用频率控制及数据一致性维护。以下是具体对接思路及关键技术点&#xff1a; 一、核心对接场景与目标 商品数据同步 场景&#xff1a;将1688商品信息…...

Linux简单的操作

ls ls 查看当前目录 ll 查看详细内容 ls -a 查看所有的内容 ls --help 查看方法文档 pwd pwd 查看当前路径 cd cd 转路径 cd .. 转上一级路径 cd 名 转换路径 …...

oracle与MySQL数据库之间数据同步的技术要点

Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异&#xff0c;它们的数据同步要求既要保持数据的准确性和一致性&#xff0c;又要处理好性能问题。以下是一些主要的技术要点&#xff1a; 数据结构差异 数据类型差异&#xff…...

如何为服务器生成TLS证书

TLS&#xff08;Transport Layer Security&#xff09;证书是确保网络通信安全的重要手段&#xff0c;它通过加密技术保护传输的数据不被窃听和篡改。在服务器上配置TLS证书&#xff0c;可以使用户通过HTTPS协议安全地访问您的网站。本文将详细介绍如何在服务器上生成一个TLS证…...

在Ubuntu中设置开机自动运行(sudo)指令的指南

在Ubuntu系统中&#xff0c;有时需要在系统启动时自动执行某些命令&#xff0c;特别是需要 sudo权限的指令。为了实现这一功能&#xff0c;可以使用多种方法&#xff0c;包括编写Systemd服务、配置 rc.local文件或使用 cron任务计划。本文将详细介绍这些方法&#xff0c;并提供…...