【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning
【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning
1 算法原理
论文:Graves, L., Nagisetty, V., & Ganesh, V. (2021). Amnesiac machine learning. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 35, 11516–11524.
Amnesiac Unlearning(遗忘性遗忘) 是一种高效且精确的算法,旨在从已经训练好的神经网络模型中删除特定数据的学习信息,而不会显著影响模型在其他数据上的性能。该算法的核心思想是通过选择性撤销与敏感数据相关的参数更新来实现数据的“遗忘”。
1. 训练阶段:记录参数更新
在模型训练过程中,记录每个批次的参数更新以及哪些批次包含敏感数据。
- 步骤:
- 初始化模型参数:从随机初始化的参数 θ i n i t i a l \theta_{initial} θinitial 开始训练模型。
- 训练模型:使用标准训练方法(如随机梯度下降)对模型进行训练,训练过程分为多个 epoch,每个 epoch 包含多个批次(batches)。
- 记录参数更新:
- 对于每个批次 b b b,记录该批次的参数更新 Δ θ e , b \Delta_{\theta_{e,b}} Δθe,b,其中 e e e 表示 epoch 编号, b b b 表示批次编号。
- 同时,记录哪些批次包含敏感数据(即需要删除的数据)。可以将这些批次标记为 S B SB SB(Sensitive Batches)。
- 存储信息:
- 存储所有批次的参数更新 Δ θ e , b \Delta_{\theta_{e,b}} Δθe,b。
- 存储敏感数据批次的索引 S B SB SB。
2. 数据删除阶段:选择性撤销参数更新
当收到数据删除请求时,撤销与敏感数据相关的参数更新。
-
步骤:
- 识别敏感数据批次:从存储的记录中提取包含敏感数据的批次索引 S B SB SB。
- 撤销参数更新:
计算删除敏感数据后的模型参数 θ M \theta_{M} θM:
θ M ′ = θ M − ∑ s b ∈ S B Δ θ s b \theta_{M'} = \theta_{M} - \sum_{sb \in SB} \Delta_{\theta_{sb}} θM′=θM−sb∈SB∑Δθsb其中:
- θ M \theta_{M} θM 是原始训练后的模型参数。
- Δ θ s b \Delta_{\theta_{sb}} Δθsb 是敏感数据批次 s b sb sb 的参数更新。
- 生成保护模型:使用更新后的参数 θ M ′ \theta_{M'} θM′ 作为新的模型参数。
3. 微调阶段(可选)
如果删除的批次较多,可能会对模型性能产生一定影响。此时可以通过少量微调来恢复模型性能。
- 步骤:
- 微调模型:使用删除敏感数据后的数据集对模型进行少量迭代训练。
- 恢复性能:通过微调,模型可以恢复在非敏感数据上的性能。
2 代码实现
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from models.Base import load_MNIST_data, test_model, device, MLP, load_CIFAR100_data, init_model# AmnesiacForget类:封装撤销与敏感数据相关的参数更新
class AmnesiacForget:def __init__(self, model, all_data, epochs, learning_rate):"""初始化 AmnesiacForget 类。:param model: 需要训练的模型。:param all_data: 训练数据集。:param epochs: 训练的总 epoch 数。:param learning_rate: 优化器的学习率。"""self.model = modelself.all_data = all_dataself.epochs = epochsself.learning_rate = learning_rateself.batch_updates = [] # 存储每个批次的参数更新值self.initial_params = {name: param.clone() for name, param in model.named_parameters()} # 存储初始模型参数self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设备选择(GPU 或 CPU)def train(self, forgotten_classes):"""训练模型并记录每个批次的参数更新值。:param forgotten_classes: 需要遗忘的类别列表。:return: sensitive_batches: 包含敏感数据的批次索引。"""optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) # 使用 Adam 优化器self.model.train() # 将模型设置为训练模式sensitive_batches = {} # 记录每个 epoch 中包含敏感数据的批次索引# 训练过程for epoch in range(self.epochs):running_loss = 0.0sensitive_batches[epoch] = set() # 每个 epoch 的敏感批次集for batch_idx, (images, labels) in enumerate(self.all_data):optimizer.zero_grad() # 清空梯度images, labels = images.to(self.device), labels.to(self.device) # 将数据移动到设备上# 前向传播和损失计算outputs = self.model(images)loss = nn.CrossEntropyLoss()(outputs, labels)# 反向传播计算梯度loss.backward()running_loss += loss.item()# 记录当前参数值current_params = {name: param.clone() for name, param in self.model.named_parameters()}# 更新参数optimizer.step()# 记录参数更新值(当前参数值 - 更新前的参数值)batch_update = {}for name, param in self.model.named_parameters():if param.requires_grad:batch_update[name] = param.data - current_params[name].data # 记录参数更新值self.batch_updates.append(batch_update)# 记录包含敏感数据的批次索引if any(label.item() in forgotten_classes for label in labels):sensitive_batches[epoch].add(batch_idx)print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {running_loss/len(self.all_data):.4f}")return sensitive_batchesdef unlearn(self, sensitive_batches):"""撤销与敏感数据相关的批次更新。:param sensitive_batches: 包含敏感数据的批次索引。:return: 更新后的模型。"""# 计算非敏感批次的参数更新总和non_sensitive_updates = {name: torch.zeros_like(param) for name, param in self.model.named_parameters()}for batch_idx, batch_update in enumerate(self.batch_updates):if batch_idx not in {sb for epoch_batches in sensitive_batches.values() for sb in epoch_batches}:for name, update in batch_update.items():non_sensitive_updates[name] += update# 更新模型参数:初始参数 + 非敏感批次的更新for name, param in self.model.named_parameters():param.data = self.initial_params[name].data + non_sensitive_updates[name]return self.model# 全局函数:实现 Amnesiac Forget
def amnesiac_unlearning(model_before, test_loader, forgotten_classes, all_data, epochs=10, learning_rate=0.001):"""执行 Amnesiac Unlearning:训练模型,记录参数更新,并撤销与敏感数据相关的更新。:param model_before: 遗忘前的模型。:param test_loader: 测试数据加载器。:param forgotten_classes: 需要遗忘的类别列表。:param all_data: 训练数据集。:param epochs: 训练的总 epoch 数(默认为 10)。:param learning_rate: 优化器的学习率(默认为 0.001)。:return: 遗忘后的模型。"""# 模拟从头训练的过程,并记录批次更新的过程print("模拟重新训练过程,记录批次更新...")temp_model = MLP().to(device) # 初始化一个新模型amnesiac_forget = AmnesiacForget(temp_model, all_data, epochs, learning_rate) # 初始化 AmnesiacForget 类sensitive_batches = amnesiac_forget.train(forgotten_classes) # 训练模型并记录敏感批次# 测试遗忘前的模型性能overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(amnesiac_forget.model, test_loader)print(f"全部准确率: {overall_acc_before:.2f}%, 保留准确率: {retained_acc_before:.2f}%, 遗忘准确率: {forgotten_acc_before:.2f}%")# 应用遗忘:撤销与敏感数据相关的批次更新model_after = amnesiac_forget.unlearn(sensitive_batches)return model_afterdef main():# 超参数设置batch_size = 256forgotten_classes = [0] # 需要遗忘的类别ratio = 1model_name = "ResNet18" # 模型名称# 加载数据if model_name == "MLP":train_loader, test_loader, retain_loader, forget_loader = load_MNIST_data(batch_size, forgotten_classes, ratio)elif model_name == "ResNet18":train_loader, test_loader, retain_loader, forget_loader = load_CIFAR100_data(batch_size, forgotten_classes, ratio)# 初始化模型model_before = init_model(model_name, train_loader)# 在训练之前测试初始模型准确率overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(model_before, test_loader)# 实现遗忘操作print("执行遗忘 Amnesiac...")model_after = amnesiac_unlearning(model_before, test_loader, forgotten_classes, train_loader, epochs=5, learning_rate=0.001)# 测试遗忘后的模型overall_acc_after, forgotten_acc_after, retained_acc_after = test_model(model_after, test_loader)# 输出遗忘前后的准确率变化print(f"Unlearning 前遗忘准确率: {100 * forgotten_acc_before:.2f}%")print(f"Unlearning 后遗忘准确率: {100 * forgotten_acc_after:.2f}%")print(f"Unlearning 前保留准确率: {100 * retained_acc_before:.2f}%")print(f"Unlearning 后保留准确率: {100 * retained_acc_after:.2f}%")if __name__ == "__main__":main()
3 总结
- 高效性:只需撤销与敏感数据相关的参数更新,避免了从头训练模型的高成本。
- 精确性:能够精确删除特定数据的学习信息,特别适合删除少量数据。
- 存储成本:需要存储每个批次的参数更新,存储成本较高,但通常低于从头训练模型的成本。
- 适用场景:适合删除少量数据(如单个样本或少量样本),而不适合删除大量数据(如整个类别)。
相关文章:

【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning
【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning 1 算法原理 论文:Graves, L., Nagisetty, V., & Ganesh, V. (2021). Amnesiac machine learning. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 35, 115…...

Vue 3 30天精进之旅:Day 03 - Vue实例
引言 在前两天的学习中,我们成功搭建了Vue.js的开发环境,并创建了我们的第一个Vue项目。今天,我们将深入了解Vue的核心概念之一——Vue实例。通过学习Vue实例,你将理解Vue的基础架构,掌握数据绑定、模板语法和指令的使…...

【ArcGIS微课1000例】0141:提取多波段影像中的单个波段
文章目录 一、波段提取函数二、加载单波段导出问题描述:如下图所示,img格式的时序NDVI数据有24个波段。现在需要提取某一个波段,该怎样操作? 一、波段提取函数 首先加载多波段数据。点击【窗口】→【影像分析】。 选择需要处理的多波段影像,点击下方的【添加函数】。 在多…...

【第九天】零基础入门刷题Python-算法篇-数据结构与算法的介绍-六种常见的图论算法(持续更新)
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、Python数据结构与算法的详细介绍1.Python中的常用的图论算法2. 图论算法3.详细的图论算法1)深度优先搜索(DFS)2…...

落地 轮廓匹配
个人理解为将一幅不规则的图形,通过最轮廓发现,最大轮廓匹配来确定图像的位置,再通过pt将不规则的图像放在规定的矩形里面,在通过透视变换将不规则的图形放进规则的图像中。 1. findHomography 函数 • Mat h findHomography(s…...

【漫话机器学习系列】064.梯度下降小口诀(Gradient Descent rule of thume)
梯度下降小口诀 为了帮助记忆梯度下降的核心原理和关键注意事项,可以用以下简单口诀来总结: 1. 基本原理 损失递减,梯度为引:目标是让损失函数减少,依靠梯度指引方向。负梯度,反向最短:沿着负…...

JAVA(SpringBoot)集成Kafka实现消息发送和接收。
SpringBoot集成Kafka实现消息发送和接收。 一、Kafka 简介二、Kafka 功能三、POM依赖四、配置文件五、生产者六、消费者 君子之学贵一,一则明,明则有功。 一、Kafka 简介 Kafka 是由 Apache 软件基金会开发的一个开源流处理平台,最初由 Link…...

AI刷题-蛋糕工厂产能规划、优质章节的连续选择
挑两个简单的写写 目录 一、蛋糕工厂产能规划 问题描述 输入格式 输出格式 解题思路: 问题理解 数据结构选择 算法步骤 关键点 最终代码: 运行结果:编辑 二、优质章节的连续选择 问题描述 输入格式 输出格式 解题思路&a…...

在线可编辑Excel
1. Handsontable 特点: 提供了类似 Excel 的表格编辑体验,包括单元格样式、公式计算、数据验证等功能。 支持多种插件,如筛选、排序、合并单元格等。 轻量级且易于集成到现有项目中。 具备强大的自定义能力,可以调整外观和行为…...

什么是词嵌入?Word2Vec、GloVe 与 FastText 的区别
自然语言处理(NLP)领域的核心问题之一,是如何将人类的语言转换成计算机可以理解的数值形式,而词嵌入(Word Embedding)正是为了解决这个问题的重要技术。本文将详细讲解词嵌入的概念及其经典模型(Word2Vec、GloVe 和 FastText)的原理与区别。 1. 什么是词嵌入(Word Em…...

WPS数据分析000010
基于数据透视表的内容 一、排序 手动调动 二、筛选 三、值显示方式 四、值汇总依据 五、布局和选项 不显示分类汇总 合并居中带标签的单元格 空单元格显示 六、显示报表筛选页...

Qt中QVariant的使用
1.使用QVariant实现不同类型数据的相加 方法:通过type函数返回数值的类型,然后通过setValue来构造一个QVariant类型的返回值。 函数: QVariant mainPage::dataPlus(QVariant a, QVariant b) {QVariant ret;if ((a.type() QVariant::Int) &a…...

Avalonia UI MVVM DataTemplate里绑定Command
Avalonia 模板里面绑定ViewModel跟WPF写法有些不同。需要单独绑定Command. WPF里面可以直接按照下面的方法绑定DataContext. <Button Content"Button" Command"{Binding DataContext.ClickCommand, RelativeSource{RelativeSource AncestorType{x:Type User…...

动态规划DP 数字三角型模型 最低通行费用(题目详解+C++代码完整实现)
最低通行费用 原题链接 AcWing 1018. 最低同行费用 题目描述 一个商人穿过一个 NN的正方形的网格,去参加一个非常重要的商务活动。 他要从网格的左上角进,右下角出。每穿越中间 1个小方格,都要花费 1个单位时间。商人必须在 (2N−1)个单位…...

deepseek R1的确不错,特别是深度思考模式
deepseek R1的确不错,特别是深度思考模式,每次都能自我反省改进。比如我让 它写文案: 【赛博朋克版程序员新春密码——2025我们来破局】 亲爱的代码骑士们: 当CtrlS的肌肉记忆遇上抢票插件,当Spring Boot的…...

Linux 常用命令 - sort 【对文件内容进行排序】
简介 sort 命令源于英文单词 “sort”,表示排序。其主要功能是对文本文件中的行进行排序。它可以根据字母、数字、特定字段等不同的标准进行排序。sort 通过逐行读取文件(没有指定文件或指定文件为 - 时读取标准输入)内容,并按照…...

MyBatis最佳实践:提升数据库交互效率的秘密武器
第一章:框架的概述: MyBatis 框架的概述: MyBatis 是一个优秀的基于 Java 的持久框架,内部对 JDBC 做了封装,使开发者只需要关注 SQL 语句,而不关注 JDBC 的代码,使开发变得更加的简单MyBatis 通…...

选择困难?直接生成pynput快捷键字符串
from pynput import keyboard# 文档:https://pynput.readthedocs.io/en/latest/keyboard.html#monitoring-the-keyboard # 博客(pynput相关源码):https://blog.csdn.net/qq_39124701/article/details/145230331 # 虚拟键码(十六进制):https:/…...

DeepSeek-R1:强化学习驱动的推理模型
1月20日晚,DeepSeek正式发布了全新的推理模型DeepSeek-R1,引起了人工智能领域的广泛关注。该模型在数学、代码生成等高复杂度任务上表现出色,性能对标OpenAI的o1正式版。同时,DeepSeek宣布将DeepSeek-R1以及相关技术报告全面开源。…...

国内优秀的FPGA设计公司主要分布在哪些城市?
近年来,国内FPGA行业发展迅速,随着5G通信、人工智能、大数据等新兴技术的崛起,FPGA设计企业的需求也迎来了爆发式增长。很多技术人才在求职时都会考虑城市的行业分布和发展潜力。因此,国内优秀的FPGA设计公司主要分布在哪些城市&a…...

3.日常英语笔记
screening discrepancies 筛选差异 The team found some screening discrepancies in the data. 团队在数据筛选中发现了些差异。 Don’t tug at it ,or it will fall over and crush you. tug 拉,拽,拖 He tugged the door open with all his might…...

基于RIP的MGRE实验
实验拓扑 实验要求 按照图示配置IP地址配置静态路由协议,搞通公网配置MGRE VPNNHRP的配置配置RIP路由协议来传递两端私网路由测试全网通 实验配置 1、配置IP地址 [R1]int g0/0/0 [R1-GigabitEthernet0/0/0]ip add 15.0.0.1 24 [R1]int LoopBack 0 [R1-LoopBack0]i…...

【开源免费】基于Vue和SpringBoot的美食推荐商城(附论文)
本文项目编号 T 166 ,文末自助获取源码 \color{red}{T166,文末自助获取源码} T166,文末自助获取源码 目录 一、系统介绍二、数据库设计三、配套教程3.1 启动教程3.2 讲解视频3.3 二次开发教程 四、功能截图五、文案资料5.1 选题背景5.2 国内…...

Pandas DataFrame 拼接、合并和关联
拼接:使用 pd.concat(),可以沿着行或列方向拼接 DataFrame。 合并:使用 pd.merge(),可以根据一个或多个键进行不同类型的合并(左连接、右连接、全连接、内连接)。 关联:使用 join() 方法,通常在设置了索引的 DataFrame 上进行关联操作。 concat拼接 按列拼接 df1 = …...

【Redis】Redis修改连接数参数
1.重启操作背景 Redis数据库连接数上限,需要修改配置文件里maxclients参数,修改后需重启数据库 1.1、修改操作系统open files参数 1.2、修改redis连接数 2.登录操作系统 登录堡垒机 ssh {ip}3.查看当前状态 3.1、查看操作系统配置 ulimit -a3.2、…...

scratch变魔术 2024年12月scratch三级真题 中国电子学会 图形化编程 scratch三级真题和答案解析
目录 scratch变魔术 一、题目要求 1、准备工作 2、功能实现 二、案例分析 1、角色分析 2、背景分析 3、前期准备 三、解题思路 1、思路分析 2、详细过程 四、程序编写 五、考点分析 六、 推荐资料 1、入门基础 2、蓝桥杯比赛 3、考级资料 4、视频课程 5、py…...

51单片机开发:点阵屏显示数字
实验目标:在8x8的点阵屏上显示数字0。 点阵屏的原理图如下图所示,点阵屏的列接在P0端口,行接在74HC595扩展的DP端口上。 扩展口的使用详见:51单片机开发:IO扩展(串转并)实验-CSDN博客 要让点阵屏显示数字࿰…...

mysql DDL可重入讨论
mysql的bug:当执行 MySQL online DDL 时,期间如有其他并发的 DML 对相同的表进行增量修改,比如 update、insert、insert into … on duplicate key、replace into 等,且增量修改的数据违背唯一约束,那么 DDL 最后都会执…...

DAY01 面向对象回顾、继承、抽象类
学习目标 能够写出类的继承格式public class 子类 extends 父类{}public class Cat extends Animal{} 能够说出继承的特点子类继承父类,就会自动拥有父类非私有的成员 能够说出子类调用父类的成员特点1.子类有使用子类自己的2.子类没有使用,继承自父类的3.子类父类都没有编译报…...

127周一复盘 (165)玩法与难度思考
1.上午测试,小改了点东西, 基本等于啥也没干。 匆忙赶往车站。 从此进入春节期间,没有开发,而思考与设计。 2.火车上思考玩法与难度的问题。 目前的主流作法实际上并不完全符合不同玩家的需求, 对这方面还是要有自…...