Professional Documents
Culture Documents
Abstract—The widespread deployment of machine learning However, unlike in a server-based environment, distributed
applications in ubiquitous environments has sparked interests in machine learning on mobile devices is faced with a few
exploiting the vast amount of data stored on mobile devices. To fundamental challenges, such as the limited connectivity of
preserve data privacy, Federated Learning has been proposed to
learn a shared model by performing distributed training locally wireless networks, unstable availability of mobile devices, and
on participating devices and aggregating the local models into the non-IID distributions of local datasets which are hard to
a global one. However, due to the limited network connectivity characterize statistically [1]–[4]. Due to these reasons, it is not
of mobile devices, it is not practical for federated learning to as practical to perform machine learning on all participating
perform model updates and aggregation on all participating devices simultaneously in a synchronous or semi-synchronous
devices in parallel. Besides, data samples across all devices are
usually not independent and identically distributed (IID), posing manner as is in a parameter-server cluster. Therefore, as
additional challenges to the convergence and speed of federated a common practice in federated learning, only a subset of
learning. devices are randomly selected to participate in each round
In this paper, we propose FAVOR, an experience-driven control of model training to avoid long-tailed waiting times due to
framework that intelligently chooses the client devices to partic- unstable network conditions and straggler devices [1, 5].
ipate in each round of federated learning to counterbalance the
bias introduced by non-IID data and to speed up convergence. Since the completion time of federated learning is dom-
Through both empirical and mathematical analysis, we observe inated by the communication time, Konečnỳ et al. [6, 7]
an implicit connection between the distribution of training data proposed structured and sketched updates to decrease the
on a device and the model weights trained based on those time it takes to complete a communication round. McMa-
data, which enables us to profile the data distribution on that han et al. [1] presented the Federated Averaging (F EDAVG)
device based on its uploaded model weights. We then propose
a mechanism based on deep Q-learning that learns to select a algorithm, which aims to reduce the number communication
subset of devices in each communication round to maximize a rounds required by averaging model weights from client de-
reward that encourages the increase of validation accuracy and vices instead of applying traditional gradient descent updates.
penalizes the use of more communication rounds. With extensive F EDAVG has also been deployed in a large-scale system [4]
experiments performed in PyTorch, we show that the number to test its effectiveness.
of communication rounds required in federated learning can
be reduced by up to 49% on the MNIST dataset, 23% on However, existing federated learning methods have not
FashionMNIST, and 42% on CIFAR-10, as compared to the solved the statistical challenges posed by heterogeneous local
Federated Averaging algorithm. datasets. Since different users have different device usage
patterns, the data samples and labels located on any in-
I. I NTRODUCTION dividual device may follow a different distribution, which
cannot represent the global data distribution. It has been
As the primary computing resource for billions of users,
recently pointed out that the performance of federated learning,
mobile devices are constantly generating massive volumes
especially F EDAVG, may significantly degrade in the presence
of data, such as photos, voices, and keystrokes, which are
of non-IID data, in terms of the model accuracy and the
of great value for training machine learning models. How-
communication rounds required for convergence [8]–[10]. In
ever, due to privacy concerns, collecting private data from
particular, F EDAVG randomly selects a subset of devices in
mobile devices to the cloud for centralized model training
each round and averages their local model weights to update
is not always possible. To efficiently utilize end-user data,
the global model. The randomly selected local datasets may
Federated Learning (FL) has emerged as a new paradigm of
not reflect the true data distribution in a global view, which
distributed machine learning that orchestrates model training
inevitably incurs biases to global model updates. Furthermore,
across mobile devices [1]. With federated learning, locally
the models locally trained on non-IID data can be significantly
trained models are communicated to a server for aggregation,
different from each other. Aggregating these divergent models
without collecting any raw data from users. Federated learning
can slow down convergence and substantially reduce the model
has enabled joint model training over privacy-sensitive data
accuracy [8].
in a wide range of applications, including natural language
In this paper, we present the design and implementation
processing, computer vision, and speech recognition.
of FAVOR, a control framework that aims to improve the
This work is supported by a research contract with Huawei Technologies performance of federated learning through intelligent device
Co. Ltd. and a NSERC Discovery Research Program. selection. Based on reinforcement learning, FAVOR aims to
PC
accelerate and stabilize the federated learning process by denote the prediction function, where S = {z| i=1 zi =
learning to actively select the best subset of devices in each 1, zi > 0, 8i 2 [C]}. That is, the vector valued function f
communication round that can counterbalance the bias intro- yields a probability vector z for each sample x, where fi
duced by non-IID data. predicts the probability that the sample belongs to the ith class.
Since privacy concerns have ruled out any possibility of Let the vector w denote model weights. For classification,
accessing the raw data on each device, it is impossible to the commonly used training loss is cross entropy, defined as
profile the data distribution on each device directly. However, X C
we observe that there is an implicit connection between the `(w) := Ex,y⇠p 1y=i log fi (x, w)
distribution of the training samples on a device and the i=1
model weights trained based on those samples. Therefore, C
X
our main intuition is to use the local model weights and the = p(y = i)Ex|y=i [log fi (x, w)],
shared global model as states to judiciously select devices i=1
that may contribute to global model improvement. We propose
The learning problem is to solve the following optimization
a reinforcement learning agent based on a Deep Q-Network
problem:
(DQN), which is trained through a Double DQN for increased
efficiency and robustness. We carefully design the reward C
X
signal in order to aggressively improve the global model minimizew p(y = i)Ex|y=i [log fi (x, w)]. (1)
accuracy per round as well as to reduce the total number of i=1
communication rounds required. We also propose a practical In federated learning, suppose there are N client devices
scheme that compresses model weights to reduce the high in total. The kth device has m(k) data samples following
dimensionality of the state space, especially for big models, the data distribution p(k) , which is a joint distribution of the
e.g., deep neural network models. samples {x, y} on this device. In each round t, K devices
We have implemented FAVOR in a federated learning simu- are selected (at random in the original FL), each of which
lator developed from scratch using PyTorch1 , and evaluated it downloads the current global model weights wt 1 from the
under a variety of federated learning tasks. Our experimental server and performs the following stochastic gradient descent
results on the MNIST, FashionMNIST, and CIFAR-10 datasets (SGD) training locally:
have shown that FAVOR can reduce the required number of
(k)
communication rounds in federated learning by up to 49% on wt = wt 1 ⌘r`(wt 1) (2)
the MNIST, by up to 23% on FashionMNIST, and by up to C
X
42% on CIFAR-10, as compared to the F EDAVG algorithm. = wt 1 ⌘ p(k) (y = i)rw Ex|y=i [log fi (x, wt 1 )],
i=1
II. BACKGROUND AND M OTIVATION
where ⌘ is the learning rate. Note that in such notation with
In this section, we briefly introduce how federated learning (k)
expectations, wt can be a result one or multiple epochs of
works in general and show how non-IID data poses challenges
local SGD training [1].
to existing federated learning algorithms. We demonstrate how (k)
to properly select client devices at each round to improve the Once wt is obtained, each participating device k reports
(k)
performance of federated learning on non-IID data. a model weight difference t to the FL server [1], which is
defined as
A. Federated Learning (k)
:= wt
(k) (k)
wt 1 .
t
Federated learning (FL) trains a shared global model by (k)
iteratively aggregating model updates from multiple client de- Devices can also upload their local models wt directly,
(k)
vices, which may have slow and unstable network connections. but transferring t is more amenable to compression and
Initially, eligible client devices first check-in with a remote communication reduction [4]. After the FL server has collected
server. The remote server then proceeds federated learning updates from all K participating devices in round t, it performs
synchronously in rounds. In each round, the server randomly the federated averaging (F EDAVG) algorithm to update the
selects a subset of available client devices to participate in global model:
training. The selected devices first download the latest global PK (k) (k) PK (k)
model from the server, train the model on their local datasets, t = k=1 m t / k=1 m ,
and report their respective model updates to the server for wt wt +
1 t.
aggregation.
We formally introduce FL in the context of a C-class B. The Challenges of Non-IID Data Distribution
classification problem, which is defined over a compact feature
F EDAVG has been shown to work well approximating the
space X and a label space Y = [C], where [C] = 1, · · · , C.
model trained on centrally collected data, given that the data
Let (x, y) denote a particular labeled sample. Let f : X ! S
and label distributions on different devices are IID [1]. How-
1 Our implementation is available as an open-source GitHub repository, ever, in reality, data owned by each device are typically non-
located at https://github.com/iqua/flsim. IID, i.e., p(k) are different on different devices, due to different
rounds to achieve the target accuracy in the IID setting2 .
However, F EDAVG takes 163 rounds to achieve the 99% target
98 accuracy on non-IID data, which more than triples the number
of communication rounds.
Accuracy (%)
(k)
round t to be rt = ⌅(!t ⌦) 1, t = 1, . . . , T , where !t is the (k) {w 1 |k 2 [N ]}
w1 <latexit sha1_base64="oV5h+YSOUeEd8asuaF/C+sHAtOI=">AAACDHicbVC7TsMwFHXKq5RXgZHFokIqS5UUJBgrsTChItGHlITKcZzWquNEtgOqQj6AhV9hYQAhVj6Ajb/BaTNAy5UsH51zru69x4sZlco0v43S0vLK6lp5vbKxubW9U93d68ooEZh0cMQi0feQJIxy0lFUMdKPBUGhx0jPG1/keu+OCEkjfqMmMXFDNOQ0oBgpTQ2qNSd1vIj5chLqL73PBtZtWh8fZw9jh3JoX7lOpl1mw5wWXARWAWqgqPag+uX4EU5CwhVmSErbMmPlpkgoihnJKk4iSYzwGA2JrSFHIZFuOj0mg0ea8WEQCf24glP2d0eKQplvq50hUiM5r+Xkf5qdqODcTSmPE0U4ng0KEgZVBPNkoE8FwYpNNEBYUL0rxCMkEFY6v4oOwZo/eRF0mw3rpNG8Pq21mkUcZXAADkEdWOAMtMAlaIMOwOARPINX8GY8GS/Gu/Exs5aMomcf/Cnj8wcWOZuR</latexit>
{w2 |k 2 [K]}
(k)
T
X T
X
<latexit sha1_base64="gaq2JVHuCqEO1R5cW204LFcOocI=">AAACDHicbVDLSgMxFM3UV62vqks3wSLUTZmpgi4LbgQ3FewDZsaSSdM2NJMMSUYp43yAG3/FjQtF3PoB7vwbM+0stPVCyOGcc7n3niBiVGnb/rYKS8srq2vF9dLG5tb2Tnl3r61ELDFpYcGE7AZIEUY5aWmqGelGkqAwYKQTjC8yvXNHpKKC3+hJRPwQDTkdUIy0oXrlipd4gWB9NQnNl9ynvfptUh0fpw9jj3LoXvlealx2zZ4WXARODiogr2av/OX1BY5DwjVmSCnXsSPtJ0hqihlJS16sSITwGA2JayBHIVF+Mj0mhUeG6cOBkOZxDafs744EhSrb1jhDpEdqXsvI/zQ31oNzP6E8ijXheDZoEDOoBcySgX0qCdZsYgDCkppdIR4hibA2+ZVMCM78yYugXa85J7X69WmlUc/jKIIDcAiqwAFnoAEuQRO0AAaP4Bm8gjfryXqx3q2PmbVg5T374E9Znz8TOZuP</latexit>
R= t 1
rt = t 1
(⌅(!t ⌦)
1), … Step 5
t=1 t=1
(k)
Step 3 DRL Agent
where 2 (0, 1] is a factor discounting future rewards. {Q(w2 , k; ✓)|k 2 [N ]}
<latexit sha1_base64="b7G1QBL5RdJMxNq4ngnnNzgYEao=">AAACKHicbVDLSsNAFJ3Ud31VXboZLEIFKUkUFFxYcONKLNgqNLFMptN2yGQSZm6UEvM5bvwVNyKKdOuXOH0sfB0Y5nDOvdx7T5AIrsG2h1ZhZnZufmFxqbi8srq2XtrYbOo4VZQ1aCxidRMQzQSXrAEcBLtJFCNRINh1EJ6N/Os7pjSP5RUMEuZHpCd5l1MCRmqXTr2sXvGCWHT0IDJfdt92b7NKuJfn+zg8wd8tD/oMSL73gEOPS9y68L28XSrbVXsM/Jc4U1JGU1y2S69eJ6ZpxCRQQbRuOXYCfkYUcCpYXvRSzRJCQ9JjLUMliZj2s/GhOd41Sgd3Y2WeBDxWv3dkJNKjZU1lRKCvf3sj8T+vlUL32M+4TFJgkk4GdVOBIcaj1HCHK0ZBDAwhVHGzK6Z9oggFk23RhOD8PvkvabpV56Dq1g/LNXcaxyLaRjuoghx0hGroHF2iBqLoET2jN/RuPVkv1oc1nJQWrGnPFvoB6/ML+nGmew==</latexit>
…
<latexit sha1_base64="SAJ8pipY4rhTg+TUE4Pv0ogs3CE=">AAAB/3icbVDNS8MwHE3n15xfVcGLl+IQ5mW0m6DHgRePE9wHbLWkabqFpUlJUmXUHvxXvHhQxKv/hjf/G9OtB918EPJ47/cjL8+PKZHKtr+N0srq2vpGebOytb2zu2fuH3QlTwTCHcQpF30fSkwJwx1FFMX9WGAY+RT3/MlV7vfusZCEs1s1jbEbwREjIUFQackzj4Y+p4GcRvpKH7K7tDY5y7ymZ1btuj2DtUycglRBgbZnfg0DjpIIM4UolHLg2LFyUygUQRRnlWEicQzRBI7wQFMGIyzddJY/s061ElghF/owZc3U3xspjGQeUU9GUI3lopeL/3mDRIWXbkpYnCjM0PyhMKGW4lZehhUQgZGiU00gEkRntdAYCoiUrqyiS3AWv7xMuo2606w3bs6rrUZRRxkcgxNQAw64AC1wDdqgAxB4BM/gFbwZT8aL8W58zEdLRrFzCP7A+PwBKbGWIQ==</latexit>
<latexit sha1_base64="oSLjVMi6mYIJ3lLOa0C/IX6AUT0=">AAACDHicbVDLSsNAFJ34rPVVdelmsAh1U5JW0GXBjeCmgn1AEstkMmmHTiZhZqKUmA9w46+4caGIWz/AnX/jpM1CWy8MczjnXO69x4sZlco0v42l5ZXVtfXSRnlza3tnt7K335VRIjDp4IhFou8hSRjlpKOoYqQfC4JCj5GeN77I9d4dEZJG/EZNYuKGaMhpQDFSmhpUqk7qeBHz5STUX3qfDZq3aW18kj2MHcqhfeU6mXaZdXNacBFYBaiCotqDypfjRzgJCVeYISlty4yVmyKhKGYkKzuJJDHCYzQktoYchUS66fSYDB5rxodBJPTjCk7Z3x0pCmW+rXaGSI3kvJaT/2l2ooJzN6U8ThTheDYoSBhUEcyTgT4VBCs20QBhQfWuEI+QQFjp/Mo6BGv+5EXQbdStZr1xfVptNYo4SuAQHIEasMAZaIFL0AYdgMEjeAav4M14Ml6Md+NjZl0yip4D8KeMzx8UzpuQ</latexit>
pace, which means |!t !t 1 | decreases as round t increases. Fig. 2: The federated learning workflow with FAVOR.
Therefore, we use an exponential term to amplify the marginal
accuracy increase as FL progresses into later stages. ⌅ is set Note that the above workflow can be easily tweaked to let each
to 64 in our experiments. selected client upload the model difference (as compared
The second term, 1, encourages the agent to complete to the last version of the local model) instead of the new
training in fewer rounds, because the more rounds it takes, local model itself, without changing the underlying logic. The
the less cumulative reward the agent will receive. FL server can always use the to reconstruct a copy of the
B. Workflow corresponding local model stored on it.
When a device k has new training data, it will request the
Fig. 2 shows how our system FAVOR performs federated
global model weights wt and perform local SGD training for
learning with a DRL agent selecting devices in each round,
one epoch as Step 2, which generates new local model weights
following the steps below: (k) (k)
wt . Then the device pushes wt to the FL server, and the
Step 1: All N eligible devices check in with the FL server. FL server updates the state st to make the DRL agent aware
Step 2: Each device downloads the initial random model of the device with updated data. It should be noted that data
weights winit from the server, performs local SGD on each device remain unchanged during the DRL training to
training for one epoch, and returns the resulting model ensure that the training follows the Markov Decision Process.
(k)
weights {w1 , k 2 [N ]} to the FL server. The overhead introduced is minimum since no extra compu-
Step 3: In round t, where t = 1, 2, . . ., upon receiving the tation/communication overhead is added to devices. The only
uploaded local weights, the corresponding copies of local overhead is that now the FL server needs to store a copy
model weights stored on the server are updated. The DQN of the latest local model weights from each device in order
agent computes Q(st , a; ✓) for all devices a = 1, . . . , N . to form the state st . However, the FL server is deployed in
Step 4: The DQN agent selects K devices corresponding to cloud, where provisions sufficient on-demand computational
the top-K values of Q(st , a; ✓), a = 1, . . . , N . The and storage capacity .
selected K devices download the latest global model
weights wt and perform one epoch of SGD locally to C. Dimension Reduction
(k)
obtain {wt+1 |k 2 [K]}. One issue of the proposed DQN agent is that it uses the
(k)
Step 5: {wt+1 |k 2 [K]} are reported (uploaded) to the server weights of the global model and all local models as the state,
to compute wt+1 based on F EDAVG. Move into round which leads to a large state space. Many deep neural networks
t + 1 and repeat Steps 3-5. (e.g., CNN) have millions of weights, making it challenging
Steps 3-5 will be repeated until completion, e.g., until a to train the DQN agent with such a high dimensional state
target accuracy is reached or after a certain number of rounds. space. To reduce the dimension of the state space, we propose
Reward rt <latexit sha1_base64="ajxj6ouHRhvivI0uKkY+9/oerdw=">AAAB6nicbVBNS8NAEJ3Ur1q/qh69LBbBU0mKoMeCF48VTVtoQ9lsN+3SzSbsToQS+hO8eFDEq7/Im//GbZuDtj4YeLw3w8y8MJXCoOt+O6WNza3tnfJuZW//4PCoenzSNkmmGfdZIhPdDanhUijuo0DJu6nmNA4l74ST27nfeeLaiEQ94jTlQUxHSkSCUbTSgx7goFpz6+4CZJ14BalBgdag+tUfJiyLuUImqTE9z00xyKlGwSSfVfqZ4SllEzriPUsVjbkJ8sWpM3JhlSGJEm1LIVmovydyGhszjUPbGVMcm1VvLv7n9TKMboJcqDRDrthyUZRJggmZ/02GQnOGcmoJZVrYWwkbU00Z2nQqNgRv9eV10m7UPbfu3V/Vmo0ijjKcwTlcggfX0IQ7aIEPDEbwDK/w5kjnxXl3PpatJaeYOYU/cD5/AGNIjc0=</latexit>
Q
<latexit sha1_base64="ET4tcd8QM9Ni9It4gU8eEcPLoG8=">AAAB8XicbVBNS8NAEN3Ur1q/qh69LBbBU0mqoMeCF48t2A9sQ9lsJ+3SzSbsToQS+i+8eFDEq//Gm//GbZuDtj4YeLw3w8y8IJHCoOt+O4WNza3tneJuaW//4PCofHzSNnGqObR4LGPdDZgBKRS0UKCEbqKBRYGETjC5m/udJ9BGxOoBpwn4ERspEQrO0EqPzUHWxzEgmw3KFbfqLkDXiZeTCsnRGJS/+sOYpxEo5JIZ0/PcBP2MaRRcwqzUTw0kjE/YCHqWKhaB8bPFxTN6YZUhDWNtSyFdqL8nMhYZM40C2xkxHJtVby7+5/VSDG/9TKgkRVB8uShMJcWYzt+nQ6GBo5xawrgW9lbKx0wzjjakkg3BW315nbRrVe+qWmteV+q1PI4iOSPn5JJ45IbUyT1pkBbhRJFn8kreHOO8OO/Ox7K14OQzp+QPnM8fwqaQ7A==</latexit>
Features
0.2
softmax
−0.05
…
Action at
…
C1
…
…
0.1
<latexit sha1_base64="JiBKSxYyyyX1X4O24iYUGK1+pD4=">AAAB6nicbVBNS8NAEJ3Ur1q/qh69LBbBU0lqQY8FLx4r2g9oQ9lsN+3SzSbsToQS+hO8eFDEq7/Im//GbZuDtj4YeLw3w8y8IJHCoOt+O4WNza3tneJuaW//4PCofHzSNnGqGW+xWMa6G1DDpVC8hQIl7yaa0yiQvBNMbud+54lrI2L1iNOE+xEdKREKRtFKD3SAg3LFrboLkHXi5aQCOZqD8ld/GLM04gqZpMb0PDdBP6MaBZN8VuqnhieUTeiI9yxVNOLGzxanzsiFVYYkjLUthWSh/p7IaGTMNApsZ0RxbFa9ufif10sxvPEzoZIUuWLLRWEqCcZk/jcZCs0ZyqkllGlhbyVsTDVlaNMp2RC81ZfXSbtW9a6qtft6pVHL4yjCGZzDJXhwDQ24gya0gMEInuEV3hzpvDjvzseyteDkM6fwB87nD0qwjcA=</latexit>
…
0 −0.10 FL server
−0.2
−0.2 0 0.2 0.4 0.6 Fig. 4: The DDQN Agent interacting with the FL server.
C0
Fig. 3: PCA of CNN weights on the MNIST dataset. optimize the agent performance by learning an approximator
Q(s, a; ✓ t ) to the optimal action-value function Q⇤ (s, a).
to apply principle component analysis (PCA) to model weights DDQN adds another value function Q(s, a; ✓ 0 t ) to stabilize
and use the compressed model weights to represent states the action-value function estimation. The idea behind DDQN
instead. is that the network is frozen every M updates. DDQN adds
Specifically, we compute the PCA loading vectors of differ- stability to the action-value evaluation, which is less prone to
ent principal components only based on local model weights “jittering.”
(k)
for t = 1, i.e., {w1 , k 2 [N ]}, obtained in Step 2. In To train the DRL agent, the FL server first performs random
the subsequent rounds t = 2, 3, . . ., such loading vectors device selection to initialize the states. As shown in Fig. 4, the
are reused to obtain first several principal components of states are fed into the one of the double DQNs Q(st , a; ✓ t ).
(k)
{wt , k 2 [N ]}, through linear transformation, without fitting The DQN generates an action a to select a device for the FL
the PCA model again. Therefore, the overhead to compress server. After several rounds of FL training, the DRL agent has
model weights in subsequent rounds is negligibly small. sampled a few action-state pairs, with which the agent learns
We demonstrate the effectiveness of using PCA-compressed to solve the (4) as:
model weights to differentiate between data distributions `t (✓ t ) = (YtDoubleQ Q(st , a; ✓ t ))2 ,
through a simple experiment. Consider a two-layer CNN
model (with 431,080 model weights in total) to be trained where YtDoubleQ is the target at round t defined as
with federated learning on 100 devices, running PyTorch, on
YtDoubleQ : = rt + max Q(st+1 , a; ✓ t ) (5)
the MNIST dataset. Data samples are distributed in the same a
way as the experiment in Sec. II: each client has 80% of = rt + Q(st , argmax Q(st , a; ✓ t ); ✓ 0 t ). (6)
a
its data samples belonging to a dominant class, while the
remaining 20% of its samples have random labels. After five (6) uses two acton-value functions to update YtDoubleQ , in
epochs of local SGD training in Round 1, we project the which ✓ t is the online parameters updated per time step, and
(k)
model weight vectors of the 100 devices, {w1 , k 2 [N ]}, the ✓ 0 t is the frozen parameters to add stability to action-
onto a two-dimensional space of the first and second principal value estimation. The action-value function Q(st+1 , a; ✓ t ) is
components. updated to minimize `t (✓ t ) by gradient descent, i.e.,
As shown in Fig. 3, different shapes (or colors) indicate
local models trained on devices with different dominant labels. ✓ t+1 = ✓ t + ↵(YtDoubleQ Q(st , a; ✓ t ))r✓t Q(st , a; ✓ t ),
For example, all the yellow “+” signs denote the compressed where ↵ is a scalar step size.
local models in Round 1 from those devices with a dominant
label “6”. Even reduced from 431,080 dimensions to only IV. E VALUATION
two dimensions and even at Round 1, we can observe a clus- We have implemented FAVOR with PyTorch in around
tering effecting of local models according to their dominant 2000 lines of code, which we have released as an open-
labels. Therefore, in the evaluation in Sec. IV, we use PCA- source project. With the Python threading library, FAVOR can
compressed model weights as states to enable efficient training simulate a large number of devices with lightweight threads,
of the DQN agent. each running real-world PyTorch models.
We evaluated FAVOR by training popular CNN models
D. Training the Agent with Double DQN on three benchmark datasets: MNIST, FashionMNIST, and
We propose to use the double deep Q-learning network CIFAR-10, with F EDAVG and K-Center as the groups of
(DDQN) to learn the function Q⇤ (st , a). Q-learning provides comparison. We evaluated the accuracy of the trained models
a value estimation for each potential action a at state st , using the testing set from each dataset. Our experimental
based on which devices are selected. However, the original results show that FAVOR can reduce the communication rounds
Q-learning algorithms can be unstable since they indirectly by up to 49% on the MNIST, up to 23% on FashionMNIST,
−20
Total Return 0 0
Total Return
Total Return
−40
−200 −60
−50
−80
−400 −100
−100
−120
0 50 100 150 0 50 100 150 200 0 50 100 150 200
Episode Episode Episode
(a) DRL training on the MNIST. (b) DRL training on FashionMNIST. (c) DRL training on CIFAR-10.
and up to 42% on CIFAR-10, compared to the F EDAVG Baseline Favor K-Center FedAvg
algorithm. We briefly describe our methodology and settings σ=1 σ=H
as follows.
Datasets and models: We explored different combinations 100
FL Accuracy (%)
FL Accuracy (%)
98
of hyper-parameters for the CNN models on different datasets
and chose the hyper-parameters leading to the best perfor-
80
mance of F EDAVG. 96
• MNIST. We train a CNN model with two 5⇥5 convo-
lution layers. The first layer has 20 output channels and 60
the second has 50, with each layer followed by 2⇥2 max 94
0 200 400 0 200
pooling. On each device, the batch size is ten and the Communication Round (#) Communication Round (#)
epoch number is five. σ=0.8 σ=0.5
• FashionMNIST. We train a CNN model with two 5⇥5 99 99
convolution layers. The first layer has 16 output channels
FL Accuracy (%)
FL Accuracy (%)
and the second has 32, with each layer followed by 2⇥2 98 98
max pooling. On each device, the batch size is 100 and
the epoch number is five.
97 97
• CIFAR-10. We train a CNN model with two 5⇥5 convo-
lution layers. The first layer has six output channels and
the second having 16, with each layer followed by 2⇥2 96 96
50 100 150 20 40 60
max pooling. On each device, the batch size is 50 and Communication Round (#) Communication Round (#)
the epoch number is five.
Fig. 6: Accuracy v.s. communication rounds on different
Performance metrics: In federated learning, due to the lim-
levels of non-IID MNIST data.
ited computation capacity and network bandwidth of mobile
devices, reducing the number of communication rounds is cru- accuracy ⌦. The total return is the cumulative discounted
cially important. Thus, we use the number of communication reward R obtained in one episode. The target accuracy ⌦ is
rounds as the performance metric of FAVOR. set to 99% for training on the MNIST, 85% for training on
A. Training the DRL agent FashionMNIST, and 55% for training on CIFAR-10.
We train the DRL agent on different datasets with 100
B. Different Levels of Non-IID Data
available devices. The DDQN model in the DRL agent consists
of two two-layer MLP networks, with 512 hidden states. The We investigate different levels of non-IID data to testify the
input size is 10,100, where we have 101 model weights (i.e., efficiency of FAVOR. We include the performance of F EDAVG
the global weights w and the local weights {w(k) |k 2 [100]} and K-Center as a comparison. At each round, the number of
from 100 devices) reduced to 100 dimensions. The output selected devices K is set to 10, as in [1].
size of the second layer is 100. Each output passing through We use to denote the four different levels of non-IID data:
a softmax layer becomes the probability of selecting the = 1.0 indicating that data on each device only belong to one
particular device. We train the DRL agent on an AWS EC2 label, = 0.8 indicating that 80% of the data belong to one
instance p2.2xlarge with a K80 GPU. The DDQN model is label and the remaining 20% data belong to other labels, =
lightweight, and each training iteration takes seconds on GPU. 0.5 indicating that 50% of the data belong to one label and the
Reducing model weights from 431,080 to 100 dimensions remaining 50% data belong to other labels, = 0 indicating
takes minutes by sklearn.decomposition.PCA. the data one each device is IID and = H indicating that
Fig. 5 plots the training progress of the DRL agent on three data on each device evenly belong to two labels. We plot the
FL tasks. An episode starts at the initialization of a federated test-set accuracy v.s. the communication rounds of federated
learning job and ends when the job converges to the target learning on the three benchmarks in Fig. 6, Fig. 7, and Fig. 8.
Baseline Favor KCenter FedAvg Baseline Favor KCenter FedAvg
σ=1 σ=H σ=1 σ=H
100 60
80
FL Accuracy (%)
FL Accuracy (%)
FL Accuracy (%)
FL Accuracy (%)
40
60 40
50
40
20
20
20
0 0 0
0 200 400 0 100 200 300 0 100 200 0 100 200
Communication Round (#) Communication Round (#) Communication Round (#) Communication Round (#)
σ=0.8 σ=0.5 σ=0.8 σ=0.5
85 55
80 50
FL Accuracy (%)
FL Accuracy (%)
FL Accuracy (%)
FL Accuracy (%)
80
40 50
75 60
30 45
70
40 20
65 40
0 20 40 0 10 20 20 40 60 80 20 40 60
Communication Round (#) Communication Round (#) Communication Round (#) Communication Round (#)
Fig. 7: Accuracy v.s. communication rounds on different Fig. 8: Accuracy v.s. communication rounds on different
levels of non-IID FashionMNIST data. levels of non-IID CIFAR-10 data.
winit
Each entry in Table I shows the number of communication 1.5 Local weights Global weights
rounds necessary to achieve a test-set accuracy of 99% for 1.0 w1
the CNN on MNIST, 85% for FashionMNIST, and 55% for
C2
0.5 w2
CIFAR-10. It should be noted that the italic numbers indicate w3
0 w5 w4
that the model converges to a lower accuracy than the test-
set accuracy. The model on MNIST with data distribution of −0.5
1.0 1.5 2.0 2.5 3.0
= 1.0 converges to a test-set accuracy of 96%. The models C1
trained on FashionMNIST with data distribution of = 1.0 (a) Training on MNIST with F EDAVG.
and = H both converge to a test-set accuracy of 80%. The winit
1.5
model trained on CIFAR-10 with data distribution of = 1.0 Local weights Global weights
converges to a test-set accuracy of 50%. 1.0 w1
C2
−0.5
1.0 1.5 2.0 2.5 3.0
MNIST FashionMNIST CIFAR-10 C1
F EDAVG 1.0 1517 1811 1714 Fig. 9: PCA on model weights of FL training with MNIST.
K-Center 1.0 1684 2132 1871 w1 , w2 . . . , w5 are the global model weights at Round 1-5.
FAVOR 1.0 1232 1497 1383
F EDAVG H 313 1340 198 the K-Center clusters may introduce more bias to federated
K-Center H 421 1593 188
learning than devices selected randomly by F EDAVG. How-
FAVOR H 272 1134 114
ever, FAVOR has reduced the number of communication rounds
F EDAVG 0.8 221 52 87 by up to 49% on the MNIST, up to 23% on FashionMNIST,
K-Center 0.8 126 62 74
FAVOR 0.8 113 43 61 and up to 42% on CIFAR-10, compared to the F EDAVG
algorithm.
F EDAVG 0.5 59 19 67
K-Center 0.5 67 21 52
C. Device Selection and Weight Updates
FAVOR 0.5 59 16 50
We save the global model weights and local model weights
TABLE I: The number of communication rounds to reach per round when we train the two-layer CNN on MNIST with
a target accuracy for FAVOR v.s. F EDAVG and K-Center. = 0.8. The saved model weights are reduced to two-
K=10 K=50 K=100 and sketched updates to decrease the completion time of
FedAvg
each communication round. Bonawitz et al. [3] developed a
60 secure aggregation protocol that enables a server to perform
Accuracy (%)
40
devices in each round. Sattler et al. [9] proposed a compression
K = 10, round # = 74 framework, Sparse Ternary Compression (STC), that reduces
20 K = 50, round # = 76 the communication costs and remains robust to non-IID data.
K = 100, round # = 128 FAVOR applies DRL to select the best subset of participating
0 20 40 60 80 100 120 140 devices to minimize the number of communication rounds,
Communication Round (#) which is orthogonal to these studies on communication effi-
60
Favor ciency.
Sample efficiency. Unlike centralized machine learning,
Accuracy (%)