联邦学习:按Dirichlet分布划分Non-IID样本

1 导引

我们在《Python中的随机采样和概率分布(二)》介绍了如何用Python现有的库对一个概率分布进行采样,其中的Dirichlet分布大家一定不会感到陌生,这篇博客我们来更详细地介绍Dirichlet分布的性质及其在联邦学习领域的应用。

2 Dirichlet分布及其性质

Dirichlet分布[1]是定义在(\mathbb{R}^N)上的概率密度。Dirichlet分布以度量(\bm{u})(所有系数(\bm{u}i>0)的一个向量)为参数,可将它写为(\bm{u}=\alpha \bm{m}),这里(\bm{m})是在(N)个分量上的归一化度量((\sum{i=1}^N m_i = 1), (m_i > 0)),且(\alpha)是一个正数。Dirichlet分布的概率密度函数由下式给出:

[f(\bm{x} | \alpha \bm{m}) \propto \prod_{i=1}^{N} x_{i}^{\alpha m_i -1} ]

注意,对(\bm{X} = (X_1, \cdots, X_N)\sim \text{Dir}(\alpha \bm{m})),有(X_i>0 , \sum_{i=1}^N X_i = 1)。

向量(\bm{m})是随机向量(\bm{X})的期望:

[\mathbb{E}(\bm{X}) = \int f(\bm{x} | \alpha \bm{m}) \bm{x} \mathrm{d} \bm{x} = \bm{m} ]

在物理意义上,Dirichlet分布中参数(\alpha)的作用主要体现在两个方面。首先,(\alpha)度量了这个分布的 锐度(sharpness),也即测量我们分布中的典型样本(\bm{X})与其均值(\bm{m})相差多远,就像高斯分布中精度(\tau=1/\sigma^2)度量了样本与它的均值偏差多远一样。 一个大的(\alpha) 值会使得(\bm{X}) 的分布在(\bm{m}) 附近急剧出现尖峰(后文我们会提到,在联邦数据划分中,这将导致不同标签在客户端的分布更为 同构)。下图就体现了(\alpha)对(\bm{X})分布的影响:

联邦学习:按Dirichlet分布划分Non-IID样本

注意我们这里是从满足(N=3)的分布中采样1000个3维样本点,两个轴表示(X_1)和(X_2),(X_3)在可视化中并不使用。

这里附上可视化的代码,感兴趣的童鞋可下来自行尝试:

import numpy as np
import matplotlib.pyplot as plt

us = [(0.1, 0.1, 0.1), (1, 1, 1), (10, 10, 10)]  # 3组不同的u=alpha*m参数
points = [[] for i in range(3)]
for i in range(3):
    points[i] = np.random.dirichlet(us[i], size=100)

xs, ys = [[] for i in range(3)], [[] for i in range(3)]
for i in range(3):
    xs[i], ys[i], _ = list(zip(*points[i]))

fig, axs = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for i in range(3):
    axs[i].set_title(f"$αm={us[i]}$")
    axs[i].scatter(xs[i], ys[i])
    axs[i].set_xlabel("$X_1$")
    axs[i].set_ylabel("$X_2$")

plt.suptitle(r"The Display of $X_1, X_2$ in $(X_1, X_2, X_3)$")
plt.show()

3 Dirichlet分布在联邦学习中的应用

3.1 划分不独立同分布(Non-IID)数据集

我们在联邦学习中,经常会假设不同client间的数据集不满足独立同分布(Non-IID)。那么我们如何将一个现有的数据集按照Non-IID划分呢?我们知道带标签样本的生成分布看可以表示为(p(\bm{x}, y)),我们进一步将其写作(p(\bm{x}, y)=p(\bm{x}|y)p(y))。其中如果要估计(p(\bm{x}|y))的计算开销非常大,但估计(p(y))的计算开销就很小。所有我们按照样本的标签分布来对样本进行Non-IID划分是一个非常高效、简便的做法。

总而言之,我们采取的算法思路是尽量让每个client上的样本标签分布不同。我们设有(K)个类别标签,(N)个client,每个类别标签的样本需要按照不同的比例划分在不同的client上。我们设矩阵(\bm{X}\in \mathbb{R}^{K*N})为类别标签分布矩阵,其行向量(\bm{x}_k\in \mathbb{R}^N)表示类别(k)在不同client上的概率分布向量(每一维表示(k)类别的样本划分到不同client上的比例),该随机向量就采样自Dirichlet分布(第一次采用Dirichlet分布来划分数据集的论文为《Measuring the Effects of Non-Identical Data Distribution for Federated Visual Classification》[2])。

据此,我们可以写出以下的划分算法:

