PyTorch深度学习实战——图像着色
PyTorch深度学习实战——图像着色
- 0. 前言
- 1. 模型与数据集分析
- 1.1 数据集介绍
- 1.2 模型策略
- 2. 实现图像着色
- 相关链接
0. 前言
图像着色指的是将黑白或灰度图像转换为彩色图像的过程,传统的图像处理技术通常基于直方图匹配和颜色传递的方法或基于用户交互的方法等完成图像着色操作,不但耗时且需要专业知识,而基于深度学习的方法能够实现自动着色,极大的提高了效率。在训练图着色模型时,我们可以将原始图像转换为黑白图像作为网络输入,原始彩色图像作为输出。
1. 模型与数据集分析
在本节中,我们将利用 CIFAR-10 数据集执行图像着色。
1.1 数据集介绍
CIFAR-10 数据集是一个广泛应用于计算机视觉领域的图像分类数据集。它由 10 个不同类别的彩色图像组成,每个类别包含 6000 张 32 x 32 像素的图像。该数据集涵盖了各种不同的对象类别,包括飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。与一些只包含灰度图像的数据集相比,CIFAR-10 数据集的图像是彩色的,但由于图像分辨率相对较低,图像中的细节和特征相对较少。
CIFAR-10 数据集在计算机视觉领域的研究和开发中得到了广泛的应用,许多图像分类算法和深度学习模型都在 CIFAR-10 上进行了测试和验证。它提供了一个标准化的基准,用于比较不同算法的性能。
1.2 模型策略
了解了所用数据集后,本节中,我们继续介绍图像着色模型策略:
- 获取训练数据集中的原始彩色图像,将其转换为灰度图像,构造输入(灰度)-输出(原始彩色图像)对
- 执行归一化输入和输出图像
- 构建
U-Net架构 - 训练模型
2. 实现图像着色
接下来,使用 PyTorch 实现以上策略,构建图像着色模型。
(1) 导入所需库:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torch import optim
import numpy as np
import torchvision
from matplotlib import pyplot as plt
(2) 下载数据集,并定义训练、验证数据集和数据加载器。
下载数据集:
data_folder = 'cifar10/cifar/'
datasets.CIFAR10(data_folder, download=True)
定义训练、验证数据集和数据加载器:
class Colorize(torchvision.datasets.CIFAR10):def __init__(self, root, train):super().__init__(root, train)def __getitem__(self, ix):im, _ = super().__getitem__(ix)bw = im.convert('L').convert('RGB')bw, im = np.array(bw)/255., np.array(im)/255.bw, im = [torch.tensor(i).permute(2,0,1).to(device).float() for i in [bw,im]]return bw, imtrn_ds = Colorize('cifar10/cifar/', train=True)
val_ds = Colorize('cifar10/cifar/', train=False)trn_dl = DataLoader(trn_ds, batch_size=256, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=256, shuffle=False)
输入和输出图像的样本如下:
a,b = trn_ds[0]
plt.subplot(121)
plt.imshow(a.permute(1,2,0).cpu(), cmap='gray')
plt.subplot(122)
plt.imshow(b.permute(1,2,0).cpu())
plt.show()

