当前位置: 首页 > news >正文

从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击

从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击

简介

在机器学习领域,深度神经网络的强大表现令人印象深刻,尤其是在图像分类等任务上。然而,随着对深度学习的深入研究,研究人员发现了神经网络的一个脆弱性:对抗样本(Adversarial Examples)。简单来说,对抗样本是通过对原始输入数据添加微小扰动,导致神经网络预测错误的一类特殊样本。

Fast Gradient Sign Method (FGSM) 是最早提出的一种简单且有效的对抗攻击方法。它通过利用神经网络的梯度信息,快速生成对抗样本。本篇博客将深入剖析 FGSM 的数学原理,并提供一份基于 PyTorch 的代码实现,帮助你快速上手对抗攻击。

最终效果

当前图片为 truck,但经过对抗攻击后,结果看似应该同样是 truck,然而模型却错误地将其分类为了 cat

在这里插入图片描述

数学原理

1. 神经网络的基本工作原理

神经网络通过对输入数据 x x x 进行一系列的非线性变换,输出一个预测值 y ^ \hat{y} y^。以图像分类任务为例,输入 x x x 是图像,网络输出 y ^ \hat{y} y^ 是一个分类结果。网络的目标是使得预测值 y ^ \hat{y} y^ 尽可能接近真实标签 y y y,通过最小化损失函数 L ( y ^ , y ) L(\hat{y}, y) L(y^,y) 来优化网络的权重。

2. 对抗攻击的目标

对抗攻击的核心思想是在输入数据上添加一个微小的扰动 η \eta η,使得网络在预测时发生错误。这种扰动应尽量保持小到人类肉眼难以察觉,但足以让网络产生不同的输出,即:

y ^ a d v = f ( x + η ) and y ^ a d v ≠ y \hat{y}_{adv} = f(x + \eta) \quad \text{and} \quad \hat{y}_{adv} \neq y y^adv=f(x+η)andy^adv=y

3. FGSM 的核心思想

FGSM 利用了梯度上升的思想,通过损失函数相对于输入图像的梯度来找到 最容易 迷惑网络的方向,并沿着这个方向对图像进行微小的扰动。

具体来说,给定输入图像 x x x 和其真实标签 y y y,我们可以计算损失函数 L ( x , y ) L(x, y) L(x,y),并对输入图像 x x x 计算其梯度:

∇ x L ( f ( x ) , y ) \nabla_x L(f(x), y) xL(f(x),y)

这个梯度告诉我们,如何改变输入图像才能最大化损失。FGSM 的基本想法是,沿着这个梯度的符号方向对图像进行微调,以最大化损失函数。具体公式为:

x a d v = x + ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) x_{adv} = x + \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) xadv=x+ϵsign(xL(f(x),y))

其中:

  • ϵ \epsilon ϵ 是一个小的常数,控制扰动的大小。
  • sign ( ∇ x L ( f ( x ) , y ) ) \text{sign}(\nabla_x L(f(x), y)) sign(xL(f(x),y)) 是损失函数对输入图像梯度的符号,即正负号矩阵。

通过这种方式,我们可以生成一个对抗样本 x a d v x_{adv} xadv,它与原始图像 x x x 看起来几乎相同,但神经网络会将其错误分类。

4. 梯度上升与对抗样本

FGSM 是基于梯度上升的攻击方法。为了让网络犯错,我们希望最大化损失,因此通过扰动让损失函数 L L L 增加。梯度 ∇ x L \nabla_x L xL 的方向是使 L L L 增加最快的方向,因此,我们可以通过对梯度取符号(sign)来生成一个扰动 η \eta η,即:

η = ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) \eta = \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) η=ϵsign(xL(f(x),y))

这样生成的扰动被加到原图 x x x 上,即对抗样本的公式为:

x a d v = x + η = x + ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) x_{adv} = x + \eta = x + \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) xadv=x+η=x+ϵsign(xL(f(x),y))

代码实现与解读

接下来,我们通过实际的代码示例,展示如何使用 PyTorch 实现 FGSM 攻击。

1. 加载CIFAR-10数据集与预训练模型

首先,我们需要加载 CIFAR-10 数据集和一个预训练的模型。这里我们使用 ResNet-18 模型,并针对 CIFAR-10 数据集进行了调整。

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models# 设置设备 (GPU 优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 加载 CIFAR-10 数据集
transform = transforms.Compose([transforms.ToTensor(),
])
test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)# 加载微调后的 ResNet 模型
model = models.resnet18(pretrained=False, num_classes=10).to(device)
model.eval()  # 设置为评估模式

