从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击
从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击
简介
在机器学习领域,深度神经网络的强大表现令人印象深刻,尤其是在图像分类等任务上。然而,随着对深度学习的深入研究,研究人员发现了神经网络的一个脆弱性:对抗样本(Adversarial Examples)。简单来说,对抗样本是通过对原始输入数据添加微小扰动,导致神经网络预测错误的一类特殊样本。
Fast Gradient Sign Method (FGSM) 是最早提出的一种简单且有效的对抗攻击方法。它通过利用神经网络的梯度信息,快速生成对抗样本。本篇博客将深入剖析 FGSM 的数学原理,并提供一份基于 PyTorch 的代码实现,帮助你快速上手对抗攻击。
最终效果
当前图片为 truck,但经过对抗攻击后,结果看似应该同样是 truck,然而模型却错误地将其分类为了 cat。
数学原理
1. 神经网络的基本工作原理
神经网络通过对输入数据 x x x 进行一系列的非线性变换,输出一个预测值 y ^ \hat{y} y^。以图像分类任务为例,输入 x x x 是图像,网络输出 y ^ \hat{y} y^ 是一个分类结果。网络的目标是使得预测值 y ^ \hat{y} y^ 尽可能接近真实标签 y y y,通过最小化损失函数 L ( y ^ , y ) L(\hat{y}, y) L(y^,y) 来优化网络的权重。
2. 对抗攻击的目标
对抗攻击的核心思想是在输入数据上添加一个微小的扰动 η \eta η,使得网络在预测时发生错误。这种扰动应尽量保持小到人类肉眼难以察觉,但足以让网络产生不同的输出,即:
y ^ a d v = f ( x + η ) and y ^ a d v ≠ y \hat{y}_{adv} = f(x + \eta) \quad \text{and} \quad \hat{y}_{adv} \neq y y^adv=f(x+η)andy^adv=y
3. FGSM 的核心思想
FGSM 利用了梯度上升的思想,通过损失函数相对于输入图像的梯度来找到 最容易 迷惑网络的方向,并沿着这个方向对图像进行微小的扰动。
具体来说,给定输入图像 x x x 和其真实标签 y y y,我们可以计算损失函数 L ( x , y ) L(x, y) L(x,y),并对输入图像 x x x 计算其梯度:
∇ x L ( f ( x ) , y ) \nabla_x L(f(x), y) ∇xL(f(x),y)
这个梯度告诉我们,如何改变输入图像才能最大化损失。FGSM 的基本想法是,沿着这个梯度的符号方向对图像进行微调,以最大化损失函数。具体公式为:
x a d v = x + ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) x_{adv} = x + \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) xadv=x+ϵ⋅sign(∇xL(f(x),y))
其中:
- ϵ \epsilon ϵ 是一个小的常数,控制扰动的大小。
- sign ( ∇ x L ( f ( x ) , y ) ) \text{sign}(\nabla_x L(f(x), y)) sign(∇xL(f(x),y)) 是损失函数对输入图像梯度的符号,即正负号矩阵。
通过这种方式,我们可以生成一个对抗样本 x a d v x_{adv} xadv,它与原始图像 x x x 看起来几乎相同,但神经网络会将其错误分类。
4. 梯度上升与对抗样本
FGSM 是基于梯度上升的攻击方法。为了让网络犯错,我们希望最大化损失,因此通过扰动让损失函数 L L L 增加。梯度 ∇ x L \nabla_x L ∇xL 的方向是使 L L L 增加最快的方向,因此,我们可以通过对梯度取符号(sign)来生成一个扰动 η \eta η,即:
η = ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) \eta = \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) η=ϵ⋅sign(∇xL(f(x),y))
这样生成的扰动被加到原图 x x x 上,即对抗样本的公式为:
x a d v = x + η = x + ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) x_{adv} = x + \eta = x + \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) xadv=x+η=x+ϵ⋅sign(∇xL(f(x),y))
代码实现与解读
接下来,我们通过实际的代码示例,展示如何使用 PyTorch 实现 FGSM 攻击。
1. 加载CIFAR-10数据集与预训练模型
首先,我们需要加载 CIFAR-10 数据集和一个预训练的模型。这里我们使用 ResNet-18 模型,并针对 CIFAR-10 数据集进行了调整。
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models# 设置设备 (GPU 优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 加载 CIFAR-10 数据集
transform = transforms.Compose([transforms.ToTensor(),
])
test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)# 加载微调后的 ResNet 模型
model = models.resnet18(pretrained=False, num_classes=10).to(device)
model.eval() # 设置为评估模式
在这一步中,我们准备了数据集,并将模型设置为评估模式,这样模型的参数不会在攻击过程中被更新。
提示:如果没有下载过 CIFAR-10 数据集,请将
download=False
改为download=True
。
2. 实现FGSM攻击方法
FGSM 的关键在于利用损失函数的梯度来生成对抗样本。具体步骤如下:
# 定义 FGSM 攻击函数
def fgsm_attack(image, epsilon, data_grad):# 生成扰动方向sign_data_grad = data_grad.sign()# 生成对抗样本perturbed_image = image + epsilon * sign_data_grad# 对抗样本像素值范围约束在 [0,1]perturbed_image = torch.clamp(perturbed_image, 0, 1)return perturbed_image
数学公式:
在代码中,我们实际上实现了如下的公式:
x a d v = x + ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) x_{adv} = x + \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) xadv=x+ϵ⋅sign(∇xL(f(x),y))
该函数实现了 FGSM 算法,通过将计算得到的梯度方向乘以一个小的扰动系数 ϵ \epsilon ϵ,并添加到原始图像上来生成对抗样本。
3. 攻击流程
接下来,我们定义一个函数,来完成攻击过程:
def attack_example(model, device, data, target, epsilon):# 将数据和标签移动到设备上data, target = data.to(device), target.to(device)data.requires_grad = True # 追踪输入图像的梯度# 前向传播得到初始预测output = model(data)init_pred = output.max(1, keepdim=True)[1] # 选择概率最大的类别作为预测结果# 如果初始预测错误,则不进行攻击if init_pred.item() != target.item():return None, None, None# 计算损失并进行反向传播计算梯度loss_fn = torch.nn.CrossEntropyLoss()loss = loss_fn(output, target)model.zero_grad() # 清除现有的梯度loss.backward() # 计算损失对输入图像的梯度# 使用 FGSM 方法生成对抗样本data_grad = data.grad.dataperturbed_data = fgsm_attack(data, epsilon, data_grad)# 使用对抗样本进行新的预测output = model(perturbed_data)final_pred = output.max(1, keepdim=True)[1] # 获取对抗样本的预测标签return data, perturbed_data, final_pred
对抗攻击的基本思路是:
- 前向传播:将输入数据(如图片)传递给模型,得到输出,并预测标签。
- 计算损失:使用损失函数来评估模型的预测与真实标签之间的差距。
- 反向传播计算梯度:通过反向传播,计算损失函数相对于输入数据的梯度。
- 生成对抗样本:利用这些梯度来调整输入数据,生成扰动的对抗样本。
- 重新预测:使用生成的对抗样本进行重新预测,观察模型是否做出错误的分类。
4. 展示原始图像与对抗样本
最后,我们选择一个样本进行攻击,并展示原始图像和对抗样本的效果。
# Step 6: 选择 epsilon 并进行测试
epsilon = 0.01 # 扰动强度# 遍历数据集,找到预测正确的样本
for data, target in test_loader:orig_data, perturbed_data, final_pred = attack_example(model, device, data, target, epsilon)if orig_data is not None:break# CIFAR-10 类别
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']if orig_data is not None:# 显示原始图像和对抗样本orig_img = orig_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()perturbed_img = perturbed_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()# 获取语义化标签original_label_name = cifar10_classes[target.item()]adversarial_label_name = cifar10_classes[final_pred.item()]# Step 7: 展示结果fig, (ax1, ax2) = plt.subplots(1, 2)ax1.imshow(orig_img)ax1.set_title(f"Original Label: {original_label_name}")ax2.imshow(perturbed_img)ax2.set_title(f"Adversarial Label: {adversarial_label_name}")plt.show()print(f"Original Label: {original_label_name}, Adversarial Label: {adversarial_label_name}")
else:print("Initial prediction was incorrect. No attack performed.")
在运行这段代码后,你会看到一张原始图像及其对抗样本。尽管对抗样本与原图在视觉上几乎没有区别,但它却被神经网络错误分类了。这种现象突显了神经网络的脆弱性。
结论
FGSM 是一种简单且有效的对抗攻击方法,它利用了神经网络的梯度信息生成对抗样本。这些对抗样本虽然肉眼难以察觉,但能显著影响模型的预测结果。本文介绍了 FGSM 的数学原理,并提供了基于 PyTorch 的实现代码。
通过这个例子,我们可以看到即使是看似强大的深度学习模型,也会受到一些精心设计的输入的严重影响。了解这些脆弱性有助于我们开发出更加健壮和安全的模型。
附录:完整代码
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np# Step 1: 设置设备 (GPU 优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Step 2: 加载 CIFAR-10 数据集
transform = transforms.Compose([transforms.ToTensor(),
])# 加载测试数据集
test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)# model = models.resnet18(pretrained=True).to(device)# 1. 加载在 CIFAR-10 上微调的 ResNet 模型
model = models.resnet18(pretrained=False, num_classes=10).to(device)model.eval() # 设置为评估模式# 使用的损失函数
loss_fn = nn.CrossEntropyLoss()# Step 4: FGSM 对抗攻击函数
def fgsm_attack(image, epsilon, data_grad):# 生成扰动方向sign_data_grad = data_grad.sign()# 生成对抗样本perturbed_image = image + epsilon * sign_data_grad# 对抗样本像素值范围约束在 [0,1]perturbed_image = torch.clamp(perturbed_image, 0, 1)return perturbed_image# Step 5: 攻击流程
def attack_example(model, device, data, target, epsilon):data, target = data.to(device), target.to(device)# 确保计算图对输入的梯度data.requires_grad = True# 前向传播output = model(data)init_pred = output.max(1, keepdim=True)[1] # 预测标签# 如果预测错误,则不攻击if init_pred.item() != target.item():return None, None, None# 计算损失loss = loss_fn(output, target)# 反向传播计算梯度model.zero_grad()loss.backward()# 提取梯度data_grad = data.grad.data# 执行 FGSM 攻击perturbed_data = fgsm_attack(data, epsilon, data_grad)# 再次进行预测output = model(perturbed_data)final_pred = output.max(1, keepdim=True)[1] # 对抗样本的预测标签return data, perturbed_data, final_pred# Step 6: 选择 epsilon 并进行测试
epsilon = 0.01 # 扰动强度# 遍历数据集,找到预测正确的样本
for data, target in test_loader:orig_data, perturbed_data, final_pred = attack_example(model, device, data, target, epsilon)if orig_data is not None:break# CIFAR-10 类别
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']if orig_data is not None:# 显示原始图像和对抗样本orig_img = orig_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()perturbed_img = perturbed_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()# 标签# orig_label = target.item()# perturbed_label = final_pred.item()# 获取语义化标签original_label_name = cifar10_classes[target.item()]adversarial_label_name = cifar10_classes[final_pred.item()]# Step 7: 展示结果fig, (ax1, ax2) = plt.subplots(1, 2)ax1.imshow(orig_img)ax1.set_title(f"Original Label: {original_label_name}")ax2.imshow(perturbed_img)ax2.set_title(f"Adversarial Label: {adversarial_label_name}")plt.show()print(f"Original Label: {original_label_name}, Adversarial Label: {adversarial_label_name}")
else:print("Initial prediction was incorrect. No attack performed.")
相关文章:

从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击
从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击 简介 在机器学习领域,深度神经网络的强大表现令人印象深刻,尤其是在图像分类等任务上。然而,随着对深度学习的深入研究,研究人员发现了神经网络的一个脆弱性&…...

从零开始学习OMNeT++系列第一弹——OMNeT++的介绍与安装
最近由于由于工作上的需求,接了一个网络仿真的任务。于是开始调研各个仿真平台,然后根据目前的需求和网络上公开资料的多少,决定使用omnet这个网络仿真平台。现在也是刚开始学习,所以决定记录一下从零开始的这个学习过程。因为虽然…...

Cluster Explanation via Polyhedral Descriptions
通过多面体描述进行聚类解释 本文关注聚类描述问题,即在给定数据集及其聚类划分的情况下,解释这些聚类的任务。我们提出了一种新的聚类解释方法,通过在每个聚类周围构建一个多面体,同时最小化最终多面体的复杂性或用于描述的特征…...
爬虫设计思考之一
爬虫设计思考之一 经常做爬虫的人对于技术比较的执着,尤其是本身从事的擅长的技术领域,从而容易忽视与之相近或者相似的技术。因此我建议大家在遇到此类问题的时候,可以采用对比分析的方式来理解。 本次的思考是基于国内最大的中文搜索引擎百…...
解决centos 删除文件后但空间没有释放
一、问题描述:磁盘空间不足,清理完垃圾日志以后磁盘空间还是没有释放 查看磁盘空间 [rootxwj-qt-65-44 ~]# df -h Filesystem Size Used Avail Use% Mounted on devtmpfs 1.9G 0 1.9G 0% /dev tmpfs 1.9G 0 1.9G …...
微软SCCM:企业级系统管理的核心工具
目录 摘要 1. 引言 2. SCCM的基本概念 2.1 什么是SCCM? 2.2 SCCM的历史 3. SCCM的架构 3.1 中心服务器 3.2 数据库 3.3 管理点(Management Point) 3.4 分发点(Distribution Point) 3.5 客户端代理 3.6 报告服务 4. SCCM的核心功能 4.1 软件部署与管理 4.2 操…...

