Professional Documents
Culture Documents
of-vocabulary (OOV) words under federated tocorrected even when typed correctly (Ouyang
learning settings, for the purpose of expand- et al., 2017). Moreover, for latency and relia-
ing the vocabulary of a virtual keyboard for bility reasons, mobile keyboards models run on-
smartphones without exporting sensitive text device. This means that the vocabulary support-
to servers. High-frequency words can be ing the models are intrinsically limited in size, e.g.
sampled from the trained generative model
to a couple hundred thousand words per language.
by drawing from the joint posterior directly.
We study the feasibility of the approach in
It is therefore crucial to discover and include the
two settings: (1) using simulated federated most useful words in this rather short vocabulary
learning on a publicly available non-IID per- list. Words not in the vocabulary are often called
user dataset from a popular social networking “out-of-vocabulary” (OOV) words. Note that the
website, (2) using federated learning on data concept of vocabulary is not limited to mobile key-
hosted on user mobile devices. The model boards. Other natural language applications, such
achieves good recall and precision compared as for example neural machine translation (NMT),
to ground-truth OOV words in setting (1).
rely on a vocabulary to encode words during end-
With (2) we demonstrate the practicality of
this approach by showing that we can learn to-end training. Learning OOV words and their
meaningful OOV words with good character- rankings is thus a fairly generic technology need.
level prediction accuracy and cross entropy The focus of our work is learning OOV words
loss. in a environment without transmitting and storing
sensitive user content on centralized servers. Pri-
1 Introduction
vacy is easier to ensure when words are learned
Gboard — the Google keyboard — is a virtual on device. Our work builds upon recent advances
keyboard for touch screen mobile devices with in Federated Learning (FL) (Konečnỳ et al., 2016;
support for more than 600 language varieties and Yang et al., 2018; Hard et al., 2018), a decen-
over 1 billion installs as of 2018. Gboard provides tralized learning technique to train models such
a variety of input features, including tap and word- as neural networks on users’ devices, uploading
gesture typing, auto-correction, word completion only ephemeral model updates to the server for
and next word prediction. aggregation, and leaving the users’ raw data on
Learning frequently typed words from user- their device. Approaches based on hash maps,
generated data is an important component to the count sketches (Charikar et al., 2002), or tries (Gu-
development of a mobile keyboard. Example us- nasinghe and Alahakoon, 2012) require significant
age includes incorporating new trending words adaptation to be able to run on FL settings (Zhu
(celebrity names, pop culture words, etc.) as they et al., 2019). Our work builds upon a very gen-
arise, or simply compensating for omissions in eral FL framework for neural networks, where we
the initial keyboard implementation, especially for train a federated character-based recurrent neural
low-resource languages. network (RNN) on device. OOV words are Monte
The list of words is often referred to as the “vo- Carlo sampled on servers during interference (de-
cabulary”, and may or may not be hand-curated. tails in 2.1).
While Federated Learning removes the need to
upload raw user material — here OOV words —
to the server, the privacy risk of unintended mem-
orization still exists (as demonstrated in (Carlini
et al., 2018)). Such risk can be mitigated, usually
with some accuracy cost, using techniques includ-
ing differential privacy (McMahan et al., 2018).
Exploring these trade-offs is beyond the scope of
Figure 1: Monte Carlo sampling of OOV words from
this paper.
the LSTM model.
The proposed approach relies on a learned prob-
abilistic model, and may therefore not generate the
words it is trained from faithfully. It may “day- forget and input decisions together, thereby reduc-
dream” OOV words, that is, come up with charac- ing the number of parameters by 25%. The pro-
ter sequences never seen in the training data. And jection layer that is located at the output gate gen-
it may not be able to regenerate some words that erates the hidden state ht to reduce the dimension
would be interesting to learn. The key to demon- and speed-up training. Peephole connections let
strate the practicality of the approach is to answer: the gate layers look at the cell state. We use multi-
(1) How frequently do daydreamed words occur? layer LSTMs (Sutskever et al., 2014) to increase
(2) How well does the sampled distribution repre- the representation power of the model. The loss
sent the true word frequencies in the dataset? In function of the LSTM model is computed as the
response to these questions, our contributions in- cross entropy (CE) between the predicted charac-
clude the following: ter distribution ybit at each step t and a one-hot en-
1. We train the proposed LSTM model on a coding of the current true character yit , summed
public Reddit comments dataset with a simulated over the words in the dataset.
FL environment. The Reddit dataset contains user During the inference stage, the sampling is
ID information for each entry (user’s comments or based on the chain rule of probability,
posts), which can be used to mimic the process of
−1
TY
FL to learn from each client’s local data. The sim-
Pr(x0:T −1 ) = Pr(xt |x0:t−1 ), (1)
ulated FL model is able to achieve 90.56% preci-
t=0
sion and 81.22% recall for top 105 unique words,
based on a total number of 108 parallel indepen- where x0:T −1 is a sequence with arbitrary length
dent samplings. T . RNNs estimate Pr(xt |x0:t−1 ) as
2. We show the feasibility of training LSTM
models from daily Gboard user data with real, on- Pfθ (xt |x0:t−1 ) = Pr (xt |fθ (x0:t−1 )), (2)
device, FL settings. The FL model is able to reach
55.8% character-level top-3 prediction accuracy where fθ can be perceived as a function mapping
and 2.35 cross entropy on users’ on-device data. input sequence x0:t−1 to the hidden state of the
We further show that the top sampled words are RNN. The sampling process starts with multiple
very meaningful and are able to capture words we threads, each beginning with start of the word to-
know to be trending in the news at the time of the ken. At step t each thread generates a random in-
experiments. dex based on Pfθ (xt |x0:t−1 ). This is done itera-
tively until it hits end of the word tokens. This
2 Method parallel sampling approach avoids the dependency
between each sampling thread, which might oc-
2.1 LSTM Modeling cur in beam search or shortest path search sam-
LSTM models (Hochreiter and Schmidhuber, pling (Carlini et al., 2018),
1997) have been successfully used in a variety of
sequence processing tasks. In this work we use a 2.2 Federated Learning
variant of LSTM with a Coupled Input and Forget Federated learning (FL) is a new paradigm of ma-
Gate (CIFG) (Greff et al., 2017), peephole con- chine learning that keeps training data localized on
nections (Gers and Schmidhuber, 2000) and a pro- mobile devices and never collects them centrally.
jection layer (Sak et al., 2014). CIFG couples the FL is shown to be robust to unbalanced or non-IID
(independent and identically distributed) data dis- FLSGD
S FLSGD
L FLML
tributions (McMahan et al., 2016; Konečnỳ et al., m 0.0 0.0 0.9
2016). It learns a shared model by aggregating Bs 64 64 64
locally-computed gradients. FL is especially use- S 0.0 0.0 6.0∗
ful for OOV word learning, since OOV words typ- η 1.0 1.0 1.0
ically include sensitive user content. FL avoids ηclient 0.1 0.1 0.5
the need for transmitting and storing such content Nr 2 3 3
on centralized servers. FL can also be combined N 256 256 256
with privacy-preserving techniques like secure ag- d 16 128 128
gregation (Bonawitz et al., 2016) and differential Np 64 128 128
privacy (McMahan et al., 2018) to offer stronger
Table 1: Hyper-parameters for three different FL set-
privacy guarantees to users.
tings. *Here adaptive clipping with a 0.99 percentile
We use the FederatedAveraging algo- ratio is used together with server-side L2 clipping.
rithm presented in (McMahan et al., 2016) to com-
k
bine client updates wt+1 ← ClientUpdate(k, wt )
after each round t + 1 of local training to pro- users’ raw input data is inaccessible by design. We
duce a new global round with weight wt+1 ← show that the model is able to converge to good
P K nk k t CE loss and top-K character-level prediction ac-
k=1 n wt+1 , where w is the global model
k
weight at round t and wt+1 is the model weight curacy. Unlike PR, CE and accuracy does not
for each participating device k, and nk and n are need computationally-intensive sampling and can
number of data points on client k and the total sum be computed on the fly during training.
of all K devices. Adaptive L2 -norm clipping is
performed on each client’s gradient, as it is found 3.2 Model parameters
to improve the robustness of model convergence. Table 1 shows three different model hyper-
parameters used in our federated experiments. Nr ,
3 Experiments η, m, and Bs refers to number of RNN lay-
ers, server side learning rate, momentum, and
In this section, we show the details of our ex- batch size, respectively. FLSGD and FLSGD ap-
S L
periments in two different settings: (1) simulated plies standard SGD without momentum or clip-
FL with the public Reddit dataset (Al-Rfou et al., ping. They vary in the LSTM model architec-
2016), (2) on-device FL where user text never tures, where model FLSGD contains 216K parame-
S
leaves their device. In all of our experiments, we ters and FLL contains 758K parameters. FLM
SGD
L
use a filter to exclude “invalid” OOV patterns that has the same model architecture as FLSGD and
L
represent things we don’t want the model to learn. further applies adaptive gradient clipping, com-
The filtering excludes words that: (1) start with bined with Nesterov accelerated gradient (Nes-
non-ASCII alphabetical characters (mostly emoji), terov, 1983) and a momentum hyper-parameter of
(2) contain numbers (street or telephone numbers), 0.9. Unlike server-based training, FL uses a client-
(3) contain repetitive patterns like “hahaha” or side learning rate ηclient with local min-batch up-
“yesss”, (4) are no longer than 2 in length. In date, in addition to the server-side learning rate η.
both simulated and on-device FL setting, training FLM SGD
L converges with ηclient = 0.5, while FLS
and evaluation are defined by separate computa- SGD
and FLL diverges with such a high value.
tion tasks that sample users’ data independently.
3.3 Federated learning settings
3.1 Evaluation metrics
For both simulated and on-device FL, 200 client
In OOV word learning, we are interested in how updates are required to close each round. For each
many words are either missing or daydreamed training round we set number of local epochs as 1.
from the model sampling. In simulated FL, we
have access to the datasets and know the ground 3.3.1 Simulated FL on Reddit data
truth OOV words and their frequencies. Thus, The Reddit conversation corpus is a publicly-
the quality of the model can be evaluated based accessible and fairly large dataset. It includes di-
on precision and recall (PR). For the on-device verse topics from 300,000 sub-forums (Al-Rfou
FL setting, it is not possible to compute PR since et al., 2016). The data are organized as “Reddit
yea 0.0050 yea 0.0057
upvote 0.0033 upvote 0.0040
downvoted 0.0030 downvoted 0.0033
alot 0.0026 alot 0.0029
downvote 0.0023 downvote 0.0026
downvotes 0.0018 downvotes 0.0022
upvotes 0.0016 upvotes 0.0021
wp-content 0.0016 op’s 0.0019
op’s 0.0015 wp-content 0.0017
restrict sr 0.0014 redditors 0.0016
Figure 2: Precision vs. top-K uniquely sampled words
in simulated FL experiments. Table 2: Top 10 OOV words and their probabilities
from ground truth (left) vs. samples (right) from simu-
lated federated model FLM L trained on Reddit data
4 Results
posts”, where users can comment on each other’s
comments indefinitely and user IDs are kept. The 4.1 FL simulation on Reddit data
data contain 133 million posts from 326 thou-
sand different sub-forums, consisting of 2.1 billion During training, FLM L converges faster and better
SGD SGD
comments. Unlike Twitter, Reddit message sizes than FLS and FLL in both CE loss and predic-
are unlimited. Similar to work in (Al-Rfou et al., tion accuracy, with 66.3% and 1.887 for top-3 ac-
2016), Reddit posts that have more than 1000 curacy and CE loss, respectively. The larger model
comments are excluded. The final filtered data does not lead to significant gains. Momentum and
used for FL simulation contain 492 million com- adaptive clipping lead to faster convergence and
ments coming from 763 thousand unique users. more stable performance.
There are 259 million filtered OOV words, among Table 2 shows the top 10 OOV words with their
which 19 million are unique. As user-tagged data occurring probability in the Reddit dataset (left)
is needed to do FL experiments, Reddit posts are and the generative model (right). The model gen-
sharded in FL simulations based on user ID to erally learns the probability of word occurrences,
mimic the local client cache scenario. where the absolute value and relative rank for top
words are very close to the ground truth.
3.3.2 FL on client device data In figures 2 and 3, PR are computed using the
In the on-device FL setting, the original raw data model checkpoint of FLM 8
L with 10 after 3000
content is not accessible for human inspection rounds, giving 90.56% and 81.22%, respectively.
since it remains stored in local caches on client PR rate is plotted against the top K unique words
devices. In order to participate in a round of FL, (x-axis). Curves shown in red, green, and blue
client devices must: (1) have at least 2G of mem- represent 106 , 107 and 108 number of indepen-
ory, (2) be charging, (3) be connected to an un- dent samplings, respectively. Both the PR rate im-
metered network, and (4) be idle. In this study prove significantly when the amount of sampling
we experiment FL on three languages: (1) Amer- is increased from 106 to 108 . We also observe in-
ican English (en US), (2) Brazilian Portuguese creased PR over the training rounds.
en US in ID pt BR
rlly noh pqp
srry yws pfv
abbr. lmaoo gtw rlx
adip tlpn sdds
yea gimana nois
tommorow duwe perai
slang/typo gunna clan fuder
sumthin beb ein
ewwww siim tadii
Figure 4: Cross entropy loss on live client evaluation repetitive hahah rsrs lahh
data for three different FL settings for en US. youu oww lohh
yeahh diaa kuyy
muertos block buenas
quiero contract fake
foreign
bangaram cream the
kavanaugh
khabib N.A. N.A.
names
cardi