在这一步中,我们准备了数据集,并将模型设置为评估模式,这样模型的参数不会在攻击过程中被更新。

提示:如果没有下载过 CIFAR-10 数据集,请将 download=False 改为 download=True

2. 实现FGSM攻击方法

FGSM 的关键在于利用损失函数的梯度来生成对抗样本。具体步骤如下:

# 定义 FGSM 攻击函数
def fgsm_attack(image, epsilon, data_grad):# 生成扰动方向sign_data_grad = data_grad.sign()# 生成对抗样本perturbed_image = image + epsilon * sign_data_grad# 对抗样本像素值范围约束在 [0,1]perturbed_image = torch.clamp(perturbed_image, 0, 1)return perturbed_image
数学公式:

在代码中,我们实际上实现了如下的公式:

x a d v = x + ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) x_{adv} = x + \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) xadv=x+ϵsign(xL(f(x),y))

该函数实现了 FGSM 算法,通过将计算得到的梯度方向乘以一个小的扰动系数 ϵ \epsilon ϵ,并添加到原始图像上来生成对抗样本。

3. 攻击流程

接下来,我们定义一个函数,来完成攻击过程:

def attack_example(model, device, data, target, epsilon):# 将数据和标签移动到设备上data, target = data.to(device), target.to(device)data.requires_grad = True  # 追踪输入图像的梯度# 前向传播得到初始预测output = model(data)init_pred = output.max(1, keepdim=True)[1]  # 选择概率最大的类别作为预测结果# 如果初始预测错误,则不进行攻击if init_pred.item() != target.item():return None, None, None# 计算损失并进行反向传播计算梯度loss_fn = torch.nn.CrossEntropyLoss()loss = loss_fn(output, target)model.zero_grad()  # 清除现有的梯度loss.backward()  # 计算损失对输入图像的梯度# 使用 FGSM 方法生成对抗样本data_grad = data.grad.dataperturbed_data = fgsm_attack(data, epsilon, data_grad)# 使用对抗样本进行新的预测output = model(perturbed_data)final_pred = output.max(1, keepdim=True)[1]  # 获取对抗样本的预测标签return data, perturbed_data, final_pred

对抗攻击的基本思路是:

  1. 前向传播:将输入数据(如图片)传递给模型,得到输出,并预测标签。
  2. 计算损失:使用损失函数来评估模型的预测与真实标签之间的差距。
  3. 反向传播计算梯度:通过反向传播,计算损失函数相对于输入数据的梯度。
  4. 生成对抗样本:利用这些梯度来调整输入数据,生成扰动的对抗样本。
  5. 重新预测:使用生成的对抗样本进行重新预测,观察模型是否做出错误的分类。

4. 展示原始图像与对抗样本

最后,我们选择一个样本进行攻击,并展示原始图像和对抗样本的效果。

# Step 6: 选择 epsilon 并进行测试
epsilon = 0.01  # 扰动强度# 遍历数据集,找到预测正确的样本
for data, target in test_loader:orig_data, perturbed_data, final_pred = attack_example(model, device, data, target, epsilon)if orig_data is not None:break# CIFAR-10 类别
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']if orig_data is not None:# 显示原始图像和对抗样本orig_img = orig_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()perturbed_img = perturbed_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()# 获取语义化标签original_label_name = cifar10_classes[target.item()]adversarial_label_name = cifar10_classes[final_pred.item()]# Step 7: 展示结果fig, (ax1, ax2) = plt.subplots(1, 2)ax1.imshow(orig_img)ax1.set_title(f"Original Label: {original_label_name}")ax2.imshow(perturbed_img)ax2.set_title(f"Adversarial Label: {adversarial_label_name}")plt.show()print(f"Original Label: {original_label_name}, Adversarial Label: {adversarial_label_name}")
else:print("Initial prediction was incorrect. No attack performed.")

在运行这段代码后,你会看到一张原始图像及其对抗样本。尽管对抗样本与原图在视觉上几乎没有区别,但它却被神经网络错误分类了。这种现象突显了神经网络的脆弱性。

结论

FGSM 是一种简单且有效的对抗攻击方法,它利用了神经网络的梯度信息生成对抗样本。这些对抗样本虽然肉眼难以察觉,但能显著影响模型的预测结果。本文介绍了 FGSM 的数学原理,并提供了基于 PyTorch 的实现代码。

通过这个例子,我们可以看到即使是看似强大的深度学习模型,也会受到一些精心设计的输入的严重影响。了解这些脆弱性有助于我们开发出更加健壮和安全的模型。