RTSP作为客户端 推流 拉流的过程分析
之前写过一个 rtsp server 作为服务端的简单demo 这次分析下 rtsp作为客户端 推流和拉流时候的过 A.作为客户端拉流 TCP方式 1.Client发送OPTIONS方法 Server回应告诉支持的方法 2.Client发送DESCRIPE方法 这里是从海康摄像机拉流并且设置了用户名密码 Server回复未认证 3.客…...

【MySQL 07】内置函数
目录 1.日期函数 日期函数使用场景: 2.字符串函数 字符串函数使用场景: 3.数学函数 4.控制流函数 1.日期函数 函数示例: 1.在日期的基础上加日期 在该日期下,加上10天。 2.在日期的基础上减去时间 在该日期下减去2天 3.计算两…...

《深度学习》OpenCV 背景建模 原理及案例解析
目录 一、背景建模 1、什么是背景建模 2、背景建模的方法 1)帧差法(backgroundSubtractor) 2)基于K近邻的背景/前景分割算法BackgroundSubtractorKNN 3)基于高斯混合的背景/前景分割算法BackgroundSubtractorMOG2 3、步骤 1)初…...
机器学习(1):机器学习的概念
1. 机器学习的定义和相关概念 机器学习之父 Arthur Samuel 对机器学习的定义是:在没有明确设置的情况下,使计算机具有学习能力的研究领域。 国际机器学习大会的创始人之一 Tom Mitchell 对机器学习的定义是:计算机程序从经验 E 中学习&#…...