def dirichlet_split_noniid(train_labels, alpha, n_clients):
    '''
    按照参数为alpha的Dirichlet分布将样本索引集合划分为n_clients个子集
    '''
    n_classes = train_labels.max()+1
    # (K, N) 类别标签分布矩阵X,记录每个类别划分到每个client去的比例
    label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes)
    # (K, ...) 记录K个类别对应的样本索引集合
    class_idcs = [np.argwhere(train_labels == y).flatten()
                  for y in range(n_classes)]

    # 记录N个client分别对应的样本索引集合
    client_idcs = [[] for _ in range(n_clients)]
    for k_idcs, fracs in zip(class_idcs, label_distribution):
        # np.split按照比例fracs将类别为k的样本索引k_idcs划分为了N个子集
        # i表示第i个client,idcs表示其对应的样本索引集合idcs
        for i, idcs in enumerate(np.split(k_idcs,
                                          (np.cumsum(fracs)[:-1]*len(k_idcs)).

                                          astype(int))):
            client_idcs[i] += [idcs]

    client_idcs = [np.concatenate(idcs) for idcs in client_idcs]

    return client_idcs

其中 np.random.dirichlet函数的具体用法大家可以参见我的上一篇博客《Python中的随机采样和概率分布(二)》和numpy文档《numpy.random.dirichlet函数》[3],此处不再赘述。

3.2 算法测试与可视化呈现

接下来我们在EMNIST数据集上调用该函数进行测试,并进行可视化呈现。我们设client数量(N=10),Dirichlet概率分布的参数(\alpha=1.0)(也是我们联邦学习常用的设置),(\bm{m}\in \mathbb{R}^N)在我们这里表示client的先验分布,我们规定是均匀分布(\bm{m}= (1, 1, \cdots, 1))。数据集划分的可视化呈现如下:

import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets
from torch.utils.data import ConcatDataset

n_clients = 10
dirichlet_alpha = 1.0
seed = 42

if __name__ == "__main__":
    np.random.seed(seed)
    train_data = datasets.EMNIST(
        root=".", split="byclass", download=True, train=True)
    test_data = datasets.EMNIST(
        root=".", split="byclass", download=True, train=False)

    classes = train_data.classes
    n_classes = len(classes)

    labels = np.concatenate(
        [np.array(train_data.targets), np.array(test_data.targets)], axis=0)
    dataset = ConcatDataset([train_data, test_data])

    # 我们让每个client不同label的样本数量不同,以此做到Non-IID划分
    client_idcs = dirichlet_split_noniid(
        labels, alpha=dirichlet_alpha, n_clients=n_clients)

    # 展示不同label划分到不同client的情况
    plt.figure(figsize=(12, 8))
    plt.hist([labels[idc]for idc in client_idcs], stacked=True,
             bins=np.arange(min(labels)-0.5, max(labels) + 1.5, 1),
             label=["Client {}".format(i) for i in range(n_clients)],
             rwidth=0.5)
    plt.xticks(np.arange(n_classes), train_data.classes)
    plt.xlabel("Label type")
    plt.ylabel("Number of samples")
    plt.legend(loc="upper right")
    plt.title("Display Label Distribution on Different Clients")
    plt.show()

最终的可视化结果如下:

联邦学习:按Dirichlet分布划分Non-IID样本

可以看到,62个类别标签在不同client上的分布确实不同,证明我们的样本划分算法是有效的。

我们尝试将(\alpha)设置为(0.1),可以看到标签分布的异构程度确实有所加大(结合我们前面所讲的Dirichlet分布性质,也就是表示标签概率分布的样本点变得分散):

联邦学习:按Dirichlet分布划分Non-IID样本

最后,如果我们想将(x)轴变为client,(y)轴变为标签类别,即更明确地可视化不同client上的标签分布情况,我们可以将数据可视化部分的代码修改如下:

展示不同client上的label分布
plt.figure(figsize=(12, 8))
label_distribution = [[] for _ in range(n_classes)]
for c_id, idc in enumerate(client_idcs):
    for idx in idc:
        label_distribution[labels[idx]].append(c_id)

plt.hist(label_distribution, stacked=True,
            bins=np.arange(-0.5, n_clients + 1.5, 1),
            label=classes, rwidth=0.5)
plt.xticks(np.arange(n_clients), ["Client %d" %
                                    c_id for c_id in range(n_clients)])
plt.xlabel("Client ID")
plt.ylabel("Number of samples")
plt.legend()
plt.title("Display Label Distribution on Different Clients")
plt.show()

此时我们可以看到不同client上的标签分布情况如下图所示:

联邦学习:按Dirichlet分布划分Non-IID样本

这里有个很尴尬问题:类别数量太多,导致图右边的图例放不下了。因此建议如果采用这种可视化方法的话最好选择类别数量少的数据集,比如CIFAR10。