附录:完整代码

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np# Step 1: 设置设备 (GPU 优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Step 2: 加载 CIFAR-10 数据集
transform = transforms.Compose([transforms.ToTensor(),
])# 加载测试数据集
test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)# model = models.resnet18(pretrained=True).to(device)# 1. 加载在 CIFAR-10 上微调的 ResNet 模型
model = models.resnet18(pretrained=False, num_classes=10).to(device)model.eval()  # 设置为评估模式# 使用的损失函数
loss_fn = nn.CrossEntropyLoss()# Step 4: FGSM 对抗攻击函数
def fgsm_attack(image, epsilon, data_grad):# 生成扰动方向sign_data_grad = data_grad.sign()# 生成对抗样本perturbed_image = image + epsilon * sign_data_grad# 对抗样本像素值范围约束在 [0,1]perturbed_image = torch.clamp(perturbed_image, 0, 1)return perturbed_image# Step 5: 攻击流程
def attack_example(model, device, data, target, epsilon):data, target = data.to(device), target.to(device)# 确保计算图对输入的梯度data.requires_grad = True# 前向传播output = model(data)init_pred = output.max(1, keepdim=True)[1]  # 预测标签# 如果预测错误,则不攻击if init_pred.item() != target.item():return None, None, None# 计算损失loss = loss_fn(output, target)# 反向传播计算梯度model.zero_grad()loss.backward()# 提取梯度data_grad = data.grad.data# 执行 FGSM 攻击perturbed_data = fgsm_attack(data, epsilon, data_grad)# 再次进行预测output = model(perturbed_data)final_pred = output.max(1, keepdim=True)[1]  # 对抗样本的预测标签return data, perturbed_data, final_pred# Step 6: 选择 epsilon 并进行测试
epsilon = 0.01  # 扰动强度# 遍历数据集,找到预测正确的样本
for data, target in test_loader:orig_data, perturbed_data, final_pred = attack_example(model, device, data, target, epsilon)if orig_data is not None:break# CIFAR-10 类别
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']if orig_data is not None:# 显示原始图像和对抗样本orig_img = orig_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()perturbed_img = perturbed_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()# 标签# orig_label = target.item()# perturbed_label = final_pred.item()# 获取语义化标签original_label_name = cifar10_classes[target.item()]adversarial_label_name = cifar10_classes[final_pred.item()]# Step 7: 展示结果fig, (ax1, ax2) = plt.subplots(1, 2)ax1.imshow(orig_img)ax1.set_title(f"Original Label: {original_label_name}")ax2.imshow(perturbed_img)ax2.set_title(f"Adversarial Label: {adversarial_label_name}")plt.show()print(f"Original Label: {original_label_name}, Adversarial Label: {adversarial_label_name}")
else:print("Initial prediction was incorrect. No attack performed.")

相关文章:

从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击

从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击 简介 在机器学习领域,深度神经网络的强大表现令人印象深刻,尤其是在图像分类等任务上。然而,随着对深度学习的深入研究,研究人员发现了神经网络的一个脆弱性&…...

从零开始学习OMNeT++系列第一弹——OMNeT++的介绍与安装

最近由于由于工作上的需求,接了一个网络仿真的任务。于是开始调研各个仿真平台,然后根据目前的需求和网络上公开资料的多少,决定使用omnet这个网络仿真平台。现在也是刚开始学习,所以决定记录一下从零开始的这个学习过程。因为虽然…...

Cluster Explanation via Polyhedral Descriptions

通过多面体描述进行聚类解释 本文关注聚类描述问题,即在给定数据集及其聚类划分的情况下,解释这些聚类的任务。我们提出了一种新的聚类解释方法,通过在每个聚类周围构建一个多面体,同时最小化最终多面体的复杂性或用于描述的特征…...

爬虫设计思考之一

爬虫设计思考之一 经常做爬虫的人对于技术比较的执着,尤其是本身从事的擅长的技术领域,从而容易忽视与之相近或者相似的技术。因此我建议大家在遇到此类问题的时候,可以采用对比分析的方式来理解。 本次的思考是基于国内最大的中文搜索引擎百…...

解决centos 删除文件后但空间没有释放

一、问题描述:磁盘空间不足,清理完垃圾日志以后磁盘空间还是没有释放 查看磁盘空间 [rootxwj-qt-65-44 ~]# df -h Filesystem Size Used Avail Use% Mounted on devtmpfs 1.9G 0 1.9G 0% /dev tmpfs 1.9G 0 1.9G …...

微软SCCM:企业级系统管理的核心工具

