解析生成对抗网络(GAN):原理与应用
目录
一、引言
二、生成对抗网络原理
(一)基本架构
(二)训练过程
三、生成对抗网络的应用
(一)图像生成
无条件图像生成:
(二)数据增强
(三)风格迁移
四、生成对抗网络训练中的挑战与解决策略
(一)模式崩溃
(二)梯度消失
一、引言
生成对抗网络(GAN)自 2014 年被 Goodfellow 等人提出以来,在深度学习领域引起了广泛的关注和研究热潮。它创新性地引入了一种对抗训练的思想,通过生成器和判别器的相互博弈,使得生成器能够学习到数据的潜在分布,从而生成逼真的样本数据。这种独特的机制使得 GAN 在图像生成、文本生成、音频生成等多个领域展现出了巨大的潜力,为人工智能技术的发展带来了新的突破和方向。
二、生成对抗网络原理
(一)基本架构
GAN 主要由两个核心组件构成:生成器(Generator)和判别器(Discriminator)。
- 生成器:
- 生成器的任务是接收一个随机噪声向量 (通常从一个简单的分布,如标准正态分布 N(0,1)采样得到),并通过一系列的神经网络层将其映射为与真实数据相似的生成数据G(z)
- 例如,在图像生成任务中,生成器的输出将是一张与训练数据集中图像具有相似特征的合成图像。
- 生成器通常采用多层的反卷积神经网络(Deconvolutional Neural Network)或转置卷积神经网络(Transposed Convolutional Neural Network)结构。以生成64*64其网络结构如下:
import torch import torch.nn as nnclass Generator(nn.Module):def __init__(self):super(Generator, self).__init__()# 输入为 100 维的噪声向量self.fc = nn.Linear(100, 4 * 4 * 1024)self.deconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1)self.bn1 = nn.BatchNorm2d(512)self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)self.bn2 = nn.BatchNorm2d(256)self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)self.bn3 = nn.BatchNorm2d(128)self.deconv4 = nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1)def forward(self, x):x = self.fc(x)x = x.view(-1, 1024, 4, 4)x = torch.relu(self.bn1(self.deconv1(x)))x = torch.relu(self.bn2(self.deconv2(x)))x = torch.relu(self.bn3(self.deconv3(x)))x = torch.tanh(self.deconv4(x))return x
- 判别器:
- 判别器的作用是区分输入的数据是来自真实数据分布还是由生成器生成的数据。它接收真实数据 x 或生成数据 G(z),并输出一个表示数据真实性的概率值 D(x)或D(G(z)) ,取值范围在 0 到 1之间,接近 表示数据更可能是真实的,接近 表示数据更可能是生成的。
判别器通常采用卷积神经网络(Convolutional Neural Network)结构。例如,对于判断 彩色图像的判别器网络结构如下:
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.conv1 = nn.Conv2d(3, 128, kernel_size=4, stride=2, padding=1)self.bn1 = nn.BatchNorm2d(128)self.conv2 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)self.bn2 = nn.BatchNorm2d(256)self.conv3 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)self.bn3 = nn.BatchNorm2d(512)self.conv4 = nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0)def forward(self, x):x = torch.relu(self.bn1(self.conv1(x)))x = torch.relu(self.bn2(self.conv2(x)))x = torch.relu(self.bn3(self.conv3(x)))x = torch.sigmoid(self.conv4(x))return x.view(-1)
(二)训练过程
GAN 的训练过程是一个对抗性的迭代过程。
三、生成对抗网络的应用
(一)图像生成
1.无条件图像生成
GAN 可以用于生成各种类型的图像,如人脸图像、风景图像等。例如,在人脸图像生成任务中,通过在大规模人脸数据集上训练 GAN,生成器能够学习到人脸的各种特征,如五官的形状、肤色、表情等,从而生成全新的、逼真的人脸图像。
代码示例:
# 假设已经定义好生成器 G 和判别器 D,以及相关的优化器和损失函数
# 训练循环
num_epochs = 100
for epoch in range(num_epochs):for i, (real_images, _) in enumerate(dataloader):# 训练判别器# 采样噪声z = torch.randn(real_images.shape[0], 100).to(device)# 生成假图像fake_images = G(z)# 计算判别器损失real_loss = criterion(D(real_images), torch.ones(real_images.shape[0]).to(device))fake_loss = criterion(D(fake_images.detach()), torch.zeros(fake_images.shape[0]).to(device))d_loss = (real_loss + fake_loss) / 2# 更新判别器参数d_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# 训练生成器# 再次采样噪声z = torch.randn(real_images.shape[0], 100).to(device)# 生成假图像fake_images = G(z)# 计算生成器损失g_loss = criterion(D(fake_images), torch.ones(fake_images.shape[0]).to(device))# 更新生成器参数g_optimizer.zero_grad()g_loss.backward()g_optimizer.step()
2.条件图像生成
可以通过在生成器和判别器的输入中加入条件信息,实现条件图像生成。例如,根据给定的文本描述生成相应的图像,或者根据特定的类别标签生成属于该类别的图像。
以根据类别标签生成图像为例,在生成器的输入中除了噪声向量 ,还加入类别标签的编码向量 ,生成器的网络结构需要进行相应修改,如:
class ConditionalGenerator(nn.Module):def __init__(self, num_classes):super(ConditionalGenerator, self).__init__()# 输入为 100 维噪声向量和类别编码向量self.fc = nn.Linear(100 + num_classes, 4 * 4 * 1024)# 后续的反卷积层与无条件生成器类似self.deconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1)self.bn1 = nn.BatchNorm2d(512)#...def forward(self, x, y):# 拼接噪声向量和类别编码向量x = torch.cat([x, y], dim=1)x = self.fc(x)x = x.view(-1, 1024, 4, 4)x = torch.relu(self.bn1(self.deconv1(x)))#...return x
(二)数据增强
- 图像数据增强:
- 在图像分类、目标检测等任务中,数据量不足可能导致模型过拟合。GAN 可以用于生成额外的图像数据来扩充数据集。通过在原始图像数据集上训练 GAN,生成与原始图像相似但又有一定变化的图像,如不同角度、光照条件下的图像,从而增加数据的多样性,提高模型的泛化能力。
- 其他数据类型的数据增强:
- 除了图像数据,GAN 也可以应用于其他数据类型的数据增强,如文本数据。例如,通过生成与原始文本相似的新文本,扩充文本数据集,有助于训练更强大的文本处理模型,如文本分类、机器翻译等模型。
(三)风格迁移
- 图像风格迁移原理:
- GAN 可以实现图像风格迁移,即将一幅图像的内容与另一幅图像的风格进行融合。其原理是通过定义内容损失和风格损失,利用生成器生成具有目标风格的图像,同时判别器用于判断生成图像的质量和风格一致性。
- 例如,使用预训练的 VGG 网络来计算内容损失和风格损失。内容损失衡量生成图像与原始内容图像在特征表示上的差异,风格损失衡量生成图像与目标风格图像在风格特征(如纹理、颜色分布等)上的差异。
代码示例实现风格迁移:
import torchvision.models as models
import torch.nn.functional as F# 加载预训练的 VGG 网络
vgg = models.vgg19(pretrained=True).features.eval().to(device)# 定义内容损失函数
def content_loss(content_features, generated_features):return F.mse_loss(content_features, generated_features)# 定义风格损失函数
def style_loss(style_features, generated_features):style_loss = 0for s_feat, g_feat in zip(style_features, generated_features):# 计算 Gram 矩阵s_gram = gram_matrix(s_feat)g_gram = gram_matrix(g_feat)style_loss += F.mse_loss(s_gram, g_gram)return style_loss# Gram 矩阵计算函数
def gram_matrix(x):b, c, h, w = x.size()features = x.view(b * c, h * w)gram = torch.mm(features, features.t())return gram.div(b * c * h * w)
四、生成对抗网络训练中的挑战与解决策略
(一)模式崩溃
问题描述:
模式崩溃是 GAN 训练中常见的问题之一,表现为生成器生成的样本多样性不足,往往集中在少数几种模式上。例如,在生成人脸图像时,可能生成的人脸都具有相似的特征,而不能涵盖人脸的多种可能形态。
解决策略:
Wasserstein GAN(WGAN):WGAN 对 GAN 的损失函数进行了改进,采用 Wasserstein 距离来衡量真实数据分布和生成数据分布之间的差异,而不是传统的 JS 散度。这使得训练过程更加稳定,减少了模式崩溃的发生。其关键代码修改如下:
# 判别器的最后一层不再使用 Sigmoid 激活函数
self.conv4 = nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0)
# 定义 WGAN 的损失函数
def wgan_loss(real_pred, fake_pred):return -torch.mean(real_pred) + torch.mean(fake_pred)
模式正则化:通过在生成器的损失函数中加入正则化项,鼓励生成器生成更多样化的样本。例如,在生成器的损失函数中加入对生成样本的熵约束,使得生成样本的分布更加均匀。
(二)梯度消失
- 问题描述:
- 在 GAN 训练初期,当判别器的性能非常好时,生成器的梯度可能会变得非常小,导致生成器难以更新参数,无法有效地学习到数据的分布。这是因为判别器能够很容易地区分真实数据和生成数据,使得生成器的损失函数接近饱和,梯度趋近于 。
- 解决策略:
- 梯度惩罚(Gradient Penalty):在判别器的损失函数中加入梯度惩罚项,限制判别器的梯度大小,使得判别器不会过于强大,从而缓解生成器的梯度消失问题。例如,在 WGAN-GP(Wasserstein GAN with Gradient Penalty)中,梯度惩罚项的计算如下:
def gradient_penalty(critic, real, fake, device):BATCH_SIZE, C, H, W = real.shape# 随机采样插值系数alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)# 计算插值数据interpolated_images = real * alpha + fake * (1 - alpha)# 计算判别器对插值数据的输出mixed_scores = critic(interpolated_images)# 计算梯度gradient = torch.autograd.grad(inputs=interpolated_images,outputs=mixed_scores,grad_outputs=torch.ones_like(mixed_scores),create_graph=True,retain_graph=True,)[0]# 计算梯度惩罚项gradient = gradient.view(gradient.shape[0], -1)gradient_norm = gradient.norm(2, dim=1)gradient_penalty = torch.mean((gradient_norm - 1) ** 2)return gradient_penalty
- 使用 Leaky ReLU 激活函数:在判别器和生成器的网络中使用 Leaky ReLU 激活函数替代传统的 ReLU 激活函数。Leaky ReLU 允许在负半轴有一个较小的斜率,从而避免了在某些情况下神经元完全不激活导致的梯度消失问题。
- 梯度惩罚(Gradient Penalty):在判别器的损失函数中加入梯度惩罚项,限制判别器的梯度大小,使得判别器不会过于强大,从而缓解生成器的梯度消失问题。例如,在 WGAN-GP(Wasserstein GAN with Gradient Penalty)中,梯度惩罚项的计算如下:
相关文章:
解析生成对抗网络(GAN):原理与应用
目录 一、引言 二、生成对抗网络原理 (一)基本架构 (二)训练过程 三、生成对抗网络的应用 (一)图像生成 无条件图像生成: (二)数据增强 (三ÿ…...
CodeIgniter URL结构
CodeIgniter 的URL 结构设计得简洁且易于管理。通常遵循以下模式: http://<domain>/<index_page>/<controller>/<method>/<parameters> 下面是每个部分的详细说明: <domain>: 这是你的网站域名&#…...
从 App Search 到 Elasticsearch — 挖掘搜索的未来
作者:来自 Elastic Nick Chow App Search 将在 9.0 版本中停用,但 Elasticsearch 拥有你构建强大的 AI 搜索体验所需的一切。以下是你需要了解的内容。 生成式人工智能的最新进展正在改变用户行为,激励开发人员创造更具活力、更直观、更引人入…...
鸿蒙本地模拟器 模拟TCP服务端的过程
鸿蒙模拟器模拟TCP服务端的过程涉及几个关键步骤,主要包括创建TCPSocketServer实例、绑定IP地址和端口、监听连接请求、接收和发送数据以及处理连接事件。以下是详细的模拟过程: **1.创建TCPSocketServer实例:**首先,需要导入鸿蒙…...
Qt/C++基于重力模拟的像素点水平堆叠效果
本文将深入解析一个基于 Qt/C 的像素点模拟程序。程序通过 重力作用,将随机分布的像素点下落并水平堆叠,同时支持窗口动态拉伸后重新计算像素点分布。 程序功能概述 随机生成像素点:程序在初始化时随机生成一定数量的像素点,每个…...
Zookeeper学习心得
本人学zookeeper时按照此文路线学的 Zookeeper学习大纲 - 似懂非懂视为不懂 - 博客园 一、Zookeeper安装 ZooKeeper 入门教程 - Java陈序员 - 博客园 Docker安装Zookeeper教程(超详细)_docker 安装zk-CSDN博客 二、 zookeeper的数据模型 ZooKeepe…...
嵌入式开发工程师面试题 - 2024/11/24
原文嵌入式开发工程师面试题 - 2024/11/24 转载请注明来源 1.若有以下定义语句double a[8],*pa;int i5;对数组元素错误的引用是? A *a B a[5] C *(p1) D p[8] 解析: 在 C 或 C 语言中&am…...
Python中打印当前目录文件树的脚本
效果图: 实现脚本: 1、显示所有文件和文件夹: import osdef list_files(startpath, prefix):items os.listdir(startpath)items.sort()for index, item in enumerate(items):item_path os.path.join(startpath, item)is_last index le…...
全景图像(Panorama Image)向透视图像(Perspective Image)的跨视图转化(Cross-view)
一、概念讲解 全景图像到透视图像的转化是一个复杂的图像处理过程,它涉及到将一个360度的全景图像转换为一个具有透视效果的图像,这种图像更接近于人眼观察世界的方式。全景图像通常是一个矩形图像,它通过将球面图像映射到平面上得到…...
Redis 中的 hcan 命令耗内存,有什么优化的方式吗 ?
Redis 中的 hcan 命令耗内存,有什么优化的方式吗 ? 1. 使用合适的游标值:2. 控制每次迭代返回的键数量:3. 避免长时间运行的迭代:4. 使用HSCAN与SCAN命令结合:5. 优化哈希表结构:6. 监控和调整R…...
豆包MarsCode算法题:三数之和问题
问题描述 思路分析 1. 排序数组 目的: 将数组 arr 按升序排序,这样可以方便地使用双指针找到满足条件的三元组,同时避免重复的三元组被重复计算。优势: 数组有序后,处理两个数和 target - arr[i] 的问题可以通过双指针快速找到所有可能的组…...
【Android】AnimationDrawable帧动画的实现
目录 引言 一、AnimationDrawable常用方法 1.1 导包 1.2 addFrame 1.3 setOneShot 1.4 start 1.5 stop 1.6 isRunning 二、 从xml文件获取并播放帧动画 2.1 创建XML文件 2.2 在布局文件中使用帧动画资源 三、在代码中生成并播放帧动画 3.1 addFrame加入帧动画列…...
【消息序列】详解(7):剖析回环模式--设备测试的核心利器
目录 一、概述 1.1. 本地回环模式 1.2. 远程环回模式 二、本地回环模式(Local Loopback mode) 2.1. 步骤 1:主机进入本地环回模式 2.2. 本地回环测试 2.2.1. 步骤 2a:主机发送HCI数据包并接收环回数据 2.2.2. 步骤 2b&…...
解决Ubuntu 22.04系统中网络Ping问题的方法
在Ubuntu 22.04系统中,网络问题时有发生,尤其是当涉及到静态IP地址配置和网线直连的两台机器时。本文将探讨一种常见问题——断开并重新连接网线后,尽管网卡显示为UP状态,但无法立即ping通对方机器,以及如何解决这一问…...
【大数据学习 | Spark-SQL】Spark-SQL编程
上面的是SparkSQL的API操作。 1. 将RDD转化为DataFrame对象 DataFrame: DataFrame是一种以RDD为基础的分布式数据集,类似于传统数据库中的二维表格。带有schema元信息,即DataFrame所表示的二维表数据集的每一列都带有名称和类型。这样的数…...
15分钟做完一个小程序,腾讯这个工具有点东西
我记得很久之前,我们都在讲什么低代码/无代码平台,这个概念很久了,但是,一直没有很好的落地,整体的效果也不算好。 自从去年 ChatGPT 这类大模型大火以来,各大科技公司也都推出了很多 AI 代码助手ÿ…...
manim动画编程(安装+入门)
文章目录 1.基本介绍2.效果展示3.安装步骤3.1安装manba软件3.2配置环境变量3.3查看是否成功3.4什么是mamba3.5创建虚拟环境3.6尝试进入虚拟环境 4.vscode操作4.1默认配置文件 5.安装ffmpeg6.安装manim软件6.vscode制作7.我的学习收获 1.基本介绍 这个manim就是一款软件&#x…...
STL算法之数值算法<stl_numeric.h>
这一节介绍的算法,统称为数值(numeric)算法。STL规定,欲使用它们,客户端必须包含头文件<numeric>.SGI将它们实现与<stl_numeric.h>文件中。 目录 运用实例 accumulate adjacent_difference inner_product partial_sum pow…...
Oracle如何记录登录用户IP
在运维场景中,在定位到某个SQL引起系统故障之后,想知道是哪台机器发过来的,方便定位源头,该如何解决? 在 Oracle 数据库中记录登录用户的 IP 地址可以通过多种方法实现。以下是几种常见的方法,包括使用触发…...
Python图像处理:打造平滑液化效果动画
液化动画中的强度变化是通过在每一帧中逐渐调整液化效果的强度参数来实现的。在提供的代码示例中,强度变化是通过一个简单的线性插值方法来控制的,即随着动画帧数的增加,液化效果的强度也逐渐增加。 def liquify_image(image, center, radius…...
构建Ceph分布式文件共享系统:手动部署指南
#作者:西门吹雪 文章目录 micro-Services-TutorialCeph分布式文件共享方案部署Ceph集群使用CephCeph在kubernetes集群中的使用 micro-Services-Tutorial 微服务最早由Martin Fowler与James Lewis于2014年共同提出,微服务架构风格是一种使用一套小服务来开发单个应…...
数据结构——用数组实现栈和队列
目录 用数组实现栈和队列 一、数组实现栈 1.stack类 2.测试 二、数组实现队列 1.Queue类 2.测试 查询——数组:数组在内存中是连续空间 增删改——链表:链表的增删改处理更方便一些 满足数据先进后出的特点的就是栈,先进先出就是队列…...
vue3typescript,shims-vue.d.ts中declare module的vue声明
webpack已经有了vue-loader这些loader了,为什么还需要declare module *.vue’呢? declare module 是为了告诉 tsc 这是一个“模块”。 如果不声明, IDE 里因为 tsc 类型检查, lint 会标红。 但vue-loader 是在 Webpack 构建阶段使…...
C/C++基础知识复习(30)
1) 什么是 C 中的 Lambda 表达式?它的作用是什么? Lambda 表达式: 在 C 中,Lambda 表达式是一种可以定义匿名函数的机制,可以在代码中快速创建一个内联的函数对象,而不需要显式地定义一个函数。Lambda 表…...
【NLP 1、人工智能与NLP简介】
人人都不看好你,可偏偏你最争气 —— 24.11.26 一、AI和NLP的基本介绍 1.人工智能发展流程 弱人工智能 ——> 强人工智能 ——> 超人工智能 ① 弱人工智能 人工智能算法只能在限定领域解决特定的问题 eg:特定场景下的文本分类、垂直领域下的对…...
网络安全事件管理
一、背景 信息化技术的迅速发展已经极大地改变了人们的生活,网络安全威胁也日益多元化和复杂化。传统的网络安全防护手段难以应对当前繁杂的网络安全问题,构建主动防御的安全整体解决方案将更有利于防范未知的网络安全威胁。 国内外的安全事件在不断增…...
Swagger记录一次生成失败
最近在接入Swagger的时候遇到一个问题,就是Swagger UI可以使用的,但是/v3/docs 这个接口的json返回的base64类型的json,并不是纯json,后来检查之后是因为springboot3里面配置了json压缩。 Beanpublic HttpMessageConverters cusHt…...
Go 语言常用工具方法总结
在 Go 语言开发中,常常需要进行一些常见的类型转换、字符串处理、时间处理等操作。本文将总结一些常用的工具方法,帮助大家提高编码效率,并提供必要的代码解释和注意事项(go新人浅浅记录一下,以后来翻看🤣&…...
ThingsBoard规则链节点:GCP Pub/Sub 节点详解
目录 引言 1. GCP Pub/Sub 节点简介 2. 节点配置 2.1 基本配置示例 3. 使用场景 3.1 数据传输 3.2 数据分析 3.3 事件通知 3.4 任务调度 4. 实际项目中的应用 4.1 项目背景 4.2 项目需求 4.3 实现步骤 5. 总结 引言 ThingsBoard 是一个开源的物联网平台࿰…...
【Linux】select,poll和epoll
select,poll,epoll都是IO多路复用的机制。I/O多路复用就通过一种机制,可以监视多个描述符fd,一旦某个描述符就绪(一般是读就绪或者写就绪),系统会通知有I/O事件发生了(不能定位是哪一个)。但sel…...
网站设计的基本步骤和方法/广州网站优化
在我们登录一些网站、应用、游戏时,见到动态验证码的频率越来越多了。最常见的应该就是Google Authenticator,暴雪安全令之类的应用,通过不断变换的动态数字来最大限度的保证账号的数据安全。今天 Gitee 推荐的这款开源项目,是依托…...
wordpress优化公司/网站软文代写
DNS 是Domain Name System (域名系统) 的缩写,是一种按域层次结构组织计算机和网络的命名系统。DNS应用于TCP/IP构建的网络,主要用于Internet。在Internet上,用户记忆由数字组成的IP地址比较困难,所以引入了域名的概念。域名与IP地…...
wordpress 4.9.8漏洞/嘉兴seo计费管理
import pandas def data_to_excel(data,row):length len(data)number length //rowfor i in range(number1):data[i*row:(i1)*row].to_excel(./path.xlsx,indexFalse)...
深圳模板建站多少钱/竞价托管运营哪家好
github传送门 目录 前言效果图快速上手CollapsingToolbarLayout折叠模式AppBarLayout滚动方式CoordinatorLayout配合Snackbar自定义伸缩头部最后前言 之前也是写了RecyclerView的内容, 这次再补充伸缩头部的实现. 港真, 伸缩头部是那种看到第一眼就会爱上的视图效果, 好看又简洁…...
口腔医院网站做优化/关键词分析
随着人工智能的发展,人脸识别技术在各个领域的场景应用中日益丰富,在多个场景可以看到人脸识别系统的应用落地,在社区、企业、工地、安防等方面。而现在随着各地智慧校园的建设,有些学校逐步引入人脸识别技术,通过校园…...
新疆建设兵团职称查询官方网站/今日头条搜索引擎
如何选择适合深度学习的GPU?为什么GPU比CPU更适合机器学习或者深度学习?什么是张量处理单元(TPU)?目前主流的GPU厂商:Nvidia和AMD选择GPU时需要关注的主要属性1. GPU的内存需要多少?2. 需要多少核心&#…...