参考

Original: https://www.cnblogs.com/orion-orion/p/15897853.html
Author: orion-orion
Title: 联邦学习:按Dirichlet分布划分Non-IID样本

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/805420/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • pandas的拼接操作

    pandas的拼接分为两种: 级联:pd.concat, pd.append 合并:pd.merge 1. 使用pd.concat()级联 def make_df(index,co…

    Python 2023年8月7日
    044
  • 使用用户名密码或ssl证书获取zabbix-api token,批量操作管理zabbix-server

    个人博客地址 Zabbix API 允许你以编程方式检索和修改 Zabbix 的配置,并提供对历史数据的访问。它广泛用于: 创建新的应用程序以使用Zabbix; 将Zabbix与第…

    Python 2023年6月6日
    028
  • MySQL基础笔记

    单行注释:#注释文字 单行注释:–注释文字 多行注释:/*注释文字*/ 启动服务:net start 服务名 停止服务:net stop 服务名 退出:exit 登录:mysql…

    Python 2023年6月6日
    057
  • [ Shell ] 用 while + case 实现 GetOptions 效果

    https://www.cnblogs.com/yeungchie/ 可以用 getopt,但我还是喜欢自己写这个过程,便于我控制更多细节。 下面要实现的效果是,从命令行参数中分析…

    Python 2023年6月16日
    078
  • Unity-粒子特效

    粒子特效 创建粒子系统 * 模块 创建粒子系统 方法一:创建一个空物体(游戏物体),给物体加上Particle System组件。方法二:在Hierarchy视图右击 选择 Eff…

    Python 2023年9月29日
    060
  • Oauth2系列2:授权码模式

    目录 传送门 再次重申oauth2的定义 定义 作用 标准授权码流程 4个参与角色 资源拥有者 客户端 授权服务 受保护资源 授权码流程 引导授权 获取code 用授权码code换…

    Python 2023年10月1日
    040
  • 你想制作一款属于自己的游戏吗?

    💂 个人网站:【海拥 】【摸鱼游戏 】【神级源码资源网站 】 🤟 前端学习课程:👉【28个案例趣学前端 】【400个JS面试题 】 💅 想寻找共同学习交流、摸鱼划水的小伙伴,请点击…

    Python 2023年10月10日
    034
  • Python xx直聘 | 数据分析师岗位 | 分析可视化

    关注微信公共号:小程在线 关注CSDN博客:程志伟的博客 import numpy as npimport pandas as pdfrom pyecharts.charts im…

    Python 2023年8月8日
    035
  • 金融数据挖掘—科学计算语言—基础篇2

    金融数据挖掘—科学计算语言—基础篇2 一、Python基础 二、Numpy科学计算包 (一)数组的创建与运算 数组是一个同一类型的数据的有限类集合列表是由一系列按特定顺序排列的元素…

    Python 2023年8月26日
    038
  • 现代 CSS 之高阶图片渐隐消失术

    在过往,我们想要实现一个图片的渐隐消失。最常见的莫过于整体透明度的变化,像是这样: div { width: 300px; height: 300px; background: u…

    Python 2023年10月13日
    032
  • 基于Kubernetes容器云平台的CI/CD

    基于 Kubernetes 实现 CI/CD 配置,其实和往常那些 CI/CD 配置并没有太大区别。都是通过 提交代码,拉取代码,构建代码,发布代码来实现的。 只不过要是通过 K8…

    Python 2023年9月27日
    035
  • python带你采集桌游、剧本杀游戏店数据信息~

    前言 嗨喽~大家好呀,这里是魔王呐 ! 又是学习的一天,让我们开始叭~ ; 环境使用: Python 3.8 Pycharm 模块使用: requests >>>…

    Python 2023年9月20日
    063
  • Linux下ifconfig未找到命令

    ifconfig是linux中用于显示或配置网络设备(网络接口卡)的命令。 学习过程中使用 CentOS-7-x86_64-Minimal-1708.iso 这类最小化安装Cent…

    Python 2023年6月12日
    064
  • python学习路线

    Python ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ Python – 字符串 encode ​​ ​​ ​​ …

    Python 2023年5月24日
    063
  • 写技术博客的一些心得分享

    写技术博客的一些心得分享 作者: Grey 原文地址: 博客园:写技术博客的一些心得分享 CSDN:写技术博客的一些心得分享 写博客有很多好处,见[BetterExplained]…

    Python 2023年10月22日
    031
  • 安装 scikit-learn之先安装numpy、SciPy、joblib之各种跳坑

    先来看看安装Scikit-learn 要求:Python (>= 3.5),NumPy (>= 1.11.0),SciPy (>= 0.17.0),joblib …

    Python 2023年8月29日
    086
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球