目录 摘要 1. 引言 2. SCCM的基本概念 2.1 什么是SCCM? 2.2 SCCM的历史 3. SCCM的架构 3.1 中心服务器 3.2 数据库 3.3 管理点(Management Point) 3.4 分发点(Distribution Point) 3.5 客户端代理 3.6 报告服务 4. SCCM的核心功能 4.1 软件部署与管理 4.2 操…...

RTSP作为客户端 推流 拉流的过程分析

之前写过一个 rtsp server 作为服务端的简单demo 这次分析下 rtsp作为客户端 推流和拉流时候的过 A.作为客户端拉流 TCP方式 1.Client发送OPTIONS方法 Server回应告诉支持的方法 2.Client发送DESCRIPE方法 这里是从海康摄像机拉流并且设置了用户名密码 Server回复未认证 3.客…...

【MySQL 07】内置函数

目录 1.日期函数 日期函数使用场景: 2.字符串函数 字符串函数使用场景: 3.数学函数 4.控制流函数 1.日期函数 函数示例: 1.在日期的基础上加日期 在该日期下,加上10天。 2.在日期的基础上减去时间 在该日期下减去2天 3.计算两…...

《深度学习》OpenCV 背景建模 原理及案例解析

目录 一、背景建模 1、什么是背景建模 2、背景建模的方法 1)帧差法(backgroundSubtractor) 2)基于K近邻的背景/前景分割算法BackgroundSubtractorKNN 3)基于高斯混合的背景/前景分割算法BackgroundSubtractorMOG2 3、步骤 1)初…...

机器学习(1):机器学习的概念

1. 机器学习的定义和相关概念 机器学习之父 Arthur Samuel 对机器学习的定义是:在没有明确设置的情况下,使计算机具有学习能力的研究领域。 国际机器学习大会的创始人之一 Tom Mitchell 对机器学习的定义是:计算机程序从经验 E 中学习&#…...

0. Pixel3 在Ubuntu22下Android12源码拉取 + 编译

0. Pixel3 在Ubuntu22下Android12源码拉取 编译 原文地址: http://www.androidcrack.com/index.php/archives/3/ 1. 前言 这是一个非常悲伤的故事, 因为一个意外, 不小心把之前镜像的源码搞坏了. 也没做版本管理,恢复不了了. 那么只能说是重新做一次. 再者以前的镜像太老旧…...

ip经过多个服务器转发会网速变慢吗

会的,IP经过多个服务器转发时,网速通常会变慢,主要原因包括: 增加的延迟: 每经过一个服务器,数据包就需要额外的时间进行处理和转发。这种处理时间和网络延迟会累积,导致整体延迟增加。 带宽限制…...

mongodb通过mongoimport导入JSON文件数据

目录 一、概念 二、mongoimport导入工具 三、导入命令 一、概念 MongoDB是一个流行的开源文档数据库,它支持JSON格式的文档,非常适合存储和处理大量的非结构化数据。在实际应用中,我们经常需要将大量的数据批量导入到MongoDB中。mongoimpo…...

【Qt】控件概述 (1)

控件概述 1. QWidget核心属性1.1核心属性概述1.2 enable1.3 geometry——窗口坐标1.4 window frame的影响1.4 windowTitle——窗口标题1.5 windowIcon——窗口图标1.6 windowOpacity——透明度设置1.7 cursor——光标设置1.8 font——字体设置1.9 toolTip——鼠标悬停提示设置1…...

ping基本使用详解

在网络中ping是一个十分强大的TCP/IP工具。它的作用主要为: 用来检测网络的连通情况和分析网络速度根据域名得到服务器 IP根据 ping 返回的 TTL 值来判断对方所使用的操作系统及数据包经过路由器数量。我们通常会用它来直接 ping ip 地址,来测试网络的连…...

Win10之解决:设置静态IP后,为什么自动获取动态IP问题(七十八)

简介: CSDN博客专家、《Android系统多媒体进阶实战》一书作者 新书发布:《Android系统多媒体进阶实战》🚀 优质专栏: Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏: 多媒体系统工程师系列【…...

【AI论文精读1】针对知识密集型NLP任务的检索增强生成(RAG原始论文)

目录 一、简介一句话简介作者、引用数、时间论文地址开源代码地址 二、摘要三、引言四、整体架构(用一个例子来阐明)场景例子:核心点: 五、方法 (架构各部分详解)5.1 模型1. RAG-Sequence Model2. RAG-Toke…...

踩坑spring cloud gateway /actuator/gateway/refresh不生效

