You are on page 1of 10

Optimizing Federated Learning on Non-IID Data

with Reinforcement Learning


Hao Wang1 , Zakhary Kaplan1 , Di Niu2 and Baochun Li1
1
University of Toronto, {haowang, bli}@ece.utoronto.ca, zakhary.kaplan@mail.utoronto.ca 2
University of Alberta, dniu@ualberta.ca

Abstract—The widespread deployment of machine learning However, unlike in a server-based environment, distributed
applications in ubiquitous environments has sparked interests in machine learning on mobile devices is faced with a few
exploiting the vast amount of data stored on mobile devices. To fundamental challenges, such as the limited connectivity of
preserve data privacy, Federated Learning has been proposed to
learn a shared model by performing distributed training locally wireless networks, unstable availability of mobile devices, and
on participating devices and aggregating the local models into the non-IID distributions of local datasets which are hard to
a global one. However, due to the limited network connectivity characterize statistically [1]–[4]. Due to these reasons, it is not
of mobile devices, it is not practical for federated learning to as practical to perform machine learning on all participating
perform model updates and aggregation on all participating devices simultaneously in a synchronous or semi-synchronous
devices in parallel. Besides, data samples across all devices are
usually not independent and identically distributed (IID), posing manner as is in a parameter-server cluster. Therefore, as
additional challenges to the convergence and speed of federated a common practice in federated learning, only a subset of
learning. devices are randomly selected to participate in each round
In this paper, we propose FAVOR, an experience-driven control of model training to avoid long-tailed waiting times due to
framework that intelligently chooses the client devices to partic- unstable network conditions and straggler devices [1, 5].
ipate in each round of federated learning to counterbalance the
bias introduced by non-IID data and to speed up convergence. Since the completion time of federated learning is dom-
Through both empirical and mathematical analysis, we observe inated by the communication time, Konečnỳ et al. [6, 7]
an implicit connection between the distribution of training data proposed structured and sketched updates to decrease the
on a device and the model weights trained based on those time it takes to complete a communication round. McMa-
data, which enables us to profile the data distribution on that han et al. [1] presented the Federated Averaging (F EDAVG)
device based on its uploaded model weights. We then propose
a mechanism based on deep Q-learning that learns to select a algorithm, which aims to reduce the number communication
subset of devices in each communication round to maximize a rounds required by averaging model weights from client de-
reward that encourages the increase of validation accuracy and vices instead of applying traditional gradient descent updates.
penalizes the use of more communication rounds. With extensive F EDAVG has also been deployed in a large-scale system [4]
experiments performed in PyTorch, we show that the number to test its effectiveness.
of communication rounds required in federated learning can
be reduced by up to 49% on the MNIST dataset, 23% on However, existing federated learning methods have not
FashionMNIST, and 42% on CIFAR-10, as compared to the solved the statistical challenges posed by heterogeneous local
Federated Averaging algorithm. datasets. Since different users have different device usage
patterns, the data samples and labels located on any in-
I. I NTRODUCTION dividual device may follow a different distribution, which
cannot represent the global data distribution. It has been
As the primary computing resource for billions of users,
recently pointed out that the performance of federated learning,
mobile devices are constantly generating massive volumes
especially F EDAVG, may significantly degrade in the presence
of data, such as photos, voices, and keystrokes, which are
of non-IID data, in terms of the model accuracy and the
of great value for training machine learning models. How-
communication rounds required for convergence [8]–[10]. In
ever, due to privacy concerns, collecting private data from
particular, F EDAVG randomly selects a subset of devices in
mobile devices to the cloud for centralized model training
each round and averages their local model weights to update
is not always possible. To efficiently utilize end-user data,
the global model. The randomly selected local datasets may
Federated Learning (FL) has emerged as a new paradigm of
not reflect the true data distribution in a global view, which
distributed machine learning that orchestrates model training
inevitably incurs biases to global model updates. Furthermore,
across mobile devices [1]. With federated learning, locally
the models locally trained on non-IID data can be significantly
trained models are communicated to a server for aggregation,
different from each other. Aggregating these divergent models
without collecting any raw data from users. Federated learning
can slow down convergence and substantially reduce the model
has enabled joint model training over privacy-sensitive data
accuracy [8].
in a wide range of applications, including natural language
In this paper, we present the design and implementation
processing, computer vision, and speech recognition.
of FAVOR, a control framework that aims to improve the
This work is supported by a research contract with Huawei Technologies performance of federated learning through intelligent device
Co. Ltd. and a NSERC Discovery Research Program. selection. Based on reinforcement learning, FAVOR aims to
PC
accelerate and stabilize the federated learning process by denote the prediction function, where S = {z| i=1 zi =
learning to actively select the best subset of devices in each 1, zi > 0, 8i 2 [C]}. That is, the vector valued function f
communication round that can counterbalance the bias intro- yields a probability vector z for each sample x, where fi
duced by non-IID data. predicts the probability that the sample belongs to the ith class.
Since privacy concerns have ruled out any possibility of Let the vector w denote model weights. For classification,
accessing the raw data on each device, it is impossible to the commonly used training loss is cross entropy, defined as
profile the data distribution on each device directly. However, X C
we observe that there is an implicit connection between the `(w) := Ex,y⇠p 1y=i log fi (x, w)
distribution of the training samples on a device and the i=1
model weights trained based on those samples. Therefore, C
X
our main intuition is to use the local model weights and the = p(y = i)Ex|y=i [log fi (x, w)],
shared global model as states to judiciously select devices i=1
that may contribute to global model improvement. We propose
The learning problem is to solve the following optimization
a reinforcement learning agent based on a Deep Q-Network
problem:
(DQN), which is trained through a Double DQN for increased
efficiency and robustness. We carefully design the reward C
X
signal in order to aggressively improve the global model minimizew p(y = i)Ex|y=i [log fi (x, w)]. (1)
accuracy per round as well as to reduce the total number of i=1
communication rounds required. We also propose a practical In federated learning, suppose there are N client devices
scheme that compresses model weights to reduce the high in total. The kth device has m(k) data samples following
dimensionality of the state space, especially for big models, the data distribution p(k) , which is a joint distribution of the
e.g., deep neural network models. samples {x, y} on this device. In each round t, K devices
We have implemented FAVOR in a federated learning simu- are selected (at random in the original FL), each of which
lator developed from scratch using PyTorch1 , and evaluated it downloads the current global model weights wt 1 from the
under a variety of federated learning tasks. Our experimental server and performs the following stochastic gradient descent
results on the MNIST, FashionMNIST, and CIFAR-10 datasets (SGD) training locally:
have shown that FAVOR can reduce the required number of
(k)
communication rounds in federated learning by up to 49% on wt = wt 1 ⌘r`(wt 1) (2)
the MNIST, by up to 23% on FashionMNIST, and by up to C
X
42% on CIFAR-10, as compared to the F EDAVG algorithm. = wt 1 ⌘ p(k) (y = i)rw Ex|y=i [log fi (x, wt 1 )],
i=1
II. BACKGROUND AND M OTIVATION
where ⌘ is the learning rate. Note that in such notation with
In this section, we briefly introduce how federated learning (k)
expectations, wt can be a result one or multiple epochs of
works in general and show how non-IID data poses challenges
local SGD training [1].
to existing federated learning algorithms. We demonstrate how (k)
to properly select client devices at each round to improve the Once wt is obtained, each participating device k reports
(k)
performance of federated learning on non-IID data. a model weight difference t to the FL server [1], which is
defined as
A. Federated Learning (k)
:= wt
(k) (k)
wt 1 .
t
Federated learning (FL) trains a shared global model by (k)
iteratively aggregating model updates from multiple client de- Devices can also upload their local models wt directly,
(k)
vices, which may have slow and unstable network connections. but transferring t is more amenable to compression and
Initially, eligible client devices first check-in with a remote communication reduction [4]. After the FL server has collected
server. The remote server then proceeds federated learning updates from all K participating devices in round t, it performs
synchronously in rounds. In each round, the server randomly the federated averaging (F EDAVG) algorithm to update the
selects a subset of available client devices to participate in global model:
training. The selected devices first download the latest global PK (k) (k) PK (k)
model from the server, train the model on their local datasets, t = k=1 m t / k=1 m ,
and report their respective model updates to the server for wt wt +
1 t.
aggregation.
We formally introduce FL in the context of a C-class B. The Challenges of Non-IID Data Distribution
classification problem, which is defined over a compact feature
F EDAVG has been shown to work well approximating the
space X and a label space Y = [C], where [C] = 1, · · · , C.
model trained on centrally collected data, given that the data
Let (x, y) denote a particular labeled sample. Let f : X ! S
and label distributions on different devices are IID [1]. How-
1 Our implementation is available as an open-source GitHub repository, ever, in reality, data owned by each device are typically non-
located at https://github.com/iqua/flsim. IID, i.e., p(k) are different on different devices, due to different
rounds to achieve the target accuracy in the IID setting2 .
However, F EDAVG takes 163 rounds to achieve the 99% target
98 accuracy on non-IID data, which more than triples the number
of communication rounds.
Accuracy (%)