0. Pixel3 在Ubuntu22下Android12源码拉取 + 编译
0. Pixel3 在Ubuntu22下Android12源码拉取 编译 原文地址: http://www.androidcrack.com/index.php/archives/3/ 1. 前言 这是一个非常悲伤的故事, 因为一个意外, 不小心把之前镜像的源码搞坏了. 也没做版本管理,恢复不了了. 那么只能说是重新做一次. 再者以前的镜像太老旧…...
ip经过多个服务器转发会网速变慢吗
会的,IP经过多个服务器转发时,网速通常会变慢,主要原因包括: 增加的延迟: 每经过一个服务器,数据包就需要额外的时间进行处理和转发。这种处理时间和网络延迟会累积,导致整体延迟增加。 带宽限制…...
mongodb通过mongoimport导入JSON文件数据
目录 一、概念 二、mongoimport导入工具 三、导入命令 一、概念 MongoDB是一个流行的开源文档数据库,它支持JSON格式的文档,非常适合存储和处理大量的非结构化数据。在实际应用中,我们经常需要将大量的数据批量导入到MongoDB中。mongoimpo…...

【Qt】控件概述 (1)
控件概述 1. QWidget核心属性1.1核心属性概述1.2 enable1.3 geometry——窗口坐标1.4 window frame的影响1.4 windowTitle——窗口标题1.5 windowIcon——窗口图标1.6 windowOpacity——透明度设置1.7 cursor——光标设置1.8 font——字体设置1.9 toolTip——鼠标悬停提示设置1…...

ping基本使用详解
在网络中ping是一个十分强大的TCP/IP工具。它的作用主要为: 用来检测网络的连通情况和分析网络速度根据域名得到服务器 IP根据 ping 返回的 TTL 值来判断对方所使用的操作系统及数据包经过路由器数量。我们通常会用它来直接 ping ip 地址,来测试网络的连…...

Win10之解决:设置静态IP后,为什么自动获取动态IP问题(七十八)
简介: CSDN博客专家、《Android系统多媒体进阶实战》一书作者 新书发布:《Android系统多媒体进阶实战》🚀 优质专栏: Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏: 多媒体系统工程师系列【…...

【AI论文精读1】针对知识密集型NLP任务的检索增强生成(RAG原始论文)
目录 一、简介一句话简介作者、引用数、时间论文地址开源代码地址 二、摘要三、引言四、整体架构(用一个例子来阐明)场景例子:核心点: 五、方法 (架构各部分详解)5.1 模型1. RAG-Sequence Model2. RAG-Toke…...
踩坑spring cloud gateway /actuator/gateway/refresh不生效
版本 java version: 17 spring boot: 3.2.x spring cloud: 2023.0.3 现象 参考Spring Cloud Gateway -> Actuator API -> Refreshing the Route Cache 说明,先修改routes配置再调用/actuator/gateway/refresh,接口返回200 status,但…...