版本 java version: 17 spring boot: 3.2.x spring cloud: 2023.0.3 现象 参考Spring Cloud Gateway -> Actuator API -> Refreshing the Route Cache 说明,先修改routes配置再调用/actuator/gateway/refresh,接口返回200 status,但…...

【STM32开发环境搭建】-3-STM32CubeMX Project Manager配置-自动生成一个Keil(MDK-ARM) 5的工程

目录 1 KEIL(MDK-ARM) 5 Project工程设置 2 MCU和嵌入式软件包的选择 3 Code Generator 3.1 STM32Cube Firmware Library Package 3.2 Generated files 3.3 HAL Settings 3.4 Template Settings 4 Advanced Settings 5 自动生成的KEIL(MDK-ARM) 5 Project工程目录 结…...

计算机毕业设计 Java酷听音乐系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…...

【kafka】Golang实现分布式Masscan任务调度系统

要求: 输出两个程序,一个命令行程序(命令行参数用flag)和一个服务端程序。 命令行程序支持通过命令行参数配置下发IP或IP段、端口、扫描带宽,然后将消息推送到kafka里面。 服务端程序: 从kafka消费者接收…...

(十)学生端搭建

本次旨在将之前的已完成的部分功能进行拼装到学生端,同时完善学生端的构建。本次工作主要包括: 1.学生端整体界面布局 2.模拟考场与部分个人画像流程的串联 3.整体学生端逻辑 一、学生端 在主界面可以选择自己的用户角色 选择学生则进入学生登录界面…...

【2025年】解决Burpsuite抓不到https包的问题

环境:windows11 burpsuite:2025.5 在抓取https网站时,burpsuite抓取不到https数据包,只显示: 解决该问题只需如下三个步骤: 1、浏览器中访问 http://burp 2、下载 CA certificate 证书 3、在设置--隐私与安全--…...

CMake控制VS2022项目文件分组

我们可以通过 CMake 控制源文件的组织结构,使它们在 VS 解决方案资源管理器中以“组”(Filter)的形式进行分类展示。 🎯 目标 通过 CMake 脚本将 .cpp、.h 等源文件分组显示在 Visual Studio 2022 的解决方案资源管理器中。 ✅ 支持的方法汇总(共4种) 方法描述是否推荐…...

今日学习:Spring线程池|并发修改异常|链路丢失|登录续期|VIP过期策略|数值类缓存

文章目录 优雅版线程池ThreadPoolTaskExecutor和ThreadPoolTaskExecutor的装饰器并发修改异常并发修改异常简介实现机制设计原因及意义 使用线程池造成的链路丢失问题线程池导致的链路丢失问题发生原因 常见解决方法更好的解决方法设计精妙之处 登录续期登录续期常见实现方式特…...

Spring是如何解决Bean的循环依赖:三级缓存机制

1、什么是 Bean 的循环依赖 在 Spring框架中,Bean 的循环依赖是指多个 Bean 之间‌互相持有对方引用‌,形成闭环依赖关系的现象。 多个 Bean 的依赖关系构成环形链路,例如: 双向依赖:Bean A 依赖 Bean B,同时 Bean B 也依赖 Bean A(A↔B)。链条循环: Bean A → Bean…...

代码规范和架构【立芯理论一】(2025.06.08)

1、代码规范的目标 代码简洁精炼、美观,可持续性好高效率高复用,可移植性好高内聚,低耦合没有冗余规范性,代码有规可循,可以看出自己当时的思考过程特殊排版,特殊语法,特殊指令,必须…...

PHP 8.5 即将发布:管道操作符、强力调试

前不久,PHP宣布了即将在 2025 年 11 月 20 日 正式发布的 PHP 8.5!作为 PHP 语言的又一次重要迭代,PHP 8.5 承诺带来一系列旨在提升代码可读性、健壮性以及开发者效率的改进。而更令人兴奋的是,借助强大的本地开发环境 ServBay&am…...

Chrome 浏览器前端与客户端双向通信实战

Chrome 前端(即页面 JS / Web UI)与客户端(C 后端)的交互机制,是 Chromium 架构中非常核心的一环。下面我将按常见场景,从通道、流程、技术栈几个角度做一套完整的分析,特别适合你这种在分析和改…...

简单聊下阿里云DNS劫持事件

阿里云域名被DNS劫持事件 事件总结 根据ICANN规则,域名注册商(Verisign)认定aliyuncs.com域名下的部分网站被用于非法活动(如传播恶意软件);顶级域名DNS服务器将aliyuncs.com域名的DNS记录统一解析到shado…...