(3) 定义网络架构:
class Identity(nn.Module):def __init__(self):super().__init__()def forward(self, x):return xclass DownConv(nn.Module):def __init__(self, ni, no, maxpool=True):super().__init__()self.model = nn.Sequential(nn.MaxPool2d(2) if maxpool else Identity(),nn.Conv2d(ni, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(no, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),)def forward(self, x):return self.model(x)class UpConv(nn.Module):def __init__(self, ni, no, maxpool=True):super().__init__()self.convtranspose = nn.ConvTranspose2d(ni, no, 2, stride=2)self.convlayers = nn.Sequential(nn.Conv2d(no+no, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(no, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),)def forward(self, x, y):x = self.convtranspose(x)x = torch.cat([x,y], axis=1)x = self.convlayers(x)return xclass UNet(nn.Module):def __init__(self):super().__init__()self.d1 = DownConv( 3, 64, maxpool=False)self.d2 = DownConv( 64, 128)self.d3 = DownConv( 128, 256)self.d4 = DownConv( 256, 512)self.d5 = DownConv( 512, 1024)self.u5 = UpConv (1024, 512)self.u4 = UpConv ( 512, 256)self.u3 = UpConv ( 256, 128)self.u2 = UpConv ( 128, 64)self.u1 = nn.Conv2d(64, 3, kernel_size=1, stride=1)def forward(self, x):x0 = self.d1( x) # 32x1 = self.d2(x0) # 16x2 = self.d3(x1) # 8x3 = self.d4(x2) # 4x4 = self.d5(x3) # 2X4 = self.u5(x4, x3)# 4X3 = self.u4(X4, x2)# 8X2 = self.u3(X3, x1)# 16X1 = self.u2(X2, x0)# 32X0 = self.u1(X1) # 3return X0
(4) 定义模型、优化器和损失函数:
def get_model():model = UNet().to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)loss_fn = nn.MSELoss()return model, optimizer, loss_fn
(5) 定义模型在批数据进行训练和验证的函数:
def train_batch(model, data, optimizer, criterion):model.train()x, y = data_y = model(x)optimizer.zero_grad()loss = criterion(_y, y)loss.backward()optimizer.step()return loss.item()@torch.no_grad()
def validate_batch(model, data, criterion):model.eval()x, y = data_y = model(x)loss = criterion(_y, y)return loss.item()
(6) 训练模型:
model, optimizer, criterion = get_model()
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)_val_dl = DataLoader(val_ds, batch_size=1, shuffle=True)n_epochs = 100
train_loss_epochs = []
val_loss_epochs = []for ex in range(n_epochs):N = len(trn_dl)trn_loss = []val_loss = []for bx, data in enumerate(trn_dl):loss = train_batch(model, data, optimizer, criterion)pos = (ex + (bx+1)/N)trn_loss.append(loss)train_loss_epochs.append(np.average(trn_loss))N = len(val_dl)for bx, data in enumerate(val_dl):loss = validate_batch(model, data, criterion)pos = (ex + (bx+1)/N)val_loss.append(loss)val_loss_epochs.append(np.average(val_loss))exp_lr_scheduler.step()if (ex+1)%10 == 0:for _ in range(5):a,b = next(iter(_val_dl))_b = model(a)plt.subplot(131)plt.imshow(a[0].permute(1,2,0).cpu(), cmap='gray')plt.subplot(132)plt.imshow(b[0].permute(1,2,0).cpu())plt.subplot(133)plt.imshow(_b[0].permute(1,2,0).detach().cpu().numpy())plt.show()
epochs = np.arange(n_epochs)+1
plt.plot(epochs, train_loss_epochs, 'bo', label='Training loss')
plt.plot(epochs, val_loss_epochs, 'r', label='Test loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()

从前面的输出中,可以看到模型能够很好地为灰度图像着色。
相关链接
PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割
相关文章:
PyTorch深度学习实战——图像着色
PyTorch深度学习实战——图像着色 0. 前言1. 模型与数据集分析1.1 数据集介绍1.2 模型策略 2. 实现图像着色相关链接 0. 前言 图像着色指的是将黑白或灰度图像转换为彩色图像的过程,传统的图像处理技术通常基于直方图匹配和颜色传递的方法或基于用户交互的方法等完…...
InfiniBand 的前世今生
今年,以 ChatGPT 为代表的 AI 大模型强势崛起,而 ChatGPT 所使用的网络,正是 InfiniBand,这也让 InfiniBand 大火了起来。那么,到底什么是 InfiniBand 呢?下面,我们就来带你深入了解 InfiniBand…...
分享一下微信小程序里怎么添加社区团购功能
随着互联网的快速发展,线上购物已经成为我们日常生活的一部分。而在这个数字化时代,微信小程序作为一种便捷的电商渠道,正逐渐成为新的趋势。其中,社区团购功能更是受到广大用户的热烈欢迎。本文将探讨如何在微信小程序中添加社区…...
软考高项-IT部分
信息化体系 信息化技术应用:龙头 信息资源:核心任务 信息网络:应用基础 信息技术和产业:建设基础 信息化人才:成功之本 信息化法规:保障 信息化趋势 产业信息化、产品信息化、社会生活信息化、国民经济信息化 新型基础设施建设 2018年召开的中央经济工作会议,首…...
hugetlb核心组件
1 概述 hugetlb机制是一种使用大页的方法,与THP(transparent huge page)是两种完全不同的机制,它需要: 管理员通过系统接口reserve一定量的大页,用户通过hugetlbfs申请使用大页, 核心组件如下图: 围绕着…...
vscode配置环境变量
首先点击下面这个链接。 sMinGW-w64 - for 32 and 64 bit Windows - Browse Files at SourceForge.net 然后选择Files这个选项 向下移选择下载这个文件 解压完成之后,找到这个文件的bin目录复制路径后,添加到环境变量中 依次点击后打开cmd࿰…...
react:封装组件
封装 /components/Pagination.tsx import React from react import { Pagination } from antdconst PaginationWarp ({ total, paramsInfo, setParamsInfo }) > {return (<Paginationtotal{total}current{paramsInfo.page}showSizeChangershowQuickJumperdefaultPageSi…...
基于深度学习的视频多目标跟踪实现 计算机竞赛
文章目录 1 前言2 先上成果3 多目标跟踪的两种方法3.1 方法13.2 方法2 4 Tracking By Detecting的跟踪过程4.1 存在的问题4.2 基于轨迹预测的跟踪方式 5 训练代码6 最后 1 前言 🔥 优质竞赛项目系列,今天要分享的是 基于深度学习的视频多目标跟踪实现 …...
linux中各种最新网卡2.5G网卡驱动,不同型号的网卡需要不同的驱动,整合各种网卡驱动,包括有线网卡、无线网卡、Wi-Fi热点
linux中各种最新网卡2.5G网卡驱动,不同型号的网卡需要不同的驱动,整合各种网卡驱动,包括有线网卡、无线网卡、自动安装Wi-Fi热点。 最近在做路由器二次开发,现在市面上卖的新设备,大多数都采用了2.5G网卡,…...
asp.net上传文件
第一种方法 前端: <div> 单文件上传 <form enctype"multipart/form-data" method"post" action"upload.aspx"> <input type"file" name"files" /> …...
JavaEE平台技术——预备知识(Web、Sevlet、Tomcat)
JavaEE平台技术——预备知识(Web、Sevlet、Tomcat) 1. Web基础知识2. Servlet3. Tomcat并发原理 1. Web基础知识 🆒🆒上个CSDN我们讲的是JavaEE的这个渊源,实际上讲了两个小时的历史课,给大家梳理了一下&a…...
基础课23——设计客服机器人
根据调查数据显示,使用纯机器人完全替代客服的情况并不常见,人机结合模式的使用更为普遍。在这两种模式中,不满意用户的占比都非常低,不到1%。然而,在满意用户方面,人机结合模式的用户满意度明显高于其他模…...
mybatis在springboot当中的使用
1.当使用Mybatis实现数据访问时,主要: - 编写数据访问的抽象方法 - 配置抽象方法对应的SQL语句 关于抽象方法: - 必须定义在某个接口中,这样的接口通常使用Mapper作为名称的后缀,例如AdminMapper - Mybatis框架底…...
如何处理前端本地存储和缓存
前端本地存储和缓存的处理是一种重要的技术,它可以帮助改善应用程序的性能和用户体验。下面是一些处理前端本地存储和缓存的常用方法: 1. 使用Web Storage API: 这是一种在浏览器中存储数据的方法,包括两种类型:loca…...
导轨式安装压力应变桥信号处理差分信号输入转换变送器0-10mV/0-20mV/0-±10mV/0-±20mV转0-5V/0-10V/4-20mA
主要特性 DIN11 IPO 压力应变桥信号处理系列隔离放大器是一种将差分输入信号隔离放大、转换成按比例输出的直流信号导轨安装变送模块。产品广泛应用在电力、远程监控、仪器仪表、医疗设备、工业自控等行业。此系列模块内部嵌入了一个高效微功率的电源,向输入端和输…...
人体姿态估计和手部姿态估计任务中神经网络的选择
一、人体姿态估计任务适合使用卷积神经网络(CNN)来解决。 人体姿态估计任务的目标是从给定的图像或视频中推断出人体的关节位置和姿势。这是一个具有挑战性的计算机视觉任务,而CNN在处理图像数据方面表现出色。 使用CNN进行人体姿态估计的一种…...
odoo16 one2many字段的 domain
最近在odoo project模块的基础上做二开,给task表加了一个版本字段version_id,然后重写了 project表的Task_ids, 并且增加了一个domain,结果折腾了大半天才搞定 写法1 这也是最初的写法: version_id fields.Many2one("hx.p…...
一份优秀测试用例的设计策略
日常工作中最为基础核心的内容就是设计测试用例,什么样的测试用例是好的测试用例?我们一般会认为数量越少、发现缺陷越多的用例就是好的用例。那么我们如何才能设计出好的测试用例呢?一份好的用例是设计出来的,是测试人员思路和方法的集合&a…...
自动驾驶行业观察之2023上海车展-----智驾供应链(3)
智驾解决方案商发展 华为:五项重磅技术更新,重点发布华为ADS 2.0和鸿蒙OS 3.0 1)产品方案:五大解决方案都有了全面的升级,分别推出了ADS 2.0、鸿蒙OS 3.0、iDVP智能汽车数字平台、智能车云服务和华为车载光最新 产品…...
倒计时丨3天后,我们直播间见!
倒计时3天,RestCloud 零代码集成自动化平台重磅发布 ⏰11 月 9 日 14:00,期待您的参与! 点击报名:http://c.nxw.so/dfaJ9...
深度学习在微纳光子学中的应用
深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向: 逆向设计 通过神经网络快速预测微纳结构的光学响应,替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…...
51c自动驾驶~合集58
我自己的原文哦~ https://blog.51cto.com/whaosoft/13967107 #CCA-Attention 全局池化局部保留,CCA-Attention为LLM长文本建模带来突破性进展 琶洲实验室、华南理工大学联合推出关键上下文感知注意力机制(CCA-Attention),…...
Unity3D中Gfx.WaitForPresent优化方案
前言 在Unity中,Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染(即CPU被阻塞),这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案: 对惹,这里有一个游戏开发交流小组&…...
YSYX学习记录(八)
C语言,练习0: 先创建一个文件夹,我用的是物理机: 安装build-essential 练习1: 我注释掉了 #include <stdio.h> 出现下面错误 在你的文本编辑器中打开ex1文件,随机修改或删除一部分,之后…...
渲染学进阶内容——模型
最近在写模组的时候发现渲染器里面离不开模型的定义,在渲染的第二篇文章中简单的讲解了一下关于模型部分的内容,其实不管是方块还是方块实体,都离不开模型的内容 🧱 一、CubeListBuilder 功能解析 CubeListBuilder 是 Minecraft Java 版模型系统的核心构建器,用于动态创…...
跨链模式:多链互操作架构与性能扩展方案
跨链模式:多链互操作架构与性能扩展方案 ——构建下一代区块链互联网的技术基石 一、跨链架构的核心范式演进 1. 分层协议栈:模块化解耦设计 现代跨链系统采用分层协议栈实现灵活扩展(H2Cross架构): 适配层…...
AirSim/Cosys-AirSim 游戏开发(四)外部固定位置监控相机
这个博客介绍了如何通过 settings.json 文件添加一个无人机外的 固定位置监控相机,因为在使用过程中发现 Airsim 对外部监控相机的描述模糊,而 Cosys-Airsim 在官方文档中没有提供外部监控相机设置,最后在源码示例中找到了,所以感…...
站群服务器的应用场景都有哪些?
站群服务器主要是为了多个网站的托管和管理所设计的,可以通过集中管理和高效资源的分配,来支持多个独立的网站同时运行,让每一个网站都可以分配到独立的IP地址,避免出现IP关联的风险,用户还可以通过控制面板进行管理功…...
uniapp 开发ios, xcode 提交app store connect 和 testflight内测
uniapp 中配置 配置manifest 文档:manifest.json 应用配置 | uni-app官网 hbuilderx中本地打包 下载IOS最新SDK 开发环境 | uni小程序SDK hbulderx 版本号:4.66 对应的sdk版本 4.66 两者必须一致 本地打包的资源导入到SDK 导入资源 | uni小程序SDK …...
基于Java+VUE+MariaDB实现(Web)仿小米商城
仿小米商城 环境安装 nodejs maven JDK11 运行 mvn clean install -DskipTestscd adminmvn spring-boot:runcd ../webmvn spring-boot:runcd ../xiaomi-store-admin-vuenpm installnpm run servecd ../xiaomi-store-vuenpm installnpm run serve 注意:运行前…...