【STM32开发环境搭建】-3-STM32CubeMX Project Manager配置-自动生成一个Keil(MDK-ARM) 5的工程
目录 1 KEIL(MDK-ARM) 5 Project工程设置 2 MCU和嵌入式软件包的选择 3 Code Generator 3.1 STM32Cube Firmware Library Package 3.2 Generated files 3.3 HAL Settings 3.4 Template Settings 4 Advanced Settings 5 自动生成的KEIL(MDK-ARM) 5 Project工程目录 结…...

计算机毕业设计 Java酷听音乐系统的设计与实现 Java实战项目 附源码+文档+视频讲解
博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…...
conda相比python好处
Conda 作为 Python 的环境和包管理工具,相比原生 Python 生态(如 pip 虚拟环境)有许多独特优势,尤其在多项目管理、依赖处理和跨平台兼容性等方面表现更优。以下是 Conda 的核心好处: 一、一站式环境管理:…...

Spark 之 入门讲解详细版(1)
1、简介 1.1 Spark简介 Spark是加州大学伯克利分校AMP实验室(Algorithms, Machines, and People Lab)开发通用内存并行计算框架。Spark在2013年6月进入Apache成为孵化项目,8个月后成为Apache顶级项目,速度之快足见过人之处&…...

云启出海,智联未来|阿里云网络「企业出海」系列客户沙龙上海站圆满落地
借阿里云中企出海大会的东风,以**「云启出海,智联未来|打造安全可靠的出海云网络引擎」为主题的阿里云企业出海客户沙龙云网络&安全专场于5.28日下午在上海顺利举办,现场吸引了来自携程、小红书、米哈游、哔哩哔哩、波克城市、…...

STM32F4基本定时器使用和原理详解
STM32F4基本定时器使用和原理详解 前言如何确定定时器挂载在哪条时钟线上配置及使用方法参数配置PrescalerCounter ModeCounter Periodauto-reload preloadTrigger Event Selection 中断配置生成的代码及使用方法初始化代码基本定时器触发DCA或者ADC的代码讲解中断代码定时启动…...
VTK如何让部分单位不可见
最近遇到一个需求,需要让一个vtkDataSet中的部分单元不可见,查阅了一些资料大概有以下几种方式 1.通过颜色映射表来进行,是最正规的做法 vtkNew<vtkLookupTable> lut; //值为0不显示,主要是最后一个参数,透明度…...
JDK 17 新特性
#JDK 17 新特性 /**************** 文本块 *****************/ python/scala中早就支持,不稀奇 String json “”" { “name”: “Java”, “version”: 17 } “”"; /**************** Switch 语句 -> 表达式 *****************/ 挺好的ÿ…...

OPENCV形态学基础之二腐蚀
一.腐蚀的原理 (图1) 数学表达式:dst(x,y) erode(src(x,y)) min(x,y)src(xx,yy) 腐蚀也是图像形态学的基本功能之一,腐蚀跟膨胀属于反向操作,膨胀是把图像图像变大,而腐蚀就是把图像变小。腐蚀后的图像变小变暗淡。 腐蚀…...
Fabric V2.5 通用溯源系统——增加图片上传与下载功能
fabric-trace项目在发布一年后,部署量已突破1000次,为支持更多场景,现新增支持图片信息上链,本文对图片上传、下载功能代码进行梳理,包含智能合约、后端、前端部分。 一、智能合约修改 为了增加图片信息上链溯源,需要对底层数据结构进行修改,在此对智能合约中的农产品数…...

mac:大模型系列测试
0 MAC 前几天经过学生优惠以及国补17K入手了mac studio,然后这两天亲自测试其模型行运用能力如何,是否支持微调、推理速度等能力。下面进入正文。 1 mac 与 unsloth 按照下面的进行安装以及测试,是可以跑通文章里面的代码。训练速度也是很快的。 注意…...

热门Chrome扩展程序存在明文传输风险,用户隐私安全受威胁
赛门铁克威胁猎手团队最新报告披露,数款拥有数百万活跃用户的Chrome扩展程序正在通过未加密的HTTP连接静默泄露用户敏感数据,严重威胁用户隐私安全。 知名扩展程序存在明文传输风险 尽管宣称提供安全浏览、数据分析或便捷界面等功能,但SEMR…...