96 To reduce the required communication rounds under the


non-IID setting, we then tuned the device selection strategy.
94 While in FL, it is impossible to peek into the data distribution
FedAvg-IID of each device, we observe that there is an implicit connection
FedAvg-non-IID
K-Center-non-IID between the data distribution on a device and the model
92
weights trained on that device. Therefore, we may profile each
0 20 40 60 80 100 120 140 160 device by analyzing the local model weights after the first
Communication Round (#) round.
Fig. 1: Training a CNN model on non-IID MNIST data. We provide some rationale for this observation. According
to (2), given the initial global model weights winit , the
local weights on device k after one epoch of SGD can be
user preferences and usage patterns. When data distributions represented by
are non-IID, F EDAVG is unstable and may even diverge [8]. C
(k)
X
This is due to the inconsistency between the locally per- w1 = winit ⌘ p(k) (y = i)rw Ex|y=i [log fi (x, winit )].
formed SGD algorithm, which aims to minimize the loss i=1
value on m(k) local samples on each device andPthe global The weight divergence0 between device k and device k 0 can
K
objective of minimizing the overall loss on all k=1 m
(k)
be measured as k w1
(k ) (k)
w1 k, where k, k 0 2 [K]. Then,
data samples. As we keep fitting models on different devices by using a similar bounding technique adopted in [8], we can
to heterogeneous local data, the divergence among the weights derive the following bound for weight divergence in the first
w(k) of these local models will be accumulated and eventually round:
degrades the performance of learning [8, 10], leading to more
(k0 )
communication rounds before training converges. Reducing k w1
(k)
w1 k6
the number of communication rounds in federated learning C
X
is crucially essential for mobile devices, which have a limited
0
⌘gmax (winit ) k p(k ) (y = i) p(k) (y = i) k,
computation capacity and communication bandwidth. i=1
We use an experiment to demonstrate that non-IID data may where gmax (w) := maxC i=1 k rw Ex|y=i [log fi (x, w)] k.
slow down the convergence of federated learning if devices This implies that even if local models are trained with the same
are randomly selected in each round for model aggregation. initial weights winit , there is an implicit connection between
Instead of random selection, we show that selecting devices the discrepancy
PCof data distributions on device k and k 0 , which
with a clustering algorithm can help to even out data distri- 0
is the term i=1 k p(k ) (y = i) p(k) (y = i) k, and their
bution and speed up convergence. weight divergence.
In our experiment, we train a two-layer CNN model with Based on the sights above, we first let each of the 100
PyTorch on the MNIST dataset (containing 60,000 samples) devices download the same initial global weights (randomly
until the model achieves a 99% test accuracy. We consider generated) and perform one epoch of SGD based on local
(k)
100 devices in total, each running on a different thread, and data to get w1 , k = 1, . . . , 100. We then apply the K-Center
(1) (100)
we train the CNN model under the IID and non-IID settings, clustering algorithm [11] onto {w1 , . . . , w1 } to cluster
respectively. For the IID setting, the 60,000 samples are the 100 devices into 10 groups. In each round, we randomly
randomly distributed among 100 devices, such that each device select one device from each group to participate in FL training,
owns 600 samples. For the non-IID setting, each device still which still results in 10 simultaneously participating devices
owns 600 samples, yet 80% of which come from a dominant per round. As shown in Fig. 1, the K-Center algorithm takes
class and the remaining 20% belong to other classes. For 131 rounds to reach the target accuracy under the non-IID
example, a “0”-dominated device has 480 data samples with setting, which is 32 rounds less than the vanilla F EDAVG
the label “0”, while the remaining 120 data samples have labels algorithm. This example implies that it is possible to improve
evenly distributed among “1” to “9”. In each round, ten devices the performance of federated learning by carefully selecting
are selected to participate in federated learning. Note that FL the set of participating devices in each round, especially under
selects only a subset of all devices at each round to balance the non-IID data setting.
computation complexity and convergence rate, and to avoid
C. Deep Reinforcement Learning (DRL)
excessively long waiting times for model aggregation [1, 4],
which is also discussed in Sec. IV-D. Reinforcement Learning (RL) is the learning process of
an agent that acts in corresponding to the environment to
F EDAVG randomly selects ten devices to participate in each
round of training. As shown in Fig. 1, F EDAVG takes 47 2 The same setting costs F EDAVG 50 rounds on TensorFlow in [1].
maximize its rewards. The recent success of RL [12, 13] objective is to train the DRL agent to converge to the target
comes from the capability that RL agents learn from the accuracy for federated learning as quickly as possible.
interaction with a dynamic environment. Specifically, at each In the proposed framework, the agent does not need to
time step t, the RL agent observes states st and performs collect or check any data samples from mobile devices—
an action at . The state st of the environment then transits to only model weights are communicated—thus preserving the
st+1 , and the agent will receive reward rt . The state transitions same level of privacy as the original FL does. It solely
and rewards follow a Markov Decision Process (MDP) [14], relies on model weight information to determine which device
which is a discrete time stochastic control process, denoted by may improve the global model the most, since there is an
a sequence s1 , a1 , r1 , s2 , . . . , rt 1 , st , at , . . . . The objective implicit connection between the data distribution on a device
of RL is to maximize P the expectation of the cumulative and its local model weights obtained by performing SGD
T
discounted return R = t=1
t 1
rt , where 2 (0, 1] is a on those data. In fact, in Sec. III-C we will show that even
factor discounting future rewards. after dimension reduction, the divergence between local model
The RL agent seeks to learn a cheat sheet that points out weights is still obvious and contains information to guide
the actions leading to the maximum cumulative return under a device selection.
particular state. Value-based RL algorithms formally uses an
A. The Agent based on Deep Q-Network
action-value function Q(st , a) to estimate an expected return
starting from state st : Suppose there is a federated learning job on N available
1
devices with a target accuracy ⌦. In each round, using a Deep
X
k 1 Q-Network (DQN) [15], K devices are selected to participate
Q⇡ (st , a) : = E⇡ [ rt+k 1 |st , a]
k=1
in the training. Considering limited available traces from fed-
= Est+1 ,a [rt + Q⇡ (st+1 , a)|st , at ], erated learning jobs, DQN can be more efficiently trained and
can reuse data more effectively than policy gradient methods
where ⇡ is the policy mapping from states to probabilities of and actor-critic methods.
selecting possible actions. The optimal action-value function State: Let the state of round t be represented by a vector
(1) (N )
Q⇤ (st , a) is the cheat sheet sought by the RL agent, which st = (wt , wt , . . . , wt ), where wt denotes the weights of
is defined as the maximum expectation of the cumulative (1) (N )
the global model after round t, and wt , . . . , wt are model
discounted return starting from st : weights of the N devices, respectively.
The agent collocates with the FL server and maintains a
Q⇤ (st , a) := Est+1 [rt + max Q⇤ (st+1 , a)|st , a]. (3)
a list of model weights {w(k) |k 2 [N ]}; a particular w(k) is
Then, we could apply function approximation techniques to updated in round t only if device k is selected for training
(k)
learn a parameterized value function Q(st , a; ✓ t ) approxi- and the resulting t is received by the FL server. Therefore,
mating the optimal value function Q⇤ (st , a). The one-step there is no extra communication overhead introduced for the
lookahead rt + maxa Q(st+1 , a; ✓ t ) becomes the target that devices.
Q(st , a; ✓ t ) learns to be. Typically, a deep neural network The resulting state space can be huge, e.g., a CNN model
(DNN) is used to represent the function approximator [15]. can contain millions of weights. It is challenging to train a
The RL learning problem becomes minimizing the MSE loss DQN with such a large state space. In practice, we propose
between the target and the approximator, which is defined as: to apply an effective and lightweight dimension reduction
technique on the state space, i.e., on model weights, which
`t (✓ t ) := (rt + max Q(st+1 , a; ✓ t ) Q(st , a; ✓ t ))2 (4) will be described in detail in Sec. III-C.
a
Action: At the beginning of each round t, the agent needs
III. DRL FOR C LIENT S ELECTION to decide to select which subset of K devices from the N
We formulate device selection for federated learning as a devices. This selection would have resulted in a large action
N
deep reinforcement learning (DRL) problem. Based on the space of size K , which complicates RL training. We propose
formulation, we present a federated learning workflow driven a trick to keep the action space small while still leveraging the
by DRL in Sec. III-B. We then introduce the challenges of intelligent control provided by the DRL agent.
dealing with high-dimensional models in federated learning. In particular, the agent is trained by selecting only one out
Finally, we describe training the DRL agent with the double of N devices to participate in FL per round based on DQN,
Deep Q-learning Network (DDQN) [16]. while in testing and application the agent will sample a batch
The per-round FL process can be modeled as a Markov of top-K clients to participate in FL. That is, the DQN agent
Decision Process (MDP), with state s represented by the learns an approximator of the optimal action-value function
global model weights and the model weights of each client Q⇤ (st , a) through a neural network, which estimates the action
device in each round. Given a state, the DRL agent takes that maximizes the expected return starting from st . The action
an action a to select a subset of devices that perform local space is thus reduced to {1, 2, . . . , N }, where a = i means
training and update the global model. Then a reward signal r that device i is selected to participate in FL.
is observed, which is a function of the test accuracy achieved Once DQN has been trained to approximate Q⇤ (s, a),
by the global model so far on a held-out validation set. The during testing, in round t the DQN agent will compute
{Q⇤ (st , a)|a 2 [N ]} for all N actions. Each action-value
time
indicates the maximum expected return that the agent can get
… Step 1
by selecting a particular action a at state st . Then we select
Check-in
the K devices, each corresponding to a different action a, that
lead to the top-K values of Q⇤ (st , a). winit
… Initialize
Reward: We set the reward observed at the end of each
<latexit sha1_base64="kp7B44qYGaAuYOQkTU/bW3qwIGk=">AAAB/3icbVDLSsNAFJ3UV62vqODGzWARXJWkCrosuHFZwT6gDWEymbRDJzNhZqKUmIW/4saFIm79DXf+jZM2C229MMzhnHu5554gYVRpx/m2Kiura+sb1c3a1vbO7p69f9BVIpWYdLBgQvYDpAijnHQ01Yz0E0lQHDDSCybXhd67J1JRwe/0NCFejEacRhQjbSjfPhoGgoVqGpsve8j9jHKq85pv152GMyu4DNwS1EFZbd/+GoYCpzHhGjOk1MB1Eu1lSGqKGclrw1SRBOEJGpGBgRzFRHnZzH8OTw0TwkhI87iGM/b3RIZiVVg0nTHSY7WoFeR/2iDV0ZVnTkpSTTieL4pSBrWARRgwpJJgzaYGICyp8QrxGEmEtYmsCMFdPHkZdJsN97zRvL2ot5plHFVwDE7AGXDBJWiBG9AGHYDBI3gGr+DNerJerHfrY95ascqZQ/CnrM8f1xmWkg==</latexit>

(k)
round t to be rt = ⌅(!t ⌦) 1, t = 1, . . . , T , where !t is the (k) {w 1 |k 2 [N ]}
w1 <latexit sha1_base64="oV5h+YSOUeEd8asuaF/C+sHAtOI=">AAACDHicbVC7TsMwFHXKq5RXgZHFokIqS5UUJBgrsTChItGHlITKcZzWquNEtgOqQj6AhV9hYQAhVj6Ajb/BaTNAy5UsH51zru69x4sZlco0v43S0vLK6lp5vbKxubW9U93d68ooEZh0cMQi0feQJIxy0lFUMdKPBUGhx0jPG1/keu+OCEkjfqMmMXFDNOQ0oBgpTQ2qNSd1vIj5chLqL73PBtZtWh8fZw9jh3JoX7lOpl1mw5wWXARWAWqgqPag+uX4EU5CwhVmSErbMmPlpkgoihnJKk4iSYzwGA2JrSFHIZFuOj0mg0ea8WEQCf24glP2d0eKQplvq50hUiM5r+Xkf5qdqODcTSmPE0U4ng0KEgZVBPNkoE8FwYpNNEBYUL0rxCMkEFY6v4oOwZo/eRF0mw3rpNG8Pq21mkUcZXAADkEdWOAMtMAlaIMOwOARPINX8GY8GS/Gu/Exs5aMomcf/Cnj8wcWOZuR</latexit>

testing accuracy achieved by the global model on the held-out


<latexit sha1_base64="EpqxcOCyXi+HCLi1oBRtP86ZJk4=">AAACAnicbVDLSgMxFM34rPU16krcBItQN2WmCrosuHFZwT6gHYdMJm1DM8mQZJQyDG78FTcuFHHrV7jzb8y0s9DWAyGHc+7l3nuCmFGlHefbWlpeWV1bL22UN7e2d3btvf22EonEpIUFE7IbIEUY5aSlqWakG0uCooCRTjC+yv3OPZGKCn6rJzHxIjTkdEAx0kby7cN+IFioJpH50ofMT93sLq2OT7Oyb1ecmjMFXCRuQSqgQNO3v/qhwElEuMYMKdVznVh7KZKaYkaycj9RJEZ4jIakZyhHEVFeOj0hgydGCeFASPO4hlP1d0eKIpVvaSojpEdq3svF/7xeogeXXkp5nGjC8WzQIGFQC5jnAUMqCdZsYgjCkppdIR4hibA2qeUhuPMnL5J2veae1eo355VGvYijBI7AMagCF1yABrgGTdACGDyCZ/AK3qwn68V6tz5mpUtW0XMA/sD6/AE57Zc/</latexit>

validation set after round t, ⌦ is the target accuracy, and ⌅ … Step 2

is a positive constant that ensures that rt grows exponentially


with the testing accuracy !t . We have rt 2 ( 1, 0], since {Q(w1 , k; ✓)|k 2 [N ]}
(k)
Step 3 DRL Agent
0 6 !t 6 ⌦ 6 1. The federated learning stops when !t = ⌦,
<latexit sha1_base64="g8Y8Csv8fzpWxYd7LUXAQqOf2IY=">AAACKHicbVDLSsNAFJ3Ud31VXboZLEIFKUkVFFxYcONKLNgqNLFMptN2yGQSZm6UEvM5bvwVNyKKdOuXOG2z8HVgmMM593LvPX4suAbbHlmFmdm5+YXFpeLyyuraemljs6WjRFHWpJGI1I1PNBNcsiZwEOwmVoyEvmDXfnA29q/vmNI8klcwjJkXkr7kPU4JGKlTOnXTRsX1I9HVw9B86X3HuU0rwV6W7ePgBH+3XBgwINneAw5cLnH7wnOzTqlsV+0J8F/i5KSMclx2Sq9uN6JJyCRQQbRuO3YMXkoUcCpYVnQTzWJCA9JnbUMlCZn20smhGd41Shf3ImWeBDxRv3ekJNTjZU1lSGCgf3tj8T+vnUDv2Eu5jBNgkk4H9RKBIcLj1HCXK0ZBDA0hVHGzK6YDoggFk23RhOD8PvkvadWqzkG11jgs12t5HItoG+2gCnLQEaqjc3SJmoiiR/SM3tC79WS9WB/WaFpasPKeLfQD1ucX+MGmeg==</latexit>

at which point rt reaches its maximum value of 0. w1


Selection Step 4
<latexit sha1_base64="Ve3TwOH6JUpL+yoyC9qnjbJFOsE=">AAAB+XicbVDLSsNAFL2pr1pfUZdugkVwVZIq6LLgxmUF+4A2hMlk0g6dzISZSaWE/okbF4q49U/c+TdO2iy09cAwh3PuZc6cMGVUadf9tiobm1vbO9Xd2t7+weGRfXzSVSKTmHSwYEL2Q6QIo5x0NNWM9FNJUBIy0gsnd4XfmxKpqOCPepYSP0EjTmOKkTZSYNvDULBIzRJz5U/zwAvsuttwF3DWiVeSOpRoB/bXMBI4SwjXmCGlBp6baj9HUlPMyLw2zBRJEZ6gERkYylFClJ8vks+dC6NETiykOVw7C/X3Ro4SVYQzkwnSY7XqFeJ/3iDT8a2fU55mmnC8fCjOmKOFU9TgRFQSrNnMEIQlNVkdPEYSYW3KqpkSvNUvr5Nus+FdNZoP1/VWs6yjCmdwDpfgwQ204B7a0AEMU3iGV3izcuvFerc+lqMVq9w5hT+wPn8A9smT0Q==</latexit>

The DQN agent is trained to maximize the expectation of …


Reporting
the cumulative discounted reward given by (k)
w2
<latexit sha1_base64="Us9knPROl4HjCMrMBCi80b/6hgs=">AAAB/3icbVDNS8MwHE39nPOrKnjxUhzCvIy2CnocePE4wX3AVkuapltYmpQkVUbtwX/FiwdFvPpvePO/Md160M0HIY/3fj/y8oKEEqls+9tYWl5ZXVuvbFQ3t7Z3ds29/Y7kqUC4jTjlohdAiSlhuK2IoriXCAzjgOJuML4q/O49FpJwdqsmCfZiOGQkIggqLfnm4SDgNJSTWF/ZQ36X1cenue/6Zs1u2FNYi8QpSQ2UaPnm1yDkKI0xU4hCKfuOnSgvg0IRRHFeHaQSJxCN4RD3NWUwxtLLpvlz60QroRVxoQ9T1lT9vZHBWBYR9WQM1UjOe4X4n9dPVXTpZYQlqcIMzR6KUmopbhVlWCERGCk60QQiQXRWC42ggEjpyqq6BGf+y4uk4zacs4Z7c15rumUdFXAEjkEdOOACNME1aIE2QOARPINX8GY8GS/Gu/ExG10yyp0D8AfG5w8oLZYg</latexit>

{w2 |k 2 [K]}
(k)
T
X T
X
<latexit sha1_base64="gaq2JVHuCqEO1R5cW204LFcOocI=">AAACDHicbVDLSgMxFM3UV62vqks3wSLUTZmpgi4LbgQ3FewDZsaSSdM2NJMMSUYp43yAG3/FjQtF3PoB7vwbM+0stPVCyOGcc7n3niBiVGnb/rYKS8srq2vF9dLG5tb2Tnl3r61ELDFpYcGE7AZIEUY5aWmqGelGkqAwYKQTjC8yvXNHpKKC3+hJRPwQDTkdUIy0oXrlipd4gWB9NQnNl9ynvfptUh0fpw9jj3LoXvlealx2zZ4WXARODiogr2av/OX1BY5DwjVmSCnXsSPtJ0hqihlJS16sSITwGA2JayBHIVF+Mj0mhUeG6cOBkOZxDafs744EhSrb1jhDpEdqXsvI/zQ31oNzP6E8ijXheDZoEDOoBcySgX0qCdZsYgDCkppdIR4hibA2+ZVMCM78yYugXa85J7X69WmlUc/jKIIDcAiqwAFnoAEuQRO0AAaP4Bm8gjfryXqx3q2PmbVg5T374E9Znz8TOZuP</latexit>

R= t 1
rt = t 1
(⌅(!t ⌦)
1), … Step 5
t=1 t=1
(k)
Step 3 DRL Agent
where 2 (0, 1] is a factor discounting future rewards. {Q(w2 , k; ✓)|k 2 [N ]}
<latexit sha1_base64="b7G1QBL5RdJMxNq4ngnnNzgYEao=">AAACKHicbVDLSsNAFJ3Ud31VXboZLEIFKUkUFFxYcONKLNgqNLFMptN2yGQSZm6UEvM5bvwVNyKKdOuXOH0sfB0Y5nDOvdx7T5AIrsG2h1ZhZnZufmFxqbi8srq2XtrYbOo4VZQ1aCxidRMQzQSXrAEcBLtJFCNRINh1EJ6N/Os7pjSP5RUMEuZHpCd5l1MCRmqXTr2sXvGCWHT0IDJfdt92b7NKuJfn+zg8wd8tD/oMSL73gEOPS9y68L28XSrbVXsM/Jc4U1JGU1y2S69eJ6ZpxCRQQbRuOXYCfkYUcCpYXvRSzRJCQ9JjLUMliZj2s/GhOd41Sgd3Y2WeBDxWv3dkJNKjZU1lRKCvf3sj8T+vlUL32M+4TFJgkk4GdVOBIcaj1HCHK0ZBDAwhVHGzK6Z9oggFk23RhOD8PvkvabpV56Dq1g/LNXcaxyLaRjuoghx0hGroHF2iBqLoET2jN/RuPVkv1oc1nJQWrGnPFvoB6/ML+nGmew==</latexit>

We now explain the motivations behind the two terms


⌅(!t ⌦) and 1 in rt . The first term, ⌅(!t ⌦) , incentivizes the w2
Selection Step 4
<latexit sha1_base64="SEYN7/Jln5pLaBWDMGCgpymgdf8=">AAAB+XicbVDLSsNAFL2pr1pfUZdugkVwVZIq6LLgxmUF+4A2hMlk0g6dzISZSaWE/okbF4q49U/c+TdO2iy09cAwh3PuZc6cMGVUadf9tiobm1vbO9Xd2t7+weGRfXzSVSKTmHSwYEL2Q6QIo5x0NNWM9FNJUBIy0gsnd4XfmxKpqOCPepYSP0EjTmOKkTZSYNvDULBIzRJz5U/zoBnYdbfhLuCsE68kdSjRDuyvYSRwlhCuMUNKDTw31X6OpKaYkXltmCmSIjxBIzIwlKOEKD9fJJ87F0aJnFhIc7h2FurvjRwlqghnJhOkx2rVK8T/vEGm41s/pzzNNOF4+VCcMUcLp6jBiagkWLOZIQhLarI6eIwkwtqUVTMleKtfXifdZsO7ajQfruutZllHFc7gHC7BgxtowT20oQMYpvAMr/Bm5daL9W59LEcrVrlzCn9gff4A+E2T0g==</latexit>

agent to select devices that achieve a higher test accuracy !t . ⌅ …


Reporting
controls how fast reward rt grows with !t . In general, as ML w3
(k)
(k)
{w3 |k 2 [K]}
training proceeds, the model accuracy will increase at a slower


<latexit sha1_base64="SAJ8pipY4rhTg+TUE4Pv0ogs3CE=">AAAB/3icbVDNS8MwHE3n15xfVcGLl+IQ5mW0m6DHgRePE9wHbLWkabqFpUlJUmXUHvxXvHhQxKv/hjf/G9OtB918EPJ47/cjL8+PKZHKtr+N0srq2vpGebOytb2zu2fuH3QlTwTCHcQpF30fSkwJwx1FFMX9WGAY+RT3/MlV7vfusZCEs1s1jbEbwREjIUFQackzj4Y+p4GcRvpKH7K7tDY5y7ymZ1btuj2DtUycglRBgbZnfg0DjpIIM4UolHLg2LFyUygUQRRnlWEicQzRBI7wQFMGIyzddJY/s061ElghF/owZc3U3xspjGQeUU9GUI3lopeL/3mDRIWXbkpYnCjM0PyhMKGW4lZehhUQgZGiU00gEkRntdAYCoiUrqyiS3AWv7xMuo2606w3bs6rrUZRRxkcgxNQAw64AC1wDdqgAxB4BM/gFbwZT8aL8W58zEdLRrFzCP7A+PwBKbGWIQ==</latexit>

<latexit sha1_base64="oSLjVMi6mYIJ3lLOa0C/IX6AUT0=">AAACDHicbVDLSsNAFJ34rPVVdelmsAh1U5JW0GXBjeCmgn1AEstkMmmHTiZhZqKUmA9w46+4caGIWz/AnX/jpM1CWy8MczjnXO69x4sZlco0v42l5ZXVtfXSRnlza3tnt7K335VRIjDp4IhFou8hSRjlpKOoYqQfC4JCj5GeN77I9d4dEZJG/EZNYuKGaMhpQDFSmhpUqk7qeBHz5STUX3qfDZq3aW18kj2MHcqhfeU6mXaZdXNacBFYBaiCotqDypfjRzgJCVeYISlty4yVmyKhKGYkKzuJJDHCYzQktoYchUS66fSYDB5rxodBJPTjCk7Z3x0pCmW+rXaGSI3kvJaT/2l2ooJzN6U8ThTheDYoSBhUEcyTgT4VBCs20QBhQfWuEI+QQFjp/Mo6BGv+5EXQbdStZr1xfVptNYo4SuAQHIEasMAZaIFL0AYdgMEjeAav4M14Ml6Md+NjZl0yip4D8KeMzx8UzpuQ</latexit>

pace, which means |!t !t 1 | decreases as round t increases. Fig. 2: The federated learning workflow with FAVOR.
Therefore, we use an exponential term to amplify the marginal
accuracy increase as FL progresses into later stages. ⌅ is set Note that the above workflow can be easily tweaked to let each
to 64 in our experiments. selected client upload the model difference (as compared
The second term, 1, encourages the agent to complete to the last version of the local model) instead of the new
training in fewer rounds, because the more rounds it takes, local model itself, without changing the underlying logic. The
the less cumulative reward the agent will receive. FL server can always use the to reconstruct a copy of the
B. Workflow corresponding local model stored on it.
When a device k has new training data, it will request the
Fig. 2 shows how our system FAVOR performs federated
global model weights wt and perform local SGD training for
learning with a DRL agent selecting devices in each round,
one epoch as Step 2, which generates new local model weights
following the steps below: (k) (k)
wt . Then the device pushes wt to the FL server, and the
Step 1: All N eligible devices check in with the FL server. FL server updates the state st to make the DRL agent aware
Step 2: Each device downloads the initial random model of the device with updated data. It should be noted that data
weights winit from the server, performs local SGD on each device remain unchanged during the DRL training to
training for one epoch, and returns the resulting model ensure that the training follows the Markov Decision Process.
(k)
weights {w1 , k 2 [N ]} to the FL server. The overhead introduced is minimum since no extra compu-
Step 3: In round t, where t = 1, 2, . . ., upon receiving the tation/communication overhead is added to devices. The only
uploaded local weights, the corresponding copies of local overhead is that now the FL server needs to store a copy
model weights stored on the server are updated. The DQN of the latest local model weights from each device in order
agent computes Q(st , a; ✓) for all devices a = 1, . . . , N . to form the state st . However, the FL server is deployed in
Step 4: The DQN agent selects K devices corresponding to cloud, where provisions sufficient on-demand computational
the top-K values of Q(st , a; ✓), a = 1, . . . , N . The and storage capacity .
selected K devices download the latest global model
weights wt and perform one epoch of SGD locally to C. Dimension Reduction
(k)
obtain {wt+1 |k 2 [K]}. One issue of the proposed DQN agent is that it uses the
(k)
Step 5: {wt+1 |k 2 [K]} are reported (uploaded) to the server weights of the global model and all local models as the state,
to compute wt+1 based on F EDAVG. Move into round which leads to a large state space. Many deep neural networks
t + 1 and repeat Steps 3-5. (e.g., CNN) have millions of weights, making it challenging
Steps 3-5 will be repeated until completion, e.g., until a to train the DQN agent with such a high dimensional state
target accuracy is reached or after a certain number of rounds. space. To reduce the dimension of the state space, we propose
Reward rt <latexit sha1_base64="ajxj6ouHRhvivI0uKkY+9/oerdw=">AAAB6nicbVBNS8NAEJ3Ur1q/qh69LBbBU0mKoMeCF48VTVtoQ9lsN+3SzSbsToQS+hO8eFDEq7/Im//GbZuDtj4YeLw3w8y8MJXCoOt+O6WNza3tnfJuZW//4PCoenzSNkmmGfdZIhPdDanhUijuo0DJu6nmNA4l74ST27nfeeLaiEQ94jTlQUxHSkSCUbTSgx7goFpz6+4CZJ14BalBgdag+tUfJiyLuUImqTE9z00xyKlGwSSfVfqZ4SllEzriPUsVjbkJ8sWpM3JhlSGJEm1LIVmovydyGhszjUPbGVMcm1VvLv7n9TKMboJcqDRDrthyUZRJggmZ/02GQnOGcmoJZVrYWwkbU00Z2nQqNgRv9eV10m7UPbfu3V/Vmo0ijjKcwTlcggfX0IQ7aIEPDEbwDK/w5kjnxXl3PpatJaeYOYU/cD5/AGNIjc0=</latexit>

0.4 −0.05 0 0.05


Agent
0 Q✓ 0 Environment
0.3
<latexit sha1_base64="R2zmR6IUhVmjQxRdR3wyjLD301k=">AAAB8nicbVBNS8NAEJ3Ur1q/qh69LBbRU0mqoMeCF48t2A9oQ9lsN+3SzSbsToQS+jO8eFDEq7/Gm//GbZuDtj4YeLw3w8y8IJHCoOt+O4WNza3tneJuaW//4PCofHzSNnGqGW+xWMa6G1DDpVC8hQIl7yaa0yiQvBNM7ud+54lrI2L1iNOE+xEdKREKRtFKveYg6+OYI72cDcoVt+ouQNaJl5MK5GgMyl/9YczSiCtkkhrT89wE/YxqFEzyWamfGp5QNqEj3rNU0YgbP1ucPCMXVhmSMNa2FJKF+nsio5Ex0yiwnRHFsVn15uJ/Xi/F8M7PhEpS5IotF4WpJBiT+f9kKDRnKKeWUKaFvZWwMdWUoU2pZEPwVl9eJ+1a1buu1po3lXotj6MIZ3AOV+DBLdThARrQAgYxPMMrvDnovDjvzseyteDkM6fwB87nDyaLkR0=</latexit>

Q
<latexit sha1_base64="ET4tcd8QM9Ni9It4gU8eEcPLoG8=">AAAB8XicbVBNS8NAEN3Ur1q/qh69LBbBU0mqoMeCF48t2A9sQ9lsJ+3SzSbsToQS+i+8eFDEq//Gm//GbZuDtj4YeLw3w8y8IJHCoOt+O4WNza3tneJuaW//4PCofHzSNnGqObR4LGPdDZgBKRS0UKCEbqKBRYGETjC5m/udJ9BGxOoBpwn4ERspEQrO0EqPzUHWxzEgmw3KFbfqLkDXiZeTCsnRGJS/+sOYpxEo5JIZ0/PcBP2MaRRcwqzUTw0kjE/YCHqWKhaB8bPFxTN6YZUhDWNtSyFdqL8nMhYZM40C2xkxHJtVby7+5/VSDG/9TKgkRVB8uShMJcWYzt+nQ6GBo5xawrgW9lbKx0wzjjakkg3BW315nbRrVe+qWmteV+q1PI4iOSPn5JJ45IbUyT1pkBbhRJFn8kreHOO8OO/Ox7K14OQzp+QPnM8fwqaQ7A==</latexit>

Features
0.2

softmax
−0.05


Action at


C1



0.1
<latexit sha1_base64="JiBKSxYyyyX1X4O24iYUGK1+pD4=">AAAB6nicbVBNS8NAEJ3Ur1q/qh69LBbBU0lqQY8FLx4r2g9oQ9lsN+3SzSbsToQS+hO8eFDEq7/Im//GbZuDtj4YeLw3w8y8IJHCoOt+O4WNza3tneJuaW//4PCofHzSNnGqGW+xWMa6G1DDpVC8hQIl7yaa0yiQvBNMbud+54lrI2L1iNOE+xEdKREKRtFKD3SAg3LFrboLkHXi5aQCOZqD8ld/GLM04gqZpMb0PDdBP6MaBZN8VuqnhieUTeiI9yxVNOLGzxanzsiFVYYkjLUthWSh/p7IaGTMNApsZ0RxbFa9ufif10sxvPEzoZIUuWLLRWEqCcZk/jcZCs0ZyqkllGlhbyVsTDVlaNMp2RC81ZfXSbtW9a6qtft6pVHL4yjCGZzDJXhwDQ24gya0gMEInuEV3hzpvDjvzseyteDkM6fwB87nD0qwjcA=</latexit>


0 −0.10 FL server

−0.1 State st <latexit sha1_base64="3dWVQ0y4qkQcV7/MZKItgnpkqAA=">AAAB/XicbVDNS8MwHE3n15xf9ePmJTgEL452CnocePE4wX3AVkqapltYmpQkFWYp/itePCji1f/Dm/+N6daDbj4Iebz3+5GXFySMKu0431ZlZXVtfaO6Wdva3tnds/cPukqkEpMOFkzIfoAUYZSTjqaakX4iCYoDRnrB5Kbwew9EKir4vZ4mxIvRiNOIYqSN5NtHw0CwUE1jc2Uq9zN97ua+XXcazgxwmbglqYMSbd/+GoYCpzHhGjOk1MB1Eu1lSGqKGclrw1SRBOEJGpGBoRzFRHnZLH0OT40SwkhIc7iGM/X3RoZiVQQ0kzHSY7XoFeJ/3iDV0bWXUZ6kmnA8fyhKGdQCFlXAkEqCNZsagrCkJivEYyQR1qawminBXfzyMuk2G+5Fo3l3WW81yzqq4BicgDPggivQAregDToAg0fwDF7Bm/VkvVjv1sd8tGKVO4fgD6zPHw6klY4=</latexit>


1

−0.2
−0.2 0 0.2 0.4 0.6 Fig. 4: The DDQN Agent interacting with the FL server.
C0
Fig. 3: PCA of CNN weights on the MNIST dataset. optimize the agent performance by learning an approximator
Q(s, a; ✓ t ) to the optimal action-value function Q⇤ (s, a).
to apply principle component analysis (PCA) to model weights DDQN adds another value function Q(s, a; ✓ 0 t ) to stabilize
and use the compressed model weights to represent states the action-value function estimation. The idea behind DDQN
instead. is that the network is frozen every M updates. DDQN adds
Specifically, we compute the PCA loading vectors of differ- stability to the action-value evaluation, which is less prone to
ent principal components only based on local model weights “jittering.”
(k)
for t = 1, i.e., {w1 , k 2 [N ]}, obtained in Step 2. In To train the DRL agent, the FL server first performs random
the subsequent rounds t = 2, 3, . . ., such loading vectors device selection to initialize the states. As shown in Fig. 4, the
are reused to obtain first several principal components of states are fed into the one of the double DQNs Q(st , a; ✓ t ).
(k)
{wt , k 2 [N ]}, through linear transformation, without fitting The DQN generates an action a to select a device for the FL
the PCA model again. Therefore, the overhead to compress server. After several rounds of FL training, the DRL agent has
model weights in subsequent rounds is negligibly small. sampled a few action-state pairs, with which the agent learns
We demonstrate the effectiveness of using PCA-compressed to solve the (4) as:
model weights to differentiate between data distributions `t (✓ t ) = (YtDoubleQ Q(st , a; ✓ t ))2 ,
through a simple experiment. Consider a two-layer CNN
model (with 431,080 model weights in total) to be trained where YtDoubleQ is the target at round t defined as
with federated learning on 100 devices, running PyTorch, on
YtDoubleQ : = rt + max Q(st+1 , a; ✓ t ) (5)
the MNIST dataset. Data samples are distributed in the same a
way as the experiment in Sec. II: each client has 80% of = rt + Q(st , argmax Q(st , a; ✓ t ); ✓ 0 t ). (6)
a
its data samples belonging to a dominant class, while the
remaining 20% of its samples have random labels. After five (6) uses two acton-value functions to update YtDoubleQ , in
epochs of local SGD training in Round 1, we project the which ✓ t is the online parameters updated per time step, and
(k)
model weight vectors of the 100 devices, {w1 , k 2 [N ]}, the ✓ 0 t is the frozen parameters to add stability to action-
onto a two-dimensional space of the first and second principal value estimation. The action-value function Q(st+1 , a; ✓ t ) is
components. updated to minimize `t (✓ t ) by gradient descent, i.e.,
As shown in Fig. 3, different shapes (or colors) indicate
local models trained on devices with different dominant labels. ✓ t+1 = ✓ t + ↵(YtDoubleQ Q(st , a; ✓ t ))r✓t Q(st , a; ✓ t ),
For example, all the yellow “+” signs denote the compressed where ↵ is a scalar step size.
local models in Round 1 from those devices with a dominant
label “6”. Even reduced from 431,080 dimensions to only IV. E VALUATION
two dimensions and even at Round 1, we can observe a clus- We have implemented FAVOR with PyTorch in around
tering effecting of local models according to their dominant 2000 lines of code, which we have released as an open-
labels. Therefore, in the evaluation in Sec. IV, we use PCA- source project. With the Python threading library, FAVOR can
compressed model weights as states to enable efficient training simulate a large number of devices with lightweight threads,
of the DQN agent. each running real-world PyTorch models.
We evaluated FAVOR by training popular CNN models
D. Training the Agent with Double DQN on three benchmark datasets: MNIST, FashionMNIST, and
We propose to use the double deep Q-learning network CIFAR-10, with F EDAVG and K-Center as the groups of
(DDQN) to learn the function Q⇤ (st , a). Q-learning provides comparison. We evaluated the accuracy of the trained models
a value estimation for each potential action a at state st , using the testing set from each dataset. Our experimental
based on which devices are selected. However, the original results show that FAVOR can reduce the communication rounds
Q-learning algorithms can be unstable since they indirectly by up to 49% on the MNIST, up to 23% on FashionMNIST,
−20
Total Return 0 0

Total Return

Total Return
−40
−200 −60
−50
−80
−400 −100
−100
−120
0 50 100 150 0 50 100 150 200 0 50 100 150 200
Episode Episode Episode
(a) DRL training on the MNIST. (b) DRL training on FashionMNIST. (c) DRL training on CIFAR-10.

Fig. 5: Training the DRL agent.

and up to 42% on CIFAR-10, compared to the F EDAVG Baseline Favor K-Center FedAvg
algorithm. We briefly describe our methodology and settings σ=1 σ=H
as follows.
Datasets and models: We explored different combinations 100

FL Accuracy (%)

FL Accuracy (%)
98
of hyper-parameters for the CNN models on different datasets
and chose the hyper-parameters leading to the best perfor-
80
mance of F EDAVG. 96
• MNIST. We train a CNN model with two 5⇥5 convo-
lution layers. The first layer has 20 output channels and 60
the second has 50, with each layer followed by 2⇥2 max 94
0 200 400 0 200
pooling. On each device, the batch size is ten and the Communication Round (#) Communication Round (#)
epoch number is five. σ=0.8 σ=0.5
• FashionMNIST. We train a CNN model with two 5⇥5 99 99
convolution layers. The first layer has 16 output channels
FL Accuracy (%)

FL Accuracy (%)
and the second has 32, with each layer followed by 2⇥2 98 98
max pooling. On each device, the batch size is 100 and
the epoch number is five.
97 97
• CIFAR-10. We train a CNN model with two 5⇥5 convo-
lution layers. The first layer has six output channels and
the second having 16, with each layer followed by 2⇥2 96 96
50 100 150 20 40 60
max pooling. On each device, the batch size is 50 and Communication Round (#) Communication Round (#)
the epoch number is five.
Fig. 6: Accuracy v.s. communication rounds on different
Performance metrics: In federated learning, due to the lim-
levels of non-IID MNIST data.
ited computation capacity and network bandwidth of mobile
devices, reducing the number of communication rounds is cru- accuracy ⌦. The total return is the cumulative discounted
cially important. Thus, we use the number of communication reward R obtained in one episode. The target accuracy ⌦ is
rounds as the performance metric of FAVOR. set to 99% for training on the MNIST, 85% for training on
A. Training the DRL agent FashionMNIST, and 55% for training on CIFAR-10.
We train the DRL agent on different datasets with 100
B. Different Levels of Non-IID Data
available devices. The DDQN model in the DRL agent consists
of two two-layer MLP networks, with 512 hidden states. The We investigate different levels of non-IID data to testify the
input size is 10,100, where we have 101 model weights (i.e., efficiency of FAVOR. We include the performance of F EDAVG
the global weights w and the local weights {w(k) |k 2 [100]} and K-Center as a comparison. At each round, the number of
from 100 devices) reduced to 100 dimensions. The output selected devices K is set to 10, as in [1].
size of the second layer is 100. Each output passing through We use to denote the four different levels of non-IID data:
a softmax layer becomes the probability of selecting the = 1.0 indicating that data on each device only belong to one
particular device. We train the DRL agent on an AWS EC2 label, = 0.8 indicating that 80% of the data belong to one
instance p2.2xlarge with a K80 GPU. The DDQN model is label and the remaining 20% data belong to other labels, =
lightweight, and each training iteration takes seconds on GPU. 0.5 indicating that 50% of the data belong to one label and the
Reducing model weights from 431,080 to 100 dimensions remaining 50% data belong to other labels, = 0 indicating
takes minutes by sklearn.decomposition.PCA. the data one each device is IID and = H indicating that
Fig. 5 plots the training progress of the DRL agent on three data on each device evenly belong to two labels. We plot the
FL tasks. An episode starts at the initialization of a federated test-set accuracy v.s. the communication rounds of federated
learning job and ends when the job converges to the target learning on the three benchmarks in Fig. 6, Fig. 7, and Fig. 8.
Baseline Favor KCenter FedAvg Baseline Favor KCenter FedAvg
σ=1 σ=H σ=1 σ=H
100 60
80
FL Accuracy (%)

FL Accuracy (%)

FL Accuracy (%)

FL Accuracy (%)
40
60 40
50
40
20
20
20

0 0 0
0 200 400 0 100 200 300 0 100 200 0 100 200
Communication Round (#) Communication Round (#) Communication Round (#) Communication Round (#)
σ=0.8 σ=0.5 σ=0.8 σ=0.5
85 55
80 50
FL Accuracy (%)

FL Accuracy (%)

FL Accuracy (%)

FL Accuracy (%)
80
40 50
75 60
30 45
70
40 20
65 40
0 20 40 0 10 20 20 40 60 80 20 40 60
Communication Round (#) Communication Round (#) Communication Round (#) Communication Round (#)

Fig. 7: Accuracy v.s. communication rounds on different Fig. 8: Accuracy v.s. communication rounds on different
levels of non-IID FashionMNIST data. levels of non-IID CIFAR-10 data.
winit
Each entry in Table I shows the number of communication 1.5 Local weights Global weights
rounds necessary to achieve a test-set accuracy of 99% for 1.0 w1
the CNN on MNIST, 85% for FashionMNIST, and 55% for
C2

0.5 w2
CIFAR-10. It should be noted that the italic numbers indicate w3
0 w5 w4
that the model converges to a lower accuracy than the test-
set accuracy. The model on MNIST with data distribution of −0.5
1.0 1.5 2.0 2.5 3.0
= 1.0 converges to a test-set accuracy of 96%. The models C1
trained on FashionMNIST with data distribution of = 1.0 (a) Training on MNIST with F EDAVG.
and = H both converge to a test-set accuracy of 80%. The winit
1.5
model trained on CIFAR-10 with data distribution of = 1.0 Local weights Global weights
converges to a test-set accuracy of 50%. 1.0 w1
C2

Our experimental results show there is no guarantee that K- 0.5


w2
Center can always outperform F EDAVG. Devices selected from 0 w4 w3

−0.5
1.0 1.5 2.0 2.5 3.0
MNIST FashionMNIST CIFAR-10 C1

F EDAVG 0 (IID) 55 14 47 (b) Training on MNIST with FAVOR.

F EDAVG 1.0 1517 1811 1714 Fig. 9: PCA on model weights of FL training with MNIST.
K-Center 1.0 1684 2132 1871 w1 , w2 . . . , w5 are the global model weights at Round 1-5.
FAVOR 1.0 1232 1497 1383
F EDAVG H 313 1340 198 the K-Center clusters may introduce more bias to federated
K-Center H 421 1593 188
learning than devices selected randomly by F EDAVG. How-
FAVOR H 272 1134 114
ever, FAVOR has reduced the number of communication rounds
F EDAVG 0.8 221 52 87 by up to 49% on the MNIST, up to 23% on FashionMNIST,
K-Center 0.8 126 62 74
FAVOR 0.8 113 43 61 and up to 42% on CIFAR-10, compared to the F EDAVG
algorithm.
F EDAVG 0.5 59 19 67
K-Center 0.5 67 21 52
C. Device Selection and Weight Updates
FAVOR 0.5 59 16 50
We save the global model weights and local model weights
TABLE I: The number of communication rounds to reach per round when we train the two-layer CNN on MNIST with
a target accuracy for FAVOR v.s. F EDAVG and K-Center. = 0.8. The saved model weights are reduced to two-
K=10 K=50 K=100 and sketched updates to decrease the completion time of
FedAvg
each communication round. Bonawitz et al. [3] developed a
60 secure aggregation protocol that enables a server to perform
Accuracy (%)

the computation of high-dimensional data from the devices,


40
which is widely deployed in production environments [4,
K = 10, round # = 87
20 K = 50, round # = 89 17]. McMahan et al. [1] presented the Federated Averaging
K = 100, round # = 128 (F EDAVG) algorithm that allows devices to perform local
0 20 40 60 80 100 120 140
training of multiple epochs, which further reduces the number
Communication Round (#) of communication rounds by averaging model weights from
KCenter the client devices. Nishio et al. [18] proposed a resource-aware
60
selection algorithm that maximizes the number of participating
Accuracy (%)

40
devices in each round. Sattler et al. [9] proposed a compression
K = 10, round # = 74 framework, Sparse Ternary Compression (STC), that reduces
20 K = 50, round # = 76 the communication costs and remains robust to non-IID data.
K = 100, round # = 128 FAVOR applies DRL to select the best subset of participating
0 20 40 60 80 100 120 140 devices to minimize the number of communication rounds,
Communication Round (#) which is orthogonal to these studies on communication effi-
60
Favor ciency.
Sample efficiency. Unlike centralized machine learning,
Accuracy (%)

40 federated learning performs training on non-IID data on de-


K = 10, round # = 61 vices. Zhao et al. [8] presented a mathematical demonstration
K = 50, round # = 78
20 K = 100, round # = 128
to show that non-IID data reduces the accuracy of federated
learning by a substantial margin, and proposed to push a
0 20 40 60 80 100 120 140 small set of uniform distributed data to participating devices.
Communication Round (#)
Downloading extra data further increases communication cost
Fig. 10: FL training on CIFAR-10 with different levels of and computation workload for the devices. Mehryar et al. [10]
parallelism. proposed an agnostic federated learning framework for fairness
to avoid biases due to non-IID data from the devices. In
dimension vectors by PCA and plotted in Fig. 9. By examining contrast, FAVOR is the first work that counterbalances the bias
the model weights trained with F EDAVG and FAVOR, respec- from different non-IID data by dynamically constructing the
tively, Fig. 9 shows that FAVOR updates the global model with subset of participating devices with DRL techniques.
a larger weight update than F EDAVG in early rounds, which
leads to a faster convergence speed. VI. C ONCLUDING R EMARKS
In this paper, we presented our design and implementation
D. Increasing Parallelism of FAVOR, an experience-driven federated learning framework
We also studied the performance of FAVOR on different that performs active device selection to minimize the com-
numbers of selected devices. The CIFAR-10 dataset is dis- munication rounds of FL training. We argue that non-IID
tributed to 100 devices with the non-IID level at = 0.8. data exacerbates the divergence of model weights on partic-
We apply FAVOR, F EDAVG, and K-Center to train the same ipating devices, and increases the number of communication
CNN on CIFAR-10 separately. The number of selected devices rounds of federated learning by a substantial margin. With
K is set to 10, 50, and 100 to study the performance of both mathematical demonstrations and empirical studies, we
federated learning with different parallelism. Fig. 10 shows found the implicit connection between model weights and the
that increasing the parallelism does not reduce the number distribution of data that the model is trained on. We proposed
of communication rounds and even increases the number of to actively select a specific subset of devices to participate
communication rounds. in training at each round, in order to counterbalance the bias
introduced by non-IID data on each device and to speedup FL
V. R ELATED W ORK training by minimizing the number of communication rounds.
Federated learning allows machine learning models to be In particular, we designed a DRL-based agent that applies
trained on mobile devices in a distributed fashion without the DDQN algorithm to select the best subset of devices to
violating user privacy. Existing studies on federated learning achieve our objectives. We have implemented an open-source
have mostly focused on improving its efficiency. We classify prototype of FAVOR with PyTorch in more than 2K lines
the existing literature into the following two categories: of code and evaluated it with a variety of ML models. An
Communication efficiency. Mobile devices typically have extensive comparison with FL training jobs by F EDAVG has
unstable and expensive connections, and existing works have shown that FL training with FAVOR has reduced the number of
attempted to improve the communication efficiency of fed- communication rounds by up to 49% on the MNIST dataset,
erated learning. Konečnỳ et al. [6, 7] proposed structured up to 23% on FashionMNIST, and up to 42% on CIFAR-10.
R EFERENCES
[1] H. B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y.
Arcas, “Communication-Efficient Learning of Deep Networks from
Decentralized Data,” in Proc. the Artificial Intelligence and Statistics
Conference (AISTATS), 2017.
[2] V. Smith, C.-K. Chiang, M. Sanjabi, and A. S. Talwalkar, “Federated
Multi-Task Learning,” in Proc. the Advances in Neural Information
Processing Systems (NeurIPS), 2017.
[3] K. Bonawitz, V. Ivanov, B. Kreuter, A. Marcedone, H. B. McMahan,
S. Patel, D. Ramage, A. Segal, and K. Seth, “Practical Secure Aggre-
gation for Privacy-Preserving Machine Learning,” in Proc. the ACM
SIGSAC Conference on Computer and Communications Security (CCS),
2017.
[4] K. Bonawitz, H. Eichner, W. Grieskamp, D. Huba et al., “Towards
Federated Learning at Scale: System Design,” in Proc. the Conference
on Systems and Machine Learning (SysML), 2019.
[5] Y. Lin, S. Han, H. Mao, Y. Wang, and W. J. Dally, “Deep Gradient
Compression: Reducing the Communication Bandwidth for Distributed
Training,” in Proc. the International Conference on Learning Represen-
tations (ICLR), 2018.
[6] J. Konečnỳ, B. McMahan, and D. Ramage, “Federated optimization:
Distributed optimization beyond the datacenter,” in Proc. the NIPS
Workshop on Optimization for Machine Learning (OPT), 2015.
[7] J. Konečný, H. B. McMahan, F. X. Yu, P. Richtarik, A. T. Suresh, and
D. Bacon, “Federated Learning: Strategies for Improving Communica-
tion Efficiency,” in Proc. the NIPS Workshop on Private Multi-Party
Machine Learning, 2016.
[8] Y. Zhao, M. Li, L. Lai, N. Suda, D. Civin, and V. Chandra, “Federated
Learning with Non-IID Data,” arXiv preprint arXiv:1806.00582, 2018.
[9] F. Sattler, S. Wiedemann, K.-R. Müller, and W. Samek, “Robust
and Communication-Efficient Federated Learning from Non-IID Data,”
arXiv preprint arXiv:1903.02891, 2019.
[10] M. Mohri, G. Sivek, and A. T. Suresh, “Agnostic Federated Learning,”
in International Conference on Machine Learning (ICML), 2019.
[11] O. Sener and S. Savarese, “Active Learning for Convolutional Neural
Networks: A Core-Set Approach,” in Proc. the International Conference
on Learning Representations (ICLR), 2018.
[12] H. Mao, M. Alizadeh, I. Menache, and S. Kandula, “Resource Man-
agement with Deep Reinforcement Learning,” in Proc. the 15th ACM
Workshop on Hot Topics in Networks. ACM, 2016.
[13] A. Mirhoseini, H. Pham, Q. Le, M. Norouzi, S. Bengio, B. Steiner,
Y. Zhou, N. Kumar, R. Larsen, and J. Dean, “Device Placement Opti-
mization with Reinforcement Learning,” in International Conference on
Machine Learning (ICML), 2017.
[14] R. S. Sutton and A. G. Barto, Reinforcement learning: An introduction.
MIT press Cambridge, 1998.
[15] M. Volodymyr, K. Kavukcuoglu, D. Silver, A. Graves, and
I. Antonoglou, “Playing Atari with Deep Reinforcement Learning,” in
Proc. NIPS Deep Learning Workshop, 2013.
[16] H. Van Hasselt, A. Guez, and D. Silver, “Deep Reinforcement Learning
with Double Q-Learning,” in Proc. the Thirtieth AAAI Conference on
Artificial Intelligence (AAAI), 2016.
[17] T. Yang, G. Andrew, H. Eichner, H. Sun, W. Li, N. Kong, D. Ram-
age, and F. Beaufays, “Applied Federated Learning: Improving Google
Keyboard Query Suggestions,” arXiv preprint arXiv:1812.02903, 2018.
[18] T. Nishio and R. Yonetani, “Client Selection for Federated Learning
with Heterogeneous Resources in Mobile Edge,” in Proc. the IEEE
International Conference on Communications (ICC), 2019.

You might also like