paddle2.3-基于联邦学习实现FedAVg算法-CNN
目录
1. 联邦学习介绍
2. 实验流程
3. 数据加载
4. 模型构建
5. 数据采样函数
6. 模型训练
1. 联邦学习介绍
联邦学习是一种分布式机器学习方法,中心节点为server(服务器),各分支节点为本地的client(设备)。联邦学习的模式是在各分支节点分别利用本地数据训练模型,再将训练好的模型汇合到中心节点,获得一个更好的全局模型。
联邦学习的提出是为了充分利用用户的数据特征训练效果更佳的模型,同时,为了保证隐私,联邦学习在训练过程中,server和clients之间通信的是模型的参数(或梯度、参数更新量),本地的数据不会上传到服务器。
本项目主要是升级1.8版本的联邦学习fedavg算法至2.3版本,内容取材于基于PaddlePaddle实现联邦学习算法FedAvg - 飞桨AI Studio星河社区
2. 实验流程
联邦学习的基本流程是:
1. server初始化模型参数,所有的clients将这个初始模型下载到本地;
2. clients利用本地产生的数据进行SGD训练;
3. 选取K个clients将训练得到的模型参数上传到server;
4. server对得到的模型参数整合,所有的clients下载新的模型。
5. 重复执行2-5,直至收敛或达到预期要求
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import random
import time
import paddle
import paddle.nn as nn
import numpy as np
from paddle.io import Dataset,DataLoader
import paddle.nn.functional as F
3. 数据加载
mnist_data_train=np.load('data/data2489/train_mnist.npy')
mnist_data_test=np.load('data/data2489/test_mnist.npy')
print('There are {} images for training'.format(len(mnist_data_train)))
print('There are {} images for testing'.format(len(mnist_data_test)))
# 数据和标签分离(便于后续处理)
Label=[int(i[0]) for i in mnist_data_train]
Data=[i[1:] for i in mnist_data_train]
There are 60000 images for training
There are 10000 images for testing
4. 模型构建
class CNN(nn.Layer):def __init__(self):super(CNN,self).__init__()self.conv1=nn.Conv2D(1,32,5)self.relu = nn.ReLU()self.pool1=nn.MaxPool2D(kernel_size=2,stride=2)self.conv2=nn.Conv2D(32,64,5)self.pool2=nn.MaxPool2D(kernel_size=2,stride=2)self.fc1=nn.Linear(1024,512)self.fc2=nn.Linear(512,10)# self.softmax = nn.Softmax()def forward(self,inputs):x = self.conv1(inputs)x = self.relu(x)x = self.pool1(x)x = self.conv2(x)x = self.relu(x)x = self.pool2(x)x=paddle.reshape(x,[-1,1024])x = self.relu(self.fc1(x))y = self.fc2(x)return y
5. 数据采样函数
# 均匀采样,分配到各个client的数据集都是IID且数量相等的
def IID(dataset, clients):num_items_per_client = int(len(dataset)/clients)client_dict = {}image_idxs = [i for i in range(len(dataset))]for i in range(clients):client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False)) # 为每个client随机选取数据image_idxs = list(set(image_idxs) - client_dict[i]) # 将已经选取过的数据去除client_dict[i] = list(client_dict[i])return client_dict
# 非均匀采样,同时各个client上的数据分布和数量都不同
def NonIID(dataset, clients, total_shards, shards_size, num_shards_per_client):shard_idxs = [i for i in range(total_shards)]client_dict = {i: np.array([], dtype='int64') for i in range(clients)}idxs = np.arange(len(dataset))data_labels = Labellabel_idxs = np.vstack((idxs, data_labels)) # 将标签和数据ID堆叠label_idxs = label_idxs[:, label_idxs[1,:].argsort()]idxs = label_idxs[0,:]for i in range(clients):rand_set = set(np.random.choice(shard_idxs, num_shards_per_client, replace=False)) shard_idxs = list(set(shard_idxs) - rand_set)for rand in rand_set:client_dict[i] = np.concatenate((client_dict[i], idxs[rand*shards_size:(rand+1)*shards_size]), axis=0) # 拼接return client_dict
class MNISTDataset(Dataset):def __init__(self, data,label):self.data = dataself.label = labeldef __getitem__(self, idx):image=np.array(self.data[idx]).astype('float32')image=np.reshape(image,[1,28,28])label=np.array(self.label[idx]).astype('int64')return image, labeldef __len__(self):return len(self.label)
6. 模型训练
class ClientUpdate(object):def __init__(self, data, label, batch_size, learning_rate, epochs):dataset = MNISTDataset(data,label)self.train_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True,drop_last=True)self.learning_rate = learning_rateself.epochs = epochsdef train(self, model):optimizer=paddle.optimizer.SGD(learning_rate=self.learning_rate,parameters=model.parameters())criterion = nn.CrossEntropyLoss(reduction='mean')model.train()e_loss = []for epoch in range(1,self.epochs+1):train_loss = []for image,label in self.train_loader:# image=paddle.to_tensor(image)# label=paddle.to_tensor(label.reshape([label.shape[0],1]))output=model(image)loss= criterion(output,label)# print(loss)loss.backward()optimizer.step()optimizer.clear_grad()train_loss.append(loss.numpy()[0])t_loss=sum(train_loss)/len(train_loss)e_loss.append(t_loss)total_loss=sum(e_loss)/len(e_loss)return model.state_dict(), total_loss
train_x = np.array(Data)
train_y = np.array(Label)
BATCH_SIZE = 32
# 通信轮数
rounds = 100
# client比例
C = 0.1
# clients数量
K = 100
# 每次通信在本地训练的epoch
E = 5
# batch size
batch_size = 10
# 学习率
lr=0.001
# 数据切分
iid_dict = IID(mnist_data_train, 100)
def training(model, rounds, batch_size, lr, ds,L, data_dict, C, K, E, plt_title, plt_color):global_weights = model.state_dict()train_loss = []start = time.time()# clients与server之间通信for curr_round in range(1, rounds+1):w, local_loss = [], []m = max(int(C*K), 1) # 随机选取参与更新的clientsS_t = np.random.choice(range(K), m, replace=False)for k in S_t:# print(data_dict[k])sub_data = ds[data_dict[k]]sub_y = L[data_dict[k]]local_update = ClientUpdate(sub_data,sub_y, batch_size=batch_size, learning_rate=lr, epochs=E)weights, loss = local_update.train(model)w.append(weights)local_loss.append(loss)# 更新global weightsweights_avg = w[0]for k in weights_avg.keys():for i in range(1, len(w)):# weights_avg[k] += (num[i]/sum(num))*w[i][k]weights_avg[k]=weights_avg[k]+w[i][k] weights_avg[k]=weights_avg[k]/len(w)global_weights[k].set_value(weights_avg[k])# global_weights = weights_avg# print(global_weights)#模型加载最新的参数model.load_dict(global_weights)loss_avg = sum(local_loss) / len(local_loss)if curr_round % 10 == 0:print('Round: {}... \tAverage Loss: {}'.format(curr_round, np.round(loss_avg, 5)))train_loss.append(loss_avg)end = time.time()fig, ax = plt.subplots()x_axis = np.arange(1, rounds+1)y_axis = np.array(train_loss)ax.plot(x_axis, y_axis, 'tab:'+plt_color)ax.set(xlabel='Number of Rounds', ylabel='Train Loss',title=plt_title)ax.grid()fig.savefig(plt_title+'.jpg', format='jpg')print("Training Done!")print("Total time taken to Train: {}".format(end-start))return model.state_dict()#导入模型
mnist_cnn = CNN()
mnist_cnn_iid_trained = training(mnist_cnn, rounds, batch_size, lr, train_x,train_y, iid_dict, C, K, E, "MNIST CNN on IID Dataset", "orange")
Round: 10... Average Loss: [0.024]
Round: 20... Average Loss: [0.015]
Round: 30... Average Loss: [0.008]
Round: 40... Average Loss: [0.003]
Round: 50... Average Loss: [0.004]
Round: 60... Average Loss: [0.002]
Round: 70... Average Loss: [0.002]
Round: 80... Average Loss: [0.002]
Round: 90... Average Loss: [0.001]
Round: 100... Average Loss: [0.]
Training Done!
Total time taken to Train: 759.6239657402039
相关文章:
paddle2.3-基于联邦学习实现FedAVg算法-CNN
目录 1. 联邦学习介绍 2. 实验流程 3. 数据加载 4. 模型构建 5. 数据采样函数 6. 模型训练 1. 联邦学习介绍 联邦学习是一种分布式机器学习方法,中心节点为server(服务器),各分支节点为本地的client(设备&#…...
nuiapp保存canvas绘图
要保存一个 Canvas 绘图,可以使用以下步骤: 获取 Canvas 元素和其绘图上下文: var canvas document.getElementById("myCanvas"); var ctx canvas.getContext("2d");使用 Canvas 绘图 API 绘制图形。 使用 toDataUR…...
Object.defineProperty()方法详解,了解vue2的数据代理
假期第一篇,对于基础的知识点,我感觉自己还是很薄弱的。 趁着假期,再去复习一遍 Object.defineProperty(),对于这个方法,更多的还是停留在面试的时候,面试官问你vue2和vue3区别的时候,不免要提一提这个方法…...
Linux 磁盘管理
Linux 系统的磁盘管理直接关系到整个系统的性能表现。磁盘管理常用三个命令为: df、du 和 fdisk。 df df(英文全称:disk free)。df 命令用于显示磁盘空间的使用情况,包括文件系统的挂载点、总容量、已用空间、可用空间…...
大数据与人工智能的未来已来
大数据与人工智能的定义 大数据: 大数据指的是规模庞大、复杂性高、多样性丰富的数据集合。这些数据通常无法通过传统的数据库管理工具来捕获、存储、管理和处理。大数据的特点包括"3V": 大量(Volume):大数…...
【AI视野·今日Robot 机器人论文速览 第四十一期】Tue, 26 Sep 2023
AI视野今日CS.Robotics 机器人学论文速览 Tue, 26 Sep 2023 Totally 73 papers 👉上期速览✈更多精彩请移步主页 Daily Robotics Papers Extreme Parkour with Legged Robots Authors Xuxin Cheng, Kexin Shi, Ananye Agarwal, Deepak Pathak人类可以通过以高度动态…...
[NOIP2012 提高组] 开车旅行
[NOIP2012 提高组] 开车旅行 题目描述 小 A \text{A} A 和小 B \text{B} B 决定利用假期外出旅行,他们将想去的城市从 $1 $ 到 n n n 编号,且编号较小的城市在编号较大的城市的西边,已知各个城市的海拔高度互不相同,记城市 …...
数据库设计流程---以案例熟悉
案例名字:宠物商店系统 课程来源:点击跳转 信息->概念模型->数据模型->数据库结构模型 将现实世界中的信息转换为信息世界的概念模型(E-R模型) 业务逻辑 构建 E-R 图 确定三个实体:用户、商品、订单...
Miniconda创建paddlepaddle环境
1、conda env list 2、conda create --name paddle_env python3.8 --channel https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 3、activate paddle_env 4、python -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple 5、pip install "p…...
postgresql实现单主单从
实现步骤 1.主库创建一个有复制权限的用户 CREATE ROLE 用户名login # 有登录权限的角色即是用户replication #复制权限 encrypted password 密码;2.主库配置开放从库外部访问权限 修改 pg_hba.conf 文件 (相当于开放防火墙) # 类型 数据库 …...
提取PDF数据:Documents for PDF ( GcPdf )
在当今数据驱动的世界中,从 PDF 文档中无缝提取结构化表格数据已成为开发人员的一项关键任务。借助GrapeCity Documents for PDF ( GcPdf ),您可以使用 C# 以编程方式轻松解锁这些 PDF 中隐藏的信息宝藏。 考虑一下 PDF(最常用的文档格式之一…...
adb连接切换到模拟器端口
查看连接状态 adb devices出现以下情况 C:\Users\22560>adb devices List of devices attached 127.0.0.1:5555 offline emulator-5554 device可以发现我们想要连接的雷电模拟器的5555端口目前没有连接,只有emulator-5554被连接了,此时我们需要关…...
为何每个开发者都在谈论Go?
目录 一、引言Go的历史回顾关键时间节点 使用场景Go的语言地位技术社群与企业支持资源投入和生态系统 二、简洁的语法结构基本组成元素变量声明与初始化代码示例 类型推断函数与返回值代码示例输出 接口与结构体:组合而非继承错误处理:明确而不是异常小结…...
【Leetcode】 501. 二叉搜索树中的众数
给你一个含重复值的二叉搜索树(BST)的根节点 root ,找出并返回 BST 中的所有 众数(即,出现频率最高的元素)。 如果树中有不止一个众数,可以按 任意顺序 返回。 假定 BST 满足如下定义…...
怎样给Ubuntu系统安装vmware-tools
首先我要告诉你:Ubuntu无法安装vmware-tools,之所以这么些是因为我一开始也是这样认为的,vmware-tools是给Windows系统准备的我认为,毕竟Windows占有率远远高于Linux,这也可以理解。 那么怎么样实现Ubuntu虚拟机跟Wind…...
DDS信号发生器波形发生器VHDL
名称:DDS信号发生器波形发生器 软件:Quartus 语言:VHDL 要求: 在EDA平台中使用VHDL语言为工具,设计一个常见信号发生电路,要求: 1. 能够产生锯齿波,方波,三角波&…...
Python3操作SQLite3创建表主键自增长|CRUD基本操作
Win11查看安装的Python路径及安装的库 Python PEP8 代码规范常见问题及解决方案 Python3操作MySQL8.XX创建表|CRUD基本操作 Python3操作SQLite3创建表主键自增长|CRUD基本操作 anaconda3最新版安装|使用详情|Error: Please select a valid Python interpreter Python函数绘…...
B. Comparison String
题目: 样例: 输入 4 4 <<>> 4 >><< 5 >>>>> 7 <><><><输出 3 3 6 2 思路: 由题意,条件是 又因为要使用尽可能少的数字,这是一道贪心题,所以…...
python端口扫描
扫描所有端口 import socket, threading, os, timedef port_thread(ip, start, step, timeout):for port in range(start, start step):s socket.socket()s.settimeout(timeout)try:s.connect((ip, port))print(f"port[{port}] 可用")except Exception as e:# pri…...
国庆第二天
#include<th.h>#define ERR_MSG(msg) do{\fprintf(stderr,"__%d__",__LINE__);\perror(msg);\ }while(0)#define PORT 6666 #define IP "192.168.2.3"//键盘输入事件 int serverkeyboard(fd_set readfds) {char buf[128] "";int sndfd -…...
Java安全之servlet内存马分析
目录 前言 什么是中间键 了解jsp的本质 理解servlet运行机制 servlet的生命周期 Tomcat总体架构 查看Context 的源码 servlet内存马实现 参考 前言 php和jsp一句话马我想大家都知道,早先就听小伙伴说过一句话木马已经过时了,现在是内存马的天下…...
2023年第二十届中国研究生数学建模竞赛总结与分享
今天是国庆节,祝祖国繁荣富强。正好也学习不下去,就想着写写博客,总结一下自己在参加2023年第20届中国研究生数学建模比赛的一些感受。 目录 1.基本介绍 2.比赛分享 1.基本介绍 1. 竞赛时间:竞赛定于2023年9月22日8:00至2023年9…...
Web前端-Vue2+Vue3基础入门到实战项目-Day1(初始Vue, Vue指令, 小黑记事本)
Web前端-Vue2Vue3基础入门到实战项目-Day1 Vue快速上手创建一个Vue实例插值表达式Vue响应式特性 Vue指令指令初识 和 v-htmlv-show 和 v-ifv-else 和 v-else-ifv-on内联语句methods处理函数调用传参 v-bind案例 - 波仔的学习之旅v-forv-for基本使用案例 - 小黑的书架v-for的key…...
Sentinel学习(2)——sentinel的使用,引入依赖和配置 对消费者进行流控 对生产者进行熔断降级
前言 Sentinel 是面向分布式、多语言异构化服务架构的流量治理组件,主要以流量为切入点,从流量路由、流量控制、流量整形、熔断降级、系统自适应过载保护、热点流量防护等多个维度来帮助开发者保障微服务的稳定性。 本篇博客介绍sentinel的使用&#x…...
springboot 简单配置mongodb多数据源
准备工作: 本地mongodb一个创建两个数据库 student 和 student-two 所需jar包: # springboot基于的版本 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId>&l…...
西门子S7-1200使用LRCF通信库与安川机器人进行EthernetIP通信的具体方法示例
西门子S7-1200使用LRCF通信库与安川机器人进行EthernetIP通信的具体方法示例 准备条件: PLC:S7-1200 1214C DC/DC/DC 系统版本4.5及以上。 机器人控制柜:安川YRC1000。 软件:TIA V17 PLC做主站,机器人做从站。 具体方法可参考以下内容: 使用的库文件为西门子 1200系列…...
pytorch第一天(tensor数据和csv数据的预处理)lm老师版
tensor数据: import torch import numpyx torch.arange(12) print(x) print(x.shape) print(x.numel())X x.reshape(3, 4) print(X)zeros torch.zeros((2, 3, 4)) print(zeros)ones torch.ones((2,3,4)) print(ones)randon torch.randn(3,4) print(randon)a …...
CSP-J第二轮试题-2021年-1.2题
文章目录 参考:总结 [CSP-J 2021] 分糖果题目背景题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 样例 #2样例输入 #2样例输出 #2 样例 #3样例输入 #3样例输出 #3 提示答案1答案2-优化 [CSP-J 2021] 插入排序题目描述输入格式输出格式样例 #1样例输入 #1样…...
怒刷LeetCode的第16天(Java版)
目录 第一题 题目来源 题目内容 解决方法 方法一:迭代 方法二:模拟 方法三:循环模拟 方法四:传递 第二题 题目来源 题目内容 解决方法 方法一:回溯 方法二:枚举优化 第三题 题目来源 题目…...
让大脑自由
前言 作者写这本书的目的是什么? 教会我们如何让大脑更好地为自己工作。 1 大脑的运行机制是怎样的? 大脑的基本运行机制是神经元之间通过突触传递信息,神经元的兴奋和抑制状态决定了神经网络的运行和信息处理,神经网络可以通过…...
wordpress分享到微博才能看到/信息发布平台推广有哪些
一、选项卡 如今很多应用都会使用碎片以便在同一个活动中能够显示多个不同的视图。在Android 3.0 以上的版本中,我们已经可以使用ActionBar提供的Tab来实现这种效果,而不需要我们自己去实现碎片的切换。ActionBar默认是不具备选项卡功能的,所…...
wordpress例子/百度爱采购服务商查询
首先,需要打开终端,若桌面没有终端图标,则可以使用ctrlaltt即可,之后将终端锁定在任务栏。 其次,需要获取root权限,使用命令 su root 密码(我的是root)。 进入之后,…...
网站建设策划 流程/在线一键建站系统
Would you like to change the user interface language in any edition of Windows 7 or Vista on your computer? Here’s a free app that can help you do this quickly and easily. 您想在计算机的任何版本的Windows 7或Vista中更改用户界面语言吗? 这是一个…...
wordpress建小说站收费/百度推广关键词质量度
python实现一组典型数据格式转换发布时间:2020-09-26 19:00:07来源:脚本之家阅读:105作者:杰瑞26本文实例为大家分享了一组典型数据格式转换的python实现代码,供大家参考,具体内容如下有一组源数据…...
网上代做论文的网站好/优化系统的软件
在当今的智能手机市场,iPhone和安卓形成了两大操作系统阵营,但是由于两者的操作系统不同、用户群体不同,导致使用手机的时候形成了不同的使用习惯,今天小胖就来总结一下两者在使用手机时的一些区别。清后台因为早年的安卓手机性能…...
做网站一屏一屏的/厦门百度开户
OOP即Object-Oriented Programming(面向对象程序设计)就是选用面向对象的程序设计语言(Object-Oriented Programming Language,OOPL),采用对象,类及其相关概念所进行的程序设计。 对面向对象的理解本身是无止境的&…...