Professional Documents
Culture Documents
Feature Selection With Gradient Descent On Two-Layer Networks in Low-Rotation Regimes
Feature Selection With Gradient Descent On Two-Layer Networks in Low-Rotation Regimes
Abstract
This work establishes low test error of gradient flow (GF) and stochastic gradient descent
(SGD) on two-layer ReLU networks with standard initialization, in three regimes where key sets
of weights rotate little (either naturally due to GF and SGD, or due to an artificial constraint),
and making use of margins as the core analytic√technique. The first regime is near initialization,
specifically until the weights have moved by O( m), where m denotes the network width, which
is in sharp contrast to the O(1) weight motion allowed by the Neural Tangent Kernel (NTK);
here it is shown that GF and SGD only need a network width and number of samples inversely
proportional to the NTK margin, and moreover that GF attains at least the NTK margin itself,
which suffices to establish escape from bad KKT points of the margin objective, whereas prior
work could only establish nondecreasing but arbitrarily small margins. The second regime is
the Neural Collapse (NC) setting, where data lies in extremely-well-separated groups, and the
sample complexity scales with the number of groups; here the contribution over prior work is
an analysis of the entire GF trajectory from initialization. Lastly, if the inner layer weights are
constrained to change in norm only and can not rotate, then GF with large widths achieves
globally maximal margins, and its sample complexity scales with their inverse; this is in contrast
to prior work, which required infinite width and a tricky dual convergence assumption. As purely
technical contributions, this work develops a variety of potential functions and other tools which
will hopefully aid future work.
1 Introduction
This work studies standard descent methods on two-layer networks of width m, specifically stochastic
gradient descent (SGD) and gradient flow (GF) on the logistic and exponential losses, with a goal of
establishing good test test error (low sample complexity) via margin theory (see Section 1.2 for full
details on the setup). The analysis considers three settings where weights rotate little, but rather
their magnitudes change a great deal, whereby the networks select or emphasize good features.
The motivation and context for this analysis is as follows. While the standard promise of deep
learning is to provide automatic feature learning, by contrast, standard optimization-based analyses
typically utilize the Neural Tangent Kernel (NTK) (Jacot et al., 2018), which suffices to establish
low training error (Du et al., 2018; Allen-Zhu et al., 2018; Zou et al., 2018), and even low testing
error (Arora et al., 2019; Li and Liang, 2018; Ji and Telgarsky, 2020b), but ultimately the NTK is
equivalent to a linear predictor over fixed features given by a Taylor expansion, and fails to achieve
the aforementioned feature learning promise.
Due to this gap, extensive mathematical effort has gone into the study of feature learning, where
the feature maps utilized in deep learning (e.g., those given by Taylor expansions) may change
drastically, as occurs in practice. The promise here is huge: even simple tasks such as 2-sparse parity
∗
<mjt@illinois.edu>; comments greatly appreciated.
1
(learning the parity of two bits in d dimensions) require d2 / samples for a test error of > 0 within
the NTK or any other kernel regime, but it is possible for ReLU networks beyond the NTK regime
to achieve an improved sample complexity of d/ (Wei et al., 2018). This point will be discussed in
detail throughout this introduction, though a key point is that prior work generally changes the
algorithm to achieve good sample complexity; as a brief summary, many analyses either add noise
to the training process (Shi et al., 2022; Wei et al., 2018), or train the first layer for only one step
and thereafter train only the second (Daniely and Malach, 2020; Abbe et al., 2022; Barak et al.,
2022), and lastly a few others idealize the network and make strong assumptions to show that global
margin maximization occurs, giving the desired d/ sample complexity (Chizat and Bach, 2020). As
will be expanded upon briefly, and indeed is depicted in Figure 1 and summarized in Table 1, in the
case of the aforementioned 2-sparse parity, the present work will achieve the optimal kernel sample
complexity d2 / with a standard SGD (cf. Theorem 2.1), and a beyond-kernel sample complexity of
d/ with an inefficient and somewhat simplified GF (cf. Theorem 3.3).
The technical approach in the present work is to build upon the rich theory of margins in
machine learning (Boser et al., 1992; Schapire and Freund, 2012). While the classical theory provides
that descent methods on linear models can maximize margins (Zhang and Yu, 2005; Telgarsky,
2013; Soudry et al., 2017), recent work in deep networks has revealed that GF eventually converges
monotonically to local optima of the margin function (Lyu and Li, 2019; Ji and Telgarsky, 2020a),
and even that, under a variety of conditions, margins may be globally maximized (Chizat and Bach,
2020). As mentioned above, global margin maximization is sufficient to beat the NTK sample
complexity for many well-studied problems, including 2-sparse parity (Wei et al., 2018).
The contributions of this work fall into roughly two categories.
1. Section 2: Margins at least as good as the NTK. The first set of results may sound a
bit unambitious, but already constitute a variety of improvements over prior work. Throughout
these contributions, let γntk denote the NTK margin, a quantity defined and discussed in detail
in Section 1.2.
2. Section 3: Margins surpassing the NTK. The remaining margin results are able to beat
the preceding NTK margins, however they pay a large price: the network width is exponentially
large, and the method is GF, not the discrete-time algorithm SGD. Even so, the results already
require new proof techniques, and again hopefully form the basis for improvements in future
work.
2
(a) Theorem 3.2: the first result is for Neural Collapse (NC), a setting where data lies in k
tightly packed clusters which are extremely far apart; in fact, each cluster is correctly
labeled by some vector βk , and no other data may fall in the halfspace defined by βk .
Despite this, it is an interesting setting with a variety of empirical and theoretical works,
for instance showing various asymptotic stability properties (Papyan et al., 2020); the
contribution here is to analyze the entire GF trajectory from initialization, and show that
it achieves a margin (and sample complexity) at least as good as the sparse ReLU network
with one ReLU pointing at each cluster.
(b) Theorem 3.3: under a further idealization that the inner weights are constrained to change
in norm only (meaning the inner weights can not rotate), global margin maximization
occurs. As mentioned before, this suffices to establish sample complexity d/ in 2-sparse
parity, which beats the d2 / of kernel methods. As a brief comparison to prior work, similar
idealized networks were proposed by Woodworth et al. (2020), however a full trajectory
analysis and relationship to global margin maximization under no further assumptions
was left open. Additionally, without the no-rotation constraint, Chizat and Bach (2020)
showed that infinite-width networks under a tricky dual convergence assumption also
globally maximize margins; the proof technique here is rooted in a desire to drop this dual
convergence assumption, a point which will be discussed extensively in Section 3.
(c) New potential functions: both Theorem 3.2 and Theorem 3.3 are proved via new potential
arguments which hopefully aid future work.
This introduction will close with further related work (cf. Section 1.1), notation (cf. Section 1.2),
and detailed definitions and estimates of the various margin notions (cf. Section 1.2). Margin
maximization at least as good as the NTK is presented in Section 2, margins beyond the NTK are
in Section 3, and open problems and concluding remarks appear in Section 4. Proofs appear in the
appendices.
3
(a) Trajectories (|aj |vj )m m
j=1 with m = 16 across time. (b) Trajectories (|aj |vj )j=1 with m = 256 across time.
1.0 1.0
0.8 0.8
0.6 0.6
0.4 0.4
0.2 0.2
0.0 0.0
(c) Sorted per-node rotations for m ∈ {16, 256}. (d) Sorted per-node norms for m ∈ {16, 256}.
Figure 1. Two runs of approximate GF (GD with small step size) on 2-sparse parity with d = 20
and n = 64 and m ∈ {16, √ 256}: specifically, a new data point x is sampled uniformly from the
d
hypercube corners {±1/ d} , and the label y for simplicity is the parity of the first two bits,
meaning y = dx1 x2 . The first two dimensions of the data and the trajectories of the m nodes
are depicted in Figures 1a andD1b for m ∈ {16,E 256}. Figure 1c shows per-node rotation over
m
jv (0) v (t)
j
time in sorter order, meaning kvj (0)k , kvj (t)k . Figure 1d shows per-node relative norms in
m j=1
ka v k
sorted order, meaning maxk jkajk vk k . Due to projection onto the first two coordinates, the 64
j=1
data points land in 4 clusters, and are colored red if negatively labeled, blue if positively labeled.
Individual nodes are colored red or blue based on the sign of their output weights. The shades of
red and blue darken for nodes whose total norm is larger. Comparing Figure 1c and Figure 1d, large
norm and large rotation go together. While Figure 1a is highly noisy, the behavior of Figures 1b
to 1d was highly regular across training. The trend of larger width leading to less rotation and
greater norm imbalance was consistent across runs. This somewhat justifies the lack of rotation with
exponentially large width in Theorem 3.3, and the overall terminology choice of feature selection.
4
Reference Algorithm Technique m n t
(Ji and Telgarsky, 2020b) SGD perceptron d8 d2 / d2 /
Theorem 2.1 SGD perceptron d2 d2 / d2 /
(Barak et al., 2022) 2-phase SGD correlation O(1) d4 /2 d2 /2
(Wei et al., 2018) WF+noise margin ∞ d/ ∞
(Chizat and Bach, 2020) WF margin ∞ d/ ∞
Theorem 3.3 scalar GF margin dd d/ ∞
Feature learning. There are many works in feature learning; a few also carrying explicit guar-
antees on 2-sparse parity are summarized in Table 1. An early work with high relevance to the
present work is (Wei et al., 2018), which in addition to establishing that the NTK requires Ω(d2 /)
samples whereas O(d/) suffice for the global maximum margin solution, also provided a noisy
Wasserstein Flow (WF) analysis which achieved the maximum margin solution, albeit using noise,
infinite width, and continuous time to aid in local search, The global maximum margin work of
Chizat and Bach (2020) was mentioned before, and will be discussed in Section 3. The work of
Barak et al. (2022) uses a two phase algorithm: the first step has a large minibatch and effectively
learns the support of the parity in an unsupervised manner, and thereafter only the second layer is
trained, a convex problem which is able to identify the signs within the parity; as in Table 1, this
work stands alone in terms of the narrow width it can handle. The work of (Abbe et al., 2022)
uses a similar two-phase approach, and while it can not learn precisely the parity, it can learn an
interesting class of “staircase” functions, and presents many valuable proof techniques. Another
work which operates in two phases and can learn an interesting class of functions which excludes
the parity (specifically due to a Jacobian condition) is the recent work of (Damian et al., 2022).
Other interesting feature learning works are (Shi et al., 2022; Bai and Lee, 2019).
1.2 Notation
Architecture and initialization. With the exception P of Theorem 3.3, the architecture will be a
2-layer ReLU network of the form x 7→ F (x; w) = j aj σ(vjT x) = aT σ(V x), where σ(z) = max{0, z}
is the ReLU, and where a ∈ Rm and V ∈ Rm×d have initialization roughly matching the variances
√
of pytorch default √ initialization, meaning a ∼ N m / m (m iid Gaussians with variance 1/m)
and V ∼ Nm×d / d (m × d iid Gaussians with variance 1/d). These parameters (a, V ) will be
collected into a tuple W = (a, V ) ∈ Rm × Rm×d ≡ Rm×(d+1) , and for convenience per-node tuples
wj = (aj , vj ) ∈ R × Rd will often be used as well.
Given a pair (x, y) with x ∈ Rd and y ∈ {±1}, the prediction or unnormalized margin mapping is
p(x, y; W ) = yF (x; W ) = yaT σ(V x); when examples ((xi , yi ))ni=1 are available, a simplified notation
pi (W ) := p(xi , yi ; W ) is often used, and moreover define a single-node variant pi (wj ) := yi aj σ(vjT xi ).
Throughout this work, kxk ≤ 1.
It will often be useful to consider normalized parameters within proofs: define vej := vj /kvj k
and eaj := sgn(aj ) := aj /|aj |.
5
SGD and GF. The loss function ` in this work will always be either the exponential loss
`exp (z) := exp(−z), or the logistic loss `log (z) := ln(1 + exp(−z)); the corresponding empirical risk
Rb is
n
1X
R(p(W
b )) := `(pi (W )),
n
i=1
Wi+1 := Wi − η ∂ˆW `(pi (Wi )), stochastic gradient descent (SGD), (1.1)
d
Ẇt := Wt = −∂¯W R(p(W b t )), gradient flow (GF), (1.2)
dt
where ∂ˆ and ∂¯ are appropriate generalizations of subgradients for the present nonsmooth nonconvex
setting, detailed as follows. For SGD, ∂ˆ will denote any
valid P
element of the Clarke differential (i.e.,
a measurable selection); for example, ∂F ˆ (x; W ) = σ(V x), 0 T T
j aj σ (vj x)ej x , where ej denotes
the jth standard basis vector, and σ 0 (vjT xi ) ∈ [0, 1] is chosen in some consistent and measurable
way. This approach is able to model what happens in a standard software library such as pytorch.
For GF, ∂¯ will denote the unique minimum norm element of the Clarke differential; typically, GF
is defined as a differential inclusion, which agrees with this minimum norm Clarke flow almost
everywhere, but here the minimum norm element is used to define the flow purely for fun. Since at
this point there is a vast array of ReLU network literature using Clarke differentials, it is merely
asserted here that chain rules and other basic properties work as required, and ∂ˆ and ∂¯ are used
essentially as gradients, and details are deferred to the detailed discussions in (Lyu and Li, 2019; Ji
and Telgarsky, 2020a; Lyu et al., 2021).
Many expressions will have a fixed time index t (e.g., Wt ), but to reduce clutter, it will either
appear as a subscript, or as a function (e.g., W (t)), or even simply dropped entirely.
Margins and dual variables. Firstly, for convenience, often `i (W ) = `(pi (W )) and `0i (W ) :=
`0 (pi (W )) are used; since `0i is negative, often |`0i | is written.
The start of developing margins and dual variables is the observation that F and pi are 2-
homogeneous in W , meaning F (x; cW ) = caT σ(cV x) = c2 F (x; w) for any x ∈ Rd and c ≥ 0 (and
pi (cW ) = c2 pi (W )). It follows that F (x; W ) = kW k2 F (x; W/kW k), and thus F and pi scale
quadratically in kW k, and it makes sense to define a normalized prediction mapping pei and margin
γ as
pi (W ) mini pi (W )
pei (W ) := , γ(W ) := = min pei (W ).
kW k2 kW k2 i
Due to nonsmoothness, γ can be hard to work with, thus, following Lyu and Li (2019), define the
smoothed margin γ
e and the normalized smoothed margin γ̊ as
X γ
e(W )
e(W ) := `−1 nR(W
γ b ) = `−1 `(pi (W )) , γ̊(W ) := ,
kW k2
i
6
When working with gradients of losses and of smoothed margins, it will be convenient to define
dual variables (qi )ni=1
P P
−1
X ∇p i `(pi ) ∇p i `(p)
q := ∇p ` `(pi ) = 0 −1 P
= 0 ,
` (` ( i `(pi ))) ` (e
γ (p))
i
Margin assumptions. It is easiest to first state the global margin definition, used in Section 3.
The first version is stated on a finite sample.
Assumption 1.1. For given examples ((xi , yi ))ni=1 , there exists a scalar γgl > 0 and parameters
((αk , βk ))rk=1 with αj ∈ R with kαk1 = 1 and βk ∈ Rd with kβk k2 = 1, where
r
X
min yi αk σ(βkT xi ) ≥ γgl . ♦
i
k=1
This definition can also be extended to hold over a distribution, whereby it holds for all finite
samples almost surely.
Assumption 1.2. There exists γgl > 0 and parameters ((αk , βk ))rk=1 so that Assumption 1.1 holds
almost surely for any ((xi , yi ))ni=1 drawn iid for any n from the underlying distribution (with the
same γgl and ((αk , βk ))rk=1 ). ♦
Assumptions 1.1 and 1.2 are used in this work to approximate the best possible margin amongst
all networks of any width; in particular, the formalism here is a simplification of the max-min
characterization given in (Chizat and Bach, 2020, Proposition 12, optimality conditions). As will
be seen shortly in Proposition 1.6, this definition suffices to achieve sample complexity d/ with
gradient flow on 2-sparse parity, as in Table 1. Lastly, while the appearance of an `1 norm may be a
surprise, it can be seen to naturally arise from 2-homogeneity:
P
pi (W ) j pi (wj )
X kwj k2
= = α j p i (w j ), where α j := , thus kαk1 = 1.
kW k2 kW k2 kW k2
e
j
Next comes the definition of the NTK margin γntk . As a consequence of 2-homogeneity,
W, ∂¯W pi (W ) = 2pi (W ), which can be interpreted as a linear predictor with weights W and
features ∂¯W pi (W )/2. Decoupling the weights and features gives W, ∂¯W pi (W0 ) , where W0 is at
initialization, and W is some other choice. To get the full definition from here, W is replaced with
an infinite-width object via expectations; similar definitions originated with the work of Nitanda
and Suzuki (2019), and were then used in (Ji and Telgarsky, 2020b; Chen et al., 2019).
Assumption 1.3. For given examples ((xi , yi ))ni=1 , there exists a scalar γntk > 0 and a mapping
θ : Rd+1 → Rd+1 with θ(w) = 0 whenever kwk ≥ 2 so that
7
Similarly to the global maximum margin definition, Assumption 1.3 can be restated over the
distribution.
Assumption 1.4. There exists γntk > 0 and a transport θ : Rd+1 → Rd+1 so that Assumption 1.3
holds almost surely for any ((xi , yi ))ni=1 drawn iid for any n from the underlying distribution (with
the same γntk and θ). ♦
Before closing this section, here are a few estimates of γntk and γgl . Firstly, both function classes
are universal approximators, and thus the assumption can be made to work for any prediction
problem with pure conditional probabilities (Ji et al., 2020). Next, as a warmup, note the following
estimates of γntk and γgl , for linear predictors, with an added estimate of showing the value of
working with both layers in the definition of γntk .
Proposition 1.5. . Let examples ((xi , yi ))ni=1 be given, and suppose they are linearly separable:
there exists kūk = 1 and γ
b > 0 with mini yi xTi ū ≥ γ
b.
γ
1. Choosing θ(a, v) := 0, sgn(a)ū · 1[k(a, v)k ≤ 2], then Assumption 1.3 holds with γntk ≥ 32b
.
2. Choosing θ(a, v) := sgn(ūT v), 0 · 1[k(a, v)k ≤ 2], then Assumption 1.3 holds with γntk ≥ 16γb√d .
3. Choosing α = (1/2, −1/2) and β = (ū, −ū), then Assumption 1.1 holds with γgl ≥ γb2 .
Margin estimates
√ for 2-sparse parity are as follows; the key is that γntk scales with 1/d whereas
γgl scales with 1/ d, which suffices to yield the separations in Table 1. The bound on γntk is also
necessarily an upper bound, since otherwise the estimates due to Ji and Telgarsky (2020b), which
are within the NTK regime, would beat NTK lower bounds (Wei et al., 2018).
√
Proposition 1.6. Suppose 2-sparse parity data, meaning inputs are supported on Hd := {±1/ d}d ,
and for any x ∈ Hd , the label is the product of two fixed coordinates dxa xb with a 6= b.
1
1. Assumption 1.4 holds with γntk ≥ 50d .
Theorem 2.1. Suppose the data distribution satisfies Assumption 1.4 for some γntk > 0, let time t
be given, and suppose width m and step size η satisfy
" #
64 ln(t/δ) 2 2
γntk γntk
m≥ , η∈ √ , .
γntk 10 m 6400
8
Then, with probability at least 1 − 7δ, the SGD iterates (Ws )s≤t with logistic loss ` = `log satisfy
8 ln(1/δ) 2560
min Pr p(x, y; Ws ) ≤ 0 ≤ + 2 , (test error bound),
s<t t tγntk
√
80η m
max kWs − W0 k ≤ , (norm bound).
s<t γntk
√
Note that while maxs<t kWs −W0 k ≤ 80η m/γntk is only an upper bound, in fact, by Lemma A.3,
√ 2 /6400) is enough
the first gradient has norm γntk m, thus one step (with maximal step size η = γntk
to exit the NTK regime. As another incidental remark, note that this proof requires the logistic loss
`log , and in fact breaks with the exponential loss `exp . Lastly, for the 2-sparse parity, γntk ≥ 1/(50d)
as in Proposition 1.6, which after plugging in to Theorem 2.1 gives the corresponding row of Table 1.
2 , whereas prior work has 1/γ 8 (Ji and Telgarsky,
As discussed previously, the width is only 1/γntk ntk
2020b; Chen et al., 2019). The proof of Theorem 2.1 is in Appendix B.1, but here is a sketch.
1. Sampling a good finite-width comparator W . Perhaps the heart of the proof is showing that the
parameters θ ∈ Rm×(d+1) given by θj := θ(wj ) satisfy
D E γ √m
ˆ
θ, ∂pi (W0 ) ≥
ntk
, ∀i.
2
More elaborate versions of this are used in the proof, and sometimes require a bit of surprising
algebra (cf. Lemma A.3), which for instance seem to be able to treat σ as though it were
smooth, and without the usual careful activation-accounting in ReLU NTK proofs.
9
ˆ s (Ws )k2 : the perceptron argument. Using the first point above, it is
3. Controlling `0s (Ws )2 k∂p
possible to show
√
1D E X
0 γ m
kWt − W0 k ≥ −θ, Wt − W0 ≥ η |`s (Ws )| ,
2 s<t
4
which can then be massaged to control the aforementioned squared gradient term. Interest-
ingly, this proof step is reminiscent of the perceptron convergence proof, and the quantity
0
P
|`
s<t s (Ws )| is exactly the analog of the mistake bound quantity central in perceptron proofs
(Novikoff, 1962). Indeed, this proof never ends up caring about the loss terms, and derives the
final test error bound (a zero-one loss!) via this perceptron term.
Theorem 2.2. Suppose the data distribution satisfies Assumption 1.4 for some γntk > 0, and the
GF curve (Ws )s≥0 uses ` ∈ {`exp , `log } on an architecture whose width m satisfies
2
640 ln(n/δ)
m≥ .
γntk
√
Then, with probability at least 1 − 15δ, there exists t with kWt − W0 k = γntk m/32, and
!
2
γntk ln(n)3 ln 1δ
γ̊(Ws ) ≥ and Pr[p(x, y; Ws ) ≤ 0] ≤ O 4 + ∀s ≥ t,
4096 nγntk n
A few brief remarks are as follows. Firstly, for Wt , the sample complexity matches the SGD
sample complexity in Theorem 2.1, though via a much more complicated proof. Secondly, this proof
can handle {`log , `exp } and not just `log . Lastly, an odd point is that a bit of algebra grants a better
generalization bound at iterate Wt , but it is not clear if this improved bound holds for all time
s ≥ t; in particular it is not immediately clear that the nondecreasing margin property established
by Lyu and Li (2019) can be applied.
One interesting comparison is to a leaky ReLU convergence analysis on a restricted form of
linearly separable data due to Lyu et al. (2021). That work, through an extremely technical and
impressive analysis, establishes convergence to a solution which is equivalent to the best linear
predictor. By contrast, while the work here does not recover that analysis, due to γntk being a
constant multiple of the linear margin (cf. Proposition 1.5), the sample complexity is within a
constant factor of the best linear predictor, thus giving a sample complexity comparable to that of
(Lyu et al., 2021) via a simpler proof in a more general setting.
To prove Theorem 2.2, the first step is essentially the same as the proof of Theorem 2.1, however
it yields only a training error guarantee, not a test error guarantee.
10
Lemma 2.3. Suppose the data distribution satisfies Assumption 1.4 for some γntk > 0, let time t
be given, and suppose width m satisfies
640 ln(t/δ) 2
m≥ .
γntk
Then, with probability at least 1 − 7δ, the GF curve (Ws )s∈[0,t] on empirical risk R
b with loss
` ∈ {`log , `exp } satisfies
b t) ≤ 1 ,
R(W (training error bound),
5t √
γntk m
sup kWs − W0 k ≤ , (norm bound).
s<t 80
Note that this bound is morally equivalent to the SGD bound in Theorem 2.1 after accounting
2 “units” arising from the step size.
for the γntk
The second step of the proof of Theorem 2.2 is an explicit margin guarantee, which is missing
from the SGD analysis.
Lemma 2.4. Let data ((xi , yi ))ni=1 be given satisfying Assumption 1.3 with margin γntk > 0, and let
(Ws )s≥0 denote the GF curve resulting from loss ` ∈ {`log , `exp }. Suppose the width m satisfies
256 ln(n/δ)
m≥ 2 ,
γntk
√
fix a distance parameter R := γntk m/32, and let time τ be given so that kWτ − W0 k ≤ R/2 and
R(Wτ ) < `(0)/n. Then, with probability at least 1 − 7δ, there exists a time t with kWt − W0 k = R
so that for all s ≥ t,
γ2
kWs − W0 k ≥ R and γ̊(Ws ) ≥ ntk ,
4096
√ √
ct := (at / γntk , Vt γntk ) satisfies p(x, y; Wt ) = p(x, y; W
and moreover the rebalanced iterate W ct ) for
all (x, y), and
γ̊(Wct ) ≥ γntk .
4096
Before discussing the proof, a few remarks are in order. Firstly, the final large margin iterate
Wt is stated as explicitly achieving some distance from initialization; needing such a claim is
unsurprising, as the margin definition requires a lot of motion in a good direction to clear the noise
√
in W0 . In particular, it is unsurprising that moving O( m) is needed to achieve a good margin,
√
given that the initial weight norm is O( m); analogously, it is not surprising that Lemma 2.3 can
not be used to produce a meaningful lower bound on γ̊(Wτ ) directly.
Regarding the proof, surprisingly it can almost verbatim follow a proof scheme originally designed
for margin rates of coordinate descent (Telgarsky, 2013). Specifically, noting that ∂¯w γ(Ws ) and Ẇs
are colinear, the fundamental theorem of calculus (adapted to Clarke differentials) gives
Z t Z tD Z t
d ¯
E
γ(Wt ) − γ(Wτ ) = γ(Ws ) ds = ∂W γ(Ws ), Ẇs ds = k∂¯W γ(Ws )k · kẆs k ds,
τ ds τ τ
and now the terms can be controlled separately. Assuming the exponential P ¯ loss for simplicity and
¯
recalling the dual variable notation from Section 1.2, then ∂W γ(Ws ) = i qi ∂W pi (Ws ). Consequently,
using the same good property of θ discussed in Section 2.1 gives
* + √
X X θ γntk m
k∂¯W γ(Ws )k = k qi ∂¯W pi (Ws )k ≥ qi , ∂¯W pi (Ws ) ≥ .
kθk 4
i i
11
S2 u
S1
Figure 2. An arrangement of positively labeled points (the blue x’s) where the margin objective
has multiple KKT points, and the gradient flow is able to avoid certain bad ones. Specifically, as
the two cones of data S1 and S2 are rotated away from each other, the linear predictor u may still
achieve a positive margin, but it will become arbitrarily small. By contrast, pointing two ReLUs at
each of S1 and S2 achieves a much better margin. Lemma 2.3 is strong enough to establish this
occurs, at least for some arrangements of the cones, as detailed in Proposition 2.6. This construction
is reminiscent of other bad KKT constructions in the literature (Lyu et al., 2021; Vardi et al., 2021).
This leaves the other term of the integral, which is even easier:
Z t Z t
R
kẆs k ds ≥ Ẇs ds = kWt − Wτ k ≥ ,
τ τ 2
which completes the proof for the exponential loss. For the logistic loss, the corresponding elements
qi do not form a probability vector, and necessitate the use of a 2-phase analysis which warm-starts
with Lemma 2.3.
To finish the proof of Theorem 2.2, it remains to relate margins to test error, which in order
√
to scale with 1/n rather than 1/ n makes use of a beautiful refined margin-based Rademacher
complexity bound due to Srebro et al. (2010, Theorem 5), and thereafter uses the special structure
of 2-homogeneity to treat the network as an `1 -bounded linear combination of nodes, and thereby
achieve no dependence, even logarithmic, on the width m.
Lemma 2.5. With probability at least 1 − δ over the draw of ((xi , yi ))ni=1 , for every width m, every
choice of weights W ∈ Rm×(d+1) with γ̊(W ) > 0 satisfies
!
ln(n)3 ln 1δ
Pr[p(x, y; W ) ≤ 0] ≤ O + .
nγ̊(W )2 n
Combining the preceding pieces yields the proof of Theorem 2.2. To conclude this section, note
that these margin guarantees suffice to establish that GF can escape bad KKT points of the margin
objective. The construction appears in Figure 2 and is elementary, detailed as follows. Consider
data, all of the same label, lying in two narrow cones S1 and S2 . If S1 and S2 are close together, the
global maximum margin network corresponds to a single linear predictor. As the angle between S1
and S2 is increased, eventually the global maximum margin network chooses two separate ReLUs,
one pointing towards each cone; meanwhile, before the angle becomes too large, if S1 and S2 are
sufficiently narrow, there exists a situation whereby a single linear predictor still has positive margin,
but worse than the 2 ReLU solution, and is still a KKT point.
12
Proposition 2.6. Let ((xi , yi ))ni=1 be given as in Figure 2, where yi = +1, and (xi )ni=1 all have
kxi k = 1, and are partitioned into two sets, S1 and S2 , such that max{ xi , xj : xi ∈ S1 , xj ∈ S2 } ≤
√
−1/ 2. Then there always exists a margin parameter γ b > 0 and a single new data point x0 so that
the resulting data S1 ∪ S2 ∪ {x0 } satisfies the following conditions.
1. The maximum margin linear predictor u achieves margin γ
b, and is also a KKT point for any
shallow ReLU network of width m ≥ 1.
2. There exists m0 so that for any width m ≥ m0 , GF achieves limt γ̊(Wt ) > γ
b.
Note that Theorem 3.2 only implies that GF selects a predictor with margins at least as good
as the NC solution, and does not necessarily converge to the NC solution (i.e., rotating all ReLUs
to point in the directions (βk )rk=1 specified by Theorem 3.2). In fact, this may fail to be true, and
Proposition 2.6 and Figure 2 already gave one such construction; moreover, this is not necessarily
bad, as GF may converge to a solution with better margins, and potentially better generalization.
Overall, the relationship of NC to the bias of 2-layer network training in practical regimes remains
open.
13
The proof of Theorem 3.2 proceeds by developing a potential functions that asserts that either
mass grows in the directions (βk )rk=1 , or their margin is exceeded. Within the proof, large width
ensures that the mass in each good direction is initially positive, and thereafter Assumption 3.1 is
used to ensure that the fraction of mass in these directions is increasing. The proof of Theorem 3.2
and of Theorem 3.3 invoke the same abstract potential function lemma, and discussion is momentarily
deferred until after the presentation of Theorem 3.3.
One small point is worth explaining now, however. It may have seemed unusual to use kaj vj k
as a (squared!) norm in Figure 1. Of course, layers asymptotically balance, thus asymptotically
not only is there the Fenchel inequality 2kaj vj k ≤ a2j + kvj k2 = kwj2 , but also a reverse inequality
2kaj vj k & kwj k2 . Despite this fact, the disagreement between 2kaj vj k and kwj k2 , namely the
imbalance between a2j and kvj k2 , can cause real problems, and one solution used within the proofs
is to replace kwj k2 with kaj vj k.
The main points of comparison for Theorem 3.3 are the global margin maximization proofs of
Wei et al. (2018) and Chizat and Bach (2020). The analysis by Wei et al. (2018) is less similar,
as it heavily relies upon the benefits to local search arising from weight re-initialization, whereas
the analysis here in some sense is based on the technique in (Chizat and Bach, 2020), but diverges
sharply due to dropping the two key assumptions therein. Specifically, (Chizat and Bach, 2020)
requires infinite width and dual convergence, meaning q(t) converges, which is open even for linear
predictors in general settings. The infinite width assumption is also quite strenuous: it is used to
ensure not just that weights cover the sphere at initialization (a consequence of exponentially large
width), but in fact that they cover the sphere for all times t.
14
The proof strategy of Theorem 3.3 (and Theorem 3.2) is as follows. The core of the proof scheme
in (Chizat and Bach, 2020) is to pick two weights wj and wk , where wk achieves better margins
than wj in some sense, and consider
!
d kwj k2 kwj k2 X
= 4Q q i p i (w j ) − p i (w k ) ;
dt kwk k2 kwk k2
e e
i
as a purely technical aside, it is extremely valuable that this ratio potential automatically normalizes
the margins, leading to the appearance of pei (wj ) not pi (wj ), and a similar idea is used in the proofs
here, albeit starting from ln kwj k − ln kwk k, the idea of ln(·) also appearing in the proofs by Lyu
and Li (2019). Furthermore, this expression already shows the role of dual convergence: if we
can
P assume every qi converges, then we need only pick nodes wk for which the margin surrogate
i qi (∞)e
pi wk (∞) is very large, and the above time derivative becomes negative if wj has bad
margin, which implies mass accumulates in directions with good margin. This is the heart of the
proof scheme due to (Chizat and Bach, 2020), and circumventing or establishing dual convergence
seems tricky.
The approach here is to replace kwj k and kwk k with other quantities which can be handled
without dual convergence. First, kwj k is replaced with kW k, the norm of all nodes, which can
be controlled in an elementary way for arbitraryPL-homogeneous networks: as summarized in
P A.5, as soon as kW k becomes large, then i qi pei (W ) ≈ γ(Wt ), essentially by properties of
Lemma
ln exp.
Replacing ln kwk k is much harder, since qi (t) may oscillate and thus the notion of nodes with
good margin seems to be time-varying. If there is little rotation, then P nodes near the reference
directions (βk )rk=1 can be swapped with (βk )rk=1 , and the expression i qi pei (wk ) can be swapped
with γgl . A potential function that replaces ln kwk k and allows this swapping need only satisfy a few
abstract but innocuous conditions, as summarized in Lemma A.6. Unfortunately, verifying these
conditions is rather painful, and handling general settings (without explicitly disallowing rotation)
seems to still need quite a few more ideas.
15
Acknowledgements
The author thanks Peter Bartlett, Spencer Frei, Danny Son, and Nati Srebro for discussions. The
author thanks the Simons Institute for hosting a short visit during the 2022 summer cluster on
Deep Learning Theory, the audience at the corresponding workshop for exciting and clarifying
participation during a talk presenting this work, and the NSF for support under grant IIS-1750051.
References
Emmanuel Abbe, Enric Boix-Adsera, and Theodor Misiakiewicz. The merged-staircase property: a
necessary and nearly sufficient condition for sgd learning of sparse functions on two-layer neural
networks. arXiv preprint arXiv:2202.08658, 2022.
Alekh Agarwal, Daniel Hsu, Satyen Kale, John Langford, Lihong Li, and Robert Schapire. Taming
the monster: A fast and simple algorithm for contextual bandits. In International Conference on
Machine Learning, pages 1638–1646. PMLR, 2014.
Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning via
over-parameterization. arXiv preprint arXiv:1811.03962, 2018.
Sanjeev Arora, Simon S Du, Wei Hu, Zhiyuan Li, and Ruosong Wang. Fine-grained analysis of
optimization and generalization for overparameterized two-layer neural networks. arXiv preprint
arXiv:1901.08584, 2019.
Yu Bai and Jason D Lee. Beyond linearization: On quadratic and higher-order approximation of
wide neural networks. arXiv preprint arXiv:1910.01619, 2019.
Boaz Barak, Benjamin L Edelman, Surbhi Goel, Sham Kakade, Eran Malach, and Cyril Zhang.
Hidden progress in deep learning: Sgd learns parities near the computational limit. arXiv preprint
arXiv:2207.08799, 2022.
Avrim Blum, John Hopcroft, and Ravindran Kannan. Foundations of data science, 2017. URL
https://www.cs.cornell.edu/jeh/book.pdf.
Bernhard E. Boser, Isabelle M. Guyon, and Vladimir N. Vapnik. A training algorithm for optimal
margin classifiers. In Proceedings of the Fifth Annual Workshop on Computational Learning Theory,
COLT ’92, page 144–152, New York, NY, USA, 1992. Association for Computing Machinery. ISBN
089791497X. doi: 10.1145/130385.130401. URL https://doi.org/10.1145/130385.130401.
Zixiang Chen, Yuan Cao, Difan Zou, and Quanquan Gu. How much over-parameterization is
sufficient to learn deep relu networks? arXiv preprint arXiv:1911.12360, 2019.
Lenaic Chizat and Francis Bach. Implicit bias of gradient descent for wide two-layer neural networks
trained with the logistic loss. arXiv preprint arXiv:2002.04486, 2020.
Alexandru Damian, Jason Lee, and Mahdi Soltanolkotabi. Neural networks can learn representations
with gradient descent. In Conference on Learning Theory, pages 5413–5452. PMLR, 2022.
16
Amit Daniely and Eran Malach. Learning parities with neural networks. Advances in Neural
Information Processing Systems, 33:20356–20365, 2020.
Simon S Du, Jason D Lee, Haochuan Li, Liwei Wang, and Xiyu Zhai. Gradient descent finds global
minima of deep neural networks. arXiv preprint arXiv:1811.03804, 2018.
Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Convergence and
generalization in neural networks. In Advances in neural information processing systems, pages
8571–8580, 2018.
Ziwei Ji and Matus Telgarsky. Gradient descent aligns the layers of deep linear networks. arXiv
preprint arXiv:1810.02032, 2018a.
Ziwei Ji and Matus Telgarsky. Risk and parameter convergence of logistic regression.
arXiv:1803.07300v3 [cs.LG], 2018b.
Ziwei Ji and Matus Telgarsky. Characterizing the implicit bias via a primal-dual analysis. arXiv
preprint arXiv:1906.04540, 2019.
Ziwei Ji and Matus Telgarsky. Directional convergence and alignment in deep learning. arXiv
preprint arXiv:2006.06657, 2020a.
Ziwei Ji and Matus Telgarsky. Polylogarithmic width suffices for gradient descent to achieve
arbitrarily small test error with shallow ReLU networks. In ICLR, 2020b.
Ziwei Ji, Matus Telgarsky, and Ruicheng Xian. Neural tangent kernels, transportation mappings,
and universal approximation. In ICLR, 2020.
Yuanzhi Li and Yingyu Liang. Learning overparameterized neural networks via stochastic gradient
descent on structured data. In Advances in Neural Information Processing Systems, pages
8157–8166, 2018.
Kaifeng Lyu and Jian Li. Gradient descent maximizes the margin of homogeneous neural networks.
arXiv preprint arXiv:1906.05890, 2019.
Kaifeng Lyu, Zhiyuan Li, Runzhe Wang, and Sanjeev Arora. Gradient descent on two-layer nets:
Margin maximization and simplicity bias. Advances in Neural Information Processing Systems,
34, 2021.
Atsushi Nitanda and Taiji Suzuki. Refined generalization analysis of gradient descent for over-
parameterized two-layer neural networks with smooth activations on classification problems. arXiv
preprint arXiv:1905.09870, 2019.
Vardan Papyan, XY Han, and David L Donoho. Prevalence of neural collapse during the terminal
phase of deep learning training. Proceedings of the National Academy of Sciences, 117(40):
24652–24663, 2020.
Robert E. Schapire and Yoav Freund. Boosting: Foundations and Algorithms. MIT Press, 2012.
17
Robert E. Schapire, Yoav Freund, Peter Bartlett, and Wee Sun Lee. Boosting the margin: A new
explanation for the effectiveness of voting methods. In ICML, pages 322–330, 1997.
Shai Shalev-Shwartz and Shai Ben-David. Understanding Machine Learning: From Theory to
Algorithms. Cambridge University Press, 2014.
Zhenmei Shi, Junyi Wei, and Yingyu Liang. A theoretical analysis on feature learning in neural
networks: Emergence from inputs and advantage over fixed features. In International Conference
on Learning Representations, 2022. URL https://openreview.net/forum?id=wMpS-Z_AI_E.
Daniel Soudry, Elad Hoffer, Mor Shpigel Nacson, Suriya Gunasekar, and Nathan Srebro. The
implicit bias of gradient descent on separable data. arXiv preprint arXiv:1710.10345, 2017.
Nathan Srebro, Karthik Sridharan, and Ambuj Tewari. Smoothness, low noise and fast rates. In
NIPS, 2010.
Gal Vardi, Ohad Shamir, and Nathan Srebro. On margin maximization in linear and relu networks.
arXiv preprint arXiv:2110.02732, 2021.
Colin Wei, Jason D Lee, Qiang Liu, and Tengyu Ma. Regularization matters: Generalization and
optimization of neural nets vs their induced kernel. arXiv preprint arXiv:1810.05369, 2018.
Blake Woodworth, Suriya Gunasekar, Jason D Lee, Edward Moroshko, Pedro Savarese, Itay Golan,
Daniel Soudry, and Nathan Srebro. Kernel and rich regimes in overparametrized models. In
Conference on Learning Theory, pages 3635–3673. PMLR, 2020.
Tong Zhang and Bin Yu. Boosting with early stopping: Convergence and consistency. The Annals
of Statistics, 33:1538–1579, 2005.
Difan Zou, Yuan Cao, Dongruo Zhou, and Quanquan Gu. Stochastic gradient descent optimizes
over-parameterized deep relu networks. arXiv preprint arXiv:1811.08888, 2018.
A Technical preliminaries
As follows are basic technical tools used throughout.
Proof of Proposition 1.5. The proof considers the three settings separately.
18
1. For any i, first note that
D E
Ew∼Nθ θ(w), ∂ˆw pi (w) = E(a,v)∼Nθ |a|ūT xi yi σ 0 (v T xi )1[k(a, v)k ≤ 2]
bE(a,v)∼Nθ |a|σ 0 (v T xi )1[k(a, v)k ≤ 2].
≥γ
√
To control the expectation, note that with probability at least 1/2, then 1/4 ≤ |a| ≤ 2, and
thus by rotational invariance
1 √
E(a,v)∼Nθ |a|σ 0 (v T xi )1[k(a, v)k ≤ 2] ≥ E(a,v)∼Nθ σ 0 (v T xi )1[kvk ≤ 2]
8
1 √
≥ E(a,v)∼Nθ σ 0 (v1 )1[kvk ≤ 2]
8
1
≥ .
32
2. For convenience, fix any example (x, y) ∈ ((xi , yi ))ni=1 , and write (a, v) = w, whereby w ∼ Nw
means a ∼ Na and v ∼ Nv . With this out of the way, define orthonormal matrix M ∈ Rd×d
where the first column is ū, the second column is (I − ūūT )x/k(I − ūūT )xk, and the remaining
columns are arbitrary soplong as M is orthonormal, and note that M u = e1 and M x =
e1 ūT x + e2 r2 where r2 := kxk2 − (ūT x)2 . Then, using rotational invariance of the Gaussian,
D E
ˆ
Ew θ(w), ∂pi (w) = yEw=(a,v) sgn(ūT v)σ(v T x)1[kwk ≤ 2]
= yEk(a,M v)k≤2 α(M v)σ(v T M T x)
= Ek(a,v)k≤2 ysgn(v1 )σ(v1 ūT xy 2 + v2 r2 )
= Ek(a,v)k≤2 ysgn(v1 )σ(ysgn(v1 )|v1 |ūT xy + v2 r2 )
h
= E k(a,v)k≤2 σ(|v1 |ūT xy + v2 r2 ) − σ(−|v1 |ūT xy + v2 r2 )
ysgn(v1 )=1
v2 ≥0
i
+ σ(|v1 |ūT xy − v2 r2 ) − σ(−|v1 |ūT xy − v2 r2 ) .
Considering cases, the first ReLU argument is always positive, exactly one of the second and
third is positive, and the fourth is negative, whereby
yEk(a,v)k≤2 α(v)σ(v T x) = E k(a,v)k≤2 |v1 |ūT xy + v2 r2 + |v1 |ūT xy − v2 r2
ysgn(v1 )=1
v2 ≥0
= 2E k(a,v)k≤2 |v1 |ūT xy
ysgn(v1 )=1
≥ 2b
γE kvk≤1 |v1 |
ysgn(v1 )=1
bPr[k(a, v)k ≤ 2]E |v1 | k(a, v)k ≤ 2 ,
=γ
where Pr[k(a, v)k ≤ 2] ≥ 1/4 since (for example) the χ2 random variables corresponding
√ to |a|2
2
and kvk have median less than one, and the expectation term is at least 1/(4 d) by standard
Gaussian computations (Blum et al., 2017, Theorem 2.8).
19
3. For any pair (xi , yi ),
2
X
2yi αj σ(βjT xi ) = yi σ(ūT xi ) − yi σ(−ūT xi )
j=1
Next, the construction for 2-sparse parity. As is natural in maximum margin settings, but in
contrast with most studies of sparse parity, only the support of the distribution matters (and the
labeling), but not the marginal distribution of the inputs.
Proof of Proposition 1.6. This proof shares ideas with (Wei et al., 2018; Ji and Telgarsky, 2020b),
though with some adjustments to exactly fit the standard 2-sparse parity setting, and to shorten
the proofs.
Without loss of generality, due to the symmetry of the data distribution about the origin, suppose
a = 1 and b = 2, meaning for any x ∈ Hd , the correct label is dx1 x2 , the product of the first two
coordinates. Both proofs will
P use the global margin construction (the parameters for γgl ), given as
follows: p(x, y; (α, β)) = y 4j=1 αj σ(βjT x), where α = (1/4, −1/4, −1/4, 1/4) and
1 1
β1 := √ , √ , 0, . . . , 0 ∈ Rd ,
2 2
1 −1
β2 := √ , √ , 0, . . . , 0 ∈ Rd ,
2 2
−1 1
β3 := √ , √ , 0, . . . , 0 ∈ Rd ,
2 2
−1 −1
β4 := √ , √ , 0, . . . , 0 ∈ Rd .
2 2
Note moreover that for any x ∈ Hd , then βjT x > 0 for exactly one j, which will be used for both γntk
and γgl . The proof now splits into the two different settings, and will heavily use symmetry within
Hd and also within (α, β).
note that this satisfies the condition kθ(w)k ≤ 1 thanks to the factor 1/2, since each βj gets
a hemisphere, and (β1 , β4 ) together partition the sphere once, and (β2 , β3 ) similarly together
partition the sphere once.
Now let any x be given, which as above has label y = x1 x2 . By rotational symmetry of the data
and also the transport mapping, suppose suppose β1 is the unique choice with β1T x > 0, which
20
implies y = 1, and also β2T x = 0 = β3T x = 0, however β4T x = −β4T x. Using these observations,
and also rotational invariance of the Gaussian,
¯
Ea,v θ(a, v), ∂p(x, y; w)
4
|a| X T
= Ea,v βj x1[βjT v ≥ 0] · 1[v T x ≥ 0]
2
j=1
T |a|
= β1 x Ea · Ev 1[β1T v ≥ 0] · 1[v T x ≥ 0] − Ev 1[−β1T v ≥ 0] · 1[v T x ≥ 0] .
2
Performing a similar calculation for the other term (arising from β4T x) and plugging all of this
back in,
¯
Ea,v θ(a, v), ∂p(x, y; w)
r
2 |a| p p
= Ea · Ev 1[v1 ≥ 0] 1[v1 + v2 d/2 − 1 ≥ 0] − 1[−v1 + v2 d/2 − 1 ≥ 0] .
d 2
The first term is at least 1/5, and the second can be calculated via brute force:
√ √
√ 1/ d 1/ d
Z Z
1 2 1 1 1 1
Pr[v2 ≥ 1/ d] = √ exp(−x ) dx ≥ √ exp(−1/d) dx ≥ √ √ ,
2π 0 2π 0 2π d e
which completes the proof after similarly using Ea |a| ≥ 1, and simplifying the various constants.
2. Let any x ∈ Hd be given, and as above note that βjT x > 0 for exactly one j. By symmetry,
suppose it is β1 , whereby y = x1 x2 = 1, and
X 1 2 1
γgl ≥ p(x, y; (α, β)) = y αj σ(βjT x) = |α1 | · β1T x = √ =√ .
4 2d 8d
j
21
Lastly, an estimate of γntk in a simplified version of Assumption 3.1, which is used in the proof
of Proposition 2.6.
Lemma A.1. Suppose Assumption 1.1 holds for data ((xi , yi ))ni=1 with reference solution ((αk , βk ))rk=1
and margin γgl > 0, and additionally
√ αk > 0 and yi = +1 and kxi k = 1. Then Assumption 1.3 holds
with margin γntk ≥ γgl /(8 d).
Pr 0 T
n
Proof. Define θ(a, v) := k=1 αk σ (βk v), 0 1[k(a, v)k ≤ 2]. Fix any x ∈ (xi )i=1 , and for each k
define orthonormal matrix Mk with first column βk and second p column (I − βk βkT )x/k(I − βk βkT )xk,
whereby Mk βk = e1 and Mk x = e1 βk x + e2 rj where rj := 1 − (βk x)2 . Then, using rotational
T T
≥ γgl Ekvk≤1 v1 ,
v1 ≥0
√
which is bounded below by γgl /(8 d) via similar arguments to those in the proof of Proposition 1.5.
X
max aj σ(vjT xi ) ≤ 4 ln(n/δ).
i
j
√ √ √
Proof. 1. Rewrite ã := a m, so that ã ∼ Nm . Since ã 7→ kãk/ m = kak is (1/ m)-Lipschitz,
then by Gaussian concentration, (Wainwright, 2019, Theorem 2.26),
√
kak = kãk/ m
√ p
≤ Ekãk/ m + 2 ln(1/δ)/m
p √ p
≤ Ekãk2 / m + 2 ln(1/δ)/m
p
= 1 + 2 ln(1/δ)/m.
22
√
Similarly for V , defining Ṽ := V d whereby Ṽ ∼ Nm×d , Gaussian concentration grants
√ √ p
kV k = kṼ k/ d ≤ m + 2 ln(1/δ)/d.
2. Fix any example xi , and constants 1 > 0 and 2 > 0 to be optimized at the end of the proof,
and define di := d/kxi k2 for convenience. By rotational invariance √ of Gaussians
√ and since xi
is fixed, then σ(V xi ) is equivalent
√ in
√ distribution to kx i kσ(g)/ d = σ(g)/ di where g ∼ Nm .
√
Meanwhile, g 7→ kσ(g)k/ di is (1/ di )-Lipschitz with Ekσ(g)k ≤ m, and so, by Gaussian
concentration (Wainwright, 2019, Theorem 2.26),
!
√ p √ −di 21
Pr[kσ(V xi )k ≥ 1 + m] = Pr[kσ(g)k/ di ≥ 1 + m] ≤ exp .
2
Next consider the original expression aT σ(V xi ). To simplify handling of the 1/m variance
√
of the coordinates of a, define another Gaussian h := a m, and a new constant ci := mdi
for convenience, whereby aT σ(Vi ) is equivalent in distribution to equivalent in distribution to
√
hT σ(g)/ ci since a and V are independent (and thus h and V are independent). Conditioned
on g, since Eh = 0, then E[hT σ(g)|g] = 0. As such, applying Gaussian concentration to this
√ √
conditioned random variable, since h 7→ hT σ(g)/ ci is (kσ(g)k/ ci )-Lipschitz, then
!
T
√ −ci 22
Pr[h σ(g)/ ci ≥ 2 g] ≤ exp .
2kσ(g)k2
Returning to the original expression, it can now be controlled via the two preceding bounds,
conditioning, and the tower property of conditional expectation:
√
Pr[hT σ(g)/ ci ≥ 2 ]
h √ p √ i h p √ i
≤ Pr hT σ(g)/ ci ≥ 2 kσ(g)k/ di ≤ 1 + m · Pr kσ(g)k/ di ≤ 1 + m
h √ p √ i h p √ i
+ Pr hT σ(g)/ ci ≥ 2 kσ(g)k/ di > 1 + m · Pr kσ(g)k/ di > 1 + m
h √ p √ i p √
= E Pr[hT σ(g)/ ci ≥ 2 | g] kσ(g)k/ di ≤ 1 + m Pr[kσ(g)k/ di ≤ 1 + m]
√ p √ p √
+ Pr[hT σ(g)/ ci ≥ 2 | kσ(g)k/ di > 1 + m]Pr[kσ(g)k/ di > 1 + m]
!
−ci 22 p √
2
≤ E exp kσ(g)k/ d i ≤ 1 + m + exp −d
i 1 /2
2kσ(g)k2
!
−ci 22
2
≤ exp + exp −d
i 1 /2 .
4di 21 + 4di m
p p
As such, choosing 2 := 4 ln(n/δ) mdi /ci = 4 ln(n/δ) and 1 := 2 ln(n/δ)/di gives
√ δ δ
Pr[aT σ(V xi ) ≥ 2 ] = Pr[hT σ(g)/ ci ≥ 2 ] ≤ + ,
n n
which is a sub-exponential concentration bound. Union bounding over the reverse inequality
and over all n examples and using maxi kxi k ≤ 1 gives the final bound.
23
Next comes a key tool in all the proofs using γntk : guarantees that the infinite-width margin
assumptions imply the existence of good finite-width networks. These bounds are stated for
Assumption 1.3, however they will also be applied with Assumption 1.4, since Assumption 1.4
implies that Assumption 1.3 holds almost surely over any finite sample.
Lemma A.3. Let examples ((xi , yi ))ni=1 be given, and suppose Assumption 1.3 holds, with corre-
sponding θ : Rd+1 → Rd+1 and γntk > 0 given.
√
1. With probability at least 1 − δ over the draw of (wj )m
j=1 , defining θ j := θ(wj )/ m, then
XD E √
ˆ i (wj ) ≥ γntk m − 32 ln(n/δ).
p
min θj , ∂p
i
j
2. With probability at least 1 − 7δ over the draw of W with rows (wj )m j=1 with m ≥ 2 ln(1/δ),
√
defining rows θj := θ(wj )/ m of θ ∈ R m×(d+1) , then for any W and any R ≥ kW − W 0 k and
0
any rθ ≥ 0 and rw ≥ 0,
D E √ hp i
ˆ i (W 0 ) − rw pi (W 0 ) ≥ γntk rθ m − rθ
rθ θ + rw W, ∂p 32 ln(n/δ) + 8R + 4
√ √
− rw 4 ln(n/δ) + 2R + 2R m + 4 m ,
√
and moreover, writing W = (a, V ), then kak ≤ 2 and kV k ≤ 2 m. For the particular choice
rθ := R/8 and rw = 1, if R ≥ 8 and m ≥ (64 ln(n/δ)/γntk )2 , then
√
D
ˆ 0
E
0 γntk rθ m
rθ θ + W, ∂pi (W ) − pi (W ) ≥ − 160rθ2 .
2
where µ ≥ γntk by assumption. By the various conditions on θ, it holds for any (a, v) := w ∈ Rd+1
and corresponding (ā, v) := θ(w) ∈ Rd+1 that
D E
ˆ i (w) ≤ āσ(v T xi ) + v, axi σ 0 (v T xi )
θ(w), ∂p
and therefore, by Hoeffding’s inequality, with probability at least 1 − δ/n over the draw of m
iid copies of this random variable,
XD E
ˆ i (wj ) ≥ mµ − 32m ln(n/δ) ≥ mγntk − 32m ln(n/δ),
p p
θ(wj ), ∂p
j
√ √
which gives the desired bound after dividing by m, recalling θj := θ(wj )/ m, and union
bounding over all n examples.
24
2. First, suppose with probability at least 1 − 7δ that the consequences of Lemma A.2 and the
√
preceding part of the current lemma hold, whereby simultaneously kak ≤ 2, and kV k ≤ 2 m,
and
XD E √
ˆ i (wj ) ≥ γntk m − 32 ln(n/δ).
p
min pi (W ) ≥ −4 ln(n/δ), and min θj , ∂p
i i
j
The remainder of the proof proceeds by separately lower bounding the two right hand terms in
D E D E D E
0 0
ˆ i (W ) − rw pi (W ) = rθ θ, ∂p
ˆ i (W ) + θ, ∂p 0
ˆ i (W ) − ∂p
ˆ i (W )
rθ θ + rw W, ∂p
D E
ˆ 0 0
+ rw W, ∂pi (W ) − rw pi (W ) .
For the first term, writing (ā, V ) = θ and noting kāk ≤ 2 and kV k ≤ 2, then for any
W 0 = (a0 , V 0 ),
D E X
ˆ i (W 0 ) − ∂p
θ, ∂p ˆ i (W ) ≤ āj σ(xTi vj0 ) − σ(vjT xi )
j
X
+ xTi v j a0j σ 0 (xT vj0 ) − aj σ 0 (xT vj )
j
sX s 2
X
≤ ā2j σ(xTi vj0 ) − σ(vjT xi )
j j
X
+ |xTi v j | · a0j σ 0 (xT vj0 ) − aj σ 0 (xT vj0 )
j
X
+ |xTi v j | · aj σ 0 (xT vj0 ) − aj σ 0 (xT vj0 )
j
X X D E
≤ aj σ(xTi vj0 ) + a0j vj − vj0 , xi σ 0 (xTi vj )
j j
X X
≤ pi (w) + yi aj σ(xTi vj0 ) − σ(xTi vj ) + a0j · kvj − vj0 k
j j
≤ 4 ln(nδ) + kak · kV − V k + ka − a + ak · kV − V 0 k
0 0
≤ 4 ln(nδ) + 4R + R2 .
Multiplying through by rθ and r and combining these inequalities gives, for every i,
D E √ hp i
ˆ i (W 0 ) − rw pi (W 0 ) ≥ γntk rθ m − rθ
rθ θ + rw W, ∂p 32 ln(n/δ) + 4R + 4
h i
− rw 4 ln(n/δ) + 4R + R2 ,
25
which establishes the first inequality. For the particular choice rθ := R/8 with R ≥ 8 and
rw = 1, and using m ≥ (64 ln(n/δ)/γntk )2 , the preceding bound simplifies to
" √ #
D E √ γ m
ˆ i (W 0 ) − rw pi (W 0 ) ≥ γntk rθ m − rθ ntk
rθ θ + rw W, ∂p + 32rθ + 32rθ
8
" √ #
γntk m 2
− + 32rθ + 64rθ
16
√
γntk rθ m
≥ − 160rθ2 .
2
The first property is that norms increase once there is a positive margin.
Lemma A.4 (Restatement of (Lyu and Li, 2019, Lemma B.1)). Suppose the setting of eq. (A.1)
and also ` ∈ {`exp , `log }. If R(uτ ) < `(0)/n, then, for every t ≥ τ ,
d
kut k > 0 and hut , u̇t i > 0,
dt
and moreover limt kut k = ∞.
Proof. Since Rb is nonincreasing during gradient flow, it suffices to consider any us with R(u
b s) <
`(0)/n. To apply (Lyu and Li, 2019, Lemma B.1), first note that both the exponential and logistic
losses can be handled, e.g., via the discussion of the assumptions at the beginning of (Lyu and Li,
2019, Appendix A.1). Next, the statement of that lemma is
d
ln kus k > 0,
ds
which together with (d/ ds) ln kus k > 0 from (Lyu and Li, 2019, Lemma B.1) imply the main part
of the statement; all that remains to show is kus k → ∞, but this is given by (Lyu and Li, 2019,
Lemma B.6).
26
Next, even without the assumption R(ub s ) < `(0)/n (which at a minimum requires a two-phase
proof, and certain other annoyances), note that once kus k is large, then the gradient can be related
to margins, even if they are negative, which will be useful in circumventing the need for dual
convergence and other assumptions present in prior work (e.g., as in (Chizat and Bach, 2020)).
Lemma A.5 (See also (Ji and Telgarsky, 2020a, Proof of Lemma C.5)). Suppose the setting of
eq. (A.1) and also ` = `exp . Then, for any u and any ((xi , yi ))ni=1 (and corresponding R),
b
D E
u, −n∂ˆu R(u)
b
ln n
ln n
≤ Q γ̊(u) + ≤ Q γ(u) + .
LkukL kukL kukL
P
Proof. Define π(p) = − ln exp(−p) = γ e(u), whereby q = ∇p π(p). Since π is concave in p,
D E X X
u, −n∂ˆu R(u)
b = |`0i | u, ∂p
¯ i = LQ qi pi = LQ ∇p π(p), p
i i
= LQ ∇p π(p), p − 0 ≤ LQ π(p) − π(0) = LQ [e
γ + ln n] .
Moreover, by standard properties of π, letting k be the index of any example with pk (u) = mini pi (u),
X
e = − ln
γ exp(−p) ≤ − ln exp(−pk ) = pk = γ(u)kukL .
Combining these inequalities and dividing by LkukL gives the desired bounds.
Lastly, a key abstract potential function lemma: this potential function is a proxy for mass
accumulating on certain weights with good margin, and once it satisfies a few conditions, large
margins are implied directly. This is the second component needed to remove dual convergence
from (Chizat and Bach, 2020).
Lemma A.6. Suppose the setup of eq. (A.1) with L = 2, and additionally that there exists a
b > 0, a time τ , and a potential function Φ(u) so that Φ(uτ ) > −∞, and for all t ≥ τ ,
constant γ
1
Φ(u) ≤ ln kuk,
L
d
Φ(u) ≥ Q(u)bγ.
dt
Rt
Then it follows that R(u)
b → ∞, and kuk → ∞, and τ Q(us ) ds = ∞, and lim inf t γ(ut ) ≥ γ
b.
27
thus kus k → ∞. On the other hand, if inf s R(u b s ) = 0, then there exists t1 so that for all t ≥ t1 ,
then γt > 0 (also making use of nondecreasing margins (Lyu and Li, 2019)), which is only possible
if kus k → ∞, and thus moreover we can take t2 ≥ t1 so that additionally kut k ≥ ln(n)/γt1 , (which
will hold for all t0 > t2 by Lemma A.4), and by Lemma A.5 and the restriction L = 2 means
P 0 ¯
i |`i | u, ∂pi (u)
P
d QL i qi pi (u)
ln kuk = = ≤ QL(γt + γt1 ) ≤ 2LQγt ,
dt kuk2 kuk2
R∞
which after integrating and upper bounding γt ≤ 1 means t2 Qs ds ≥ limt ln kut k − ln kut2 k = ∞.
b s ) = 0, then still kus k → ∞ and ∞ Qs ds = ∞.
R
As such, independent of whether inf s R(u τ
This now suffices to complete the proof. If lim inf t γt ≥ γ b, then in fact limt γt is well-defined
(by non-decreasing margins and kuk → ∞) and limt γt ≥ γ b > 0, whereby lim supt R(u b t) ≤
L
lim supt `(−γt kwt k ) = 0. Alternatively, suppose contradictorily that lim inf t γt < γ b, and choose
any ∈ (0, γ b/4) so that lim inf t γt < γ b − 3; noting that γt is monotone once there exists some
γs > 0, then, choosing t3 large enough so that kut k2 ≥ ln(n)/ for all t ≥ t3 and γt < γ b − 2 for all
t ≥ t3 , it follows by Lemma A.5 that
1
0 ≤ lim inf ln kut k − Φ(ut )
t L
Z t
1 d 1
≤ ln kut3 k − Φ(ut3 ) + lim inf ln kus k − Φ(us ) ds
L t t3 ds L
Z t
1
≤ ln kut3 k − Φ(ut3 ) + lim inf Q(b
γ − ) − Qb γ ds
L t t3
Z t
1
≤ ln kut3 k − Φ(ut3 ) + lim inf [−Q] ds
L t t3
= −∞,
Proof. This proof is essentially a copy of one due to Ji and Telgarsky (2020b, Lemma 4.3); that one
is stated for the analog of pi used there, and thus need to be re-checked.
Let Fi := {((xj , yj )) : j < i} denote the σ-field of all information until time i,whereby xi is inde-
pendent of FP i , whereas
after conditioning on Fi . Consequently, E Q(Wi ) − Qi (Wi )|Fi =
wi deterministic
0, whereby i<t Q(Wi ) − Qi (Wi ) is a martingale difference sequence.
28
The high probability bound will now follow via a version of Freedman’s inequality (Agarwal
et al., 2014, Lemma 9). To apply this bound, the conditional variances must be controlled: noting
that |`0 (z)| ∈ [0, 1], then Q(Wi ) − Qi (Wi ) ≤ 1, and since Qi (Wi ) ∈ [0, 1], then Qi (Wi )2 ≤ Qi (Wi ),
and thus
h 2 i h i
E Q(Wi ) − Qi (Wi ) Fi = E Qi (Wi )2 Fi − Q(Wi )2
≤ E Qi (Wi ) Fi − 0
= Q(Wi ).
As such, by the aforementioned version of Freedman’s inequality (Agarwal et al., 2014, Lemma 9),
X X h 2 i
Q(Wi ) − Qi (Wi ) ≤ (e − 2) E Q(Wi ) − Qi (Wi ) Fi + ln(1/δ)
i<t i<t
X
≤ (e − 2) Q(Wi ) + ln(1/δ),
i<t
With Lemma B.1 and the Gaussian concentration inequalities from Appendix A.2 in hand, a
proof of a generalized form of Theorem 2.1 is as follows.
√
Proof of Theorem 2.1. Let (wj )mj=1 be given with corresponding (āj , v j ) := θ j := θ(wj )/ m
(whereby kθj k ≤ 2 by construction), and define
√ √ √ √
10η m γ m 80η m γ m
r := ≤ , R := 8r = ≤ , W := rθ + W,
γ 640 γ 80
which implies r ≥ 1, and R ≥ 1, and η ≤ R/16. Since Assumption 1.4 implies that Assumption 1.3
holds for ((xi , yi ))i≤t with probability 1, for the remainder of the proof, rule out the 7δ failure
probability associated with the second part of Lemma A.3 (which is stated in terms of Assumption 1.3
not Assumption 1.4), whereby simultaneously for every kW 0 − W0 k ≤ R
√ √ 2
ˆ i (W 0 ) ≥ rγ m − 160r2 ≥ rγ m = γ m ≥ ln(t),
D E
min W , ∂p (B.1)
i 2 4 2560
√ √ √
D
0
E √ √ γ m γ m γ m
ˆ
p
min θ, ∂pi (W ) ≥ γ m − 32 ln(n/δ) − 4R − 4 ≥ γ m − − ≥ , (B.2)
i 8 10 2
√
and also ka0 k ≤ 2 and kV0 k ≤ 2 m.
The proof now proceeds as follows. Let τ denote the first iteration where kWτ − W0 k ≥ R,
whereby τ > 0 and maxs<τ kWs − W0 k ≤ R. Assume contradictorily that τ ≤ t; it will be shown
that this implies kWτ − W0 k ≤ R.
Consider any iteration s < τ . Expanding the square,
ˆ s (Ws ) − W k2
kWs+1 − W k2 = kWs − η ∂`
D E 2
= kWs − W k2 − 2η ∂` ˆ s (Ws ), Ws − W + η 2 ∂`ˆ s (Ws )
D E 2
= kWs − W k2 + 2η`0s (Ws ) ∂p ˆ s (Ws ), W − Ws + η 2 `0 (Ws )2 ∂p
s
ˆ s (Ws ) .
29
By convexity, kWs − W0 k ≤ R, and eq. (B.1),
D !
D E E
`0s (Ws ) ∂p
ˆ s (Ws ), W − Ws = `0s (Ws ) ˆ s (Ws ), W − ps (Ws ) − ps (Ws )
∂p
D E
≤ `s ˆ
∂ps (Ws ), W − ps (Ws ) − `s (Ws )
30
To simplify the last term, note by eq. (B.2) that
kWτ − W0 k = sup hW, Wτ − W0 i
kW k≤1
1
≥ −θ̄, Wτ − W0
2
η XD ˆ s (Ws )
E
= −θ, ∂`
2 s<τ
ηX 0 D
ˆ i (Ws )
E
= |`s (Ws )| θ, ∂p
2 s<τ
√
ηX 0 γ m
≥ |`s (Ws )| , (B.3)
2 s<τ 2
Lastly, for the generalization bound, defining Q(W ) := Ex,y |`0 (p(x, y; W ))|, discarding an additional
δ failure probability, by Lemma B.1,
X X 1280
Q(Ws ) ≤ 4 ln(1/δ) + 4 |`0s (Ws )| ≤ 4 ln(1/δ) + 2 .
s<t
γ
i<t
B.2 GF proofs
This section culminates in the proof of Theorem 2.2, which is immediate once Lemmas 2.3 and 2.4
are established.
Before proceeding with the main proofs, the following technical lemma is used to convert a
bound on `0 to a bound on `.
Lemma B.2. For ` ∈ {`log , `exp }, then |`0 (z)| ≤ 1/8 implies `(z) ≤ 2|`0 (z)|.
Proof. If ` = `exp , then `0 = −`, and thus `(z) ≤ 2|`0 (z)| automatically. If `(z) = `log , the logistic
loss, then |`0 (z)| ≤ 1/8 implies z ≥ 2. By the concavity of ln(·), for any z ≥ 2, since 1 + e−z ≤ 7/6,
then
(7/6)e−z
`(z) = ln(1 + e−z ) ≤ e−z ≤ ≤ 2|`0 (z)|,
1 + e−z
thus completing the proof.
31
Next comes the proof of Lemma 2.3, which follows the same proof plan as Theorem 2.1.
Proof of Lemma 2.3. This proof is basically identical to the SGD in Theorem 2.1. Despite this,
proceeding with amnesia, let rows (wj )m
j=1 of W0 be given with corresponding (āj , v j ) := θ j :=
√
θ(wj )/ m (whereby kθj k ≤ 2 by construction), and define
√ √
γntk m γntk m
r := , R := 8r = , W := rθ + W,
640 80
with immediate consequences that r ≥ 1 and R ≥ 8. For the remainder of the proof, rule out the 7δ
failure probability associated with the second part of Lemma A.3, whereby simultaneously for every
kW 0 − W0 k ≤ R,
D E rγ √m √
rγntk m γ2 m
ˆ 0 ntk 2
min W , ∂pi (W ) ≥ − 160r ≥ = ntk ≥ ln(t), (B.4)
i 2 4 2560
√ √ √
√ √
ˆ i (W 0 ) ≥ γntk m − 32 ln(n/δ) − 4R − 4 ≥ γntk m − γntk m − γntk m ≥ γntk m .
D E p
min θ, ∂p
i 8 10 2
(B.5)
The proof now proceeds as follows. Let τ denote the earliest time such that kWτ − W0 k = R; since
Ws traces out a continuous curve and since R > 0 = kW0 − W0 k, this quantity is well-defined. As a
consequence of the definition, sups<τ kWs − W0 k ≤ R. Assume contradictorily that τ ≤ t; it will be
shown that this implies kWτ − W0 k < R.
By the fundamental theorem of calculus (and the chain rule for Clarke differentials), convexity
of `, and since kWs − W0 k ≤ R holds for s ∈ [0, τ ), which implies eq. (B.4) holds,
Z τ
2 2 d
kWτ − W k − kW0 − W k = kWs − W k2 ds
0 ds
Z τ D E
= 2 Ẇs , Ws − W ds
0
2 τX 0
Z D E
= ¯
`i (Ws ) ∂pi (Ws ), Ws − W ds
n 0
i
!
2 τX 0
Z D E
= `i (Ws ) ˆ i (Ws ), W − pi (Ws ) − pi (Ws ) ds
∂p
n 0
i
Z τX D !
2 E
ˆ i (Ws ), W − pi (Ws ) − `i (Ws ) ds
≤ `i ∂p
n 0
i
Z τ X
2 1
≤ − `i (Ws ) ds
n 0 t
i
Z τ
≤2−2 R(W
b s ) ds.
0
32
which implies
R
kWτ − W0 k ≤ 2r = < R,
2
a contradiction since Wτ is well-defined as the earliest time with kWτ − W0 k = R, which thus
contradicts τ ≤ t. As such, τ ≥ t, and all of the preceding inequalities follows with τ replaced by t.
To obtain an error bound, similarly to the key perceptron argument before, using eq. (B.5),
which implies Z tX
1 4kWt − W0 k 1
|`0i (Ws )| ds ≤ √ ≤ ,
n 0 γntk m 20
i
and in particular
Z tX
1X 0 1 4kWt − W0 k 1
inf |`i (Ws )| ≤ |`0i (Ws )| ds ≤ √ ≤
s∈[0,t] n tn 0 tγntk m 20t
i i
Since this also implies maxi |`0i (Wk )| ≤ n/(10t) ≤ 1/10, it follows by Lemma B.2 that R(W
b k) ≤
0
1/(5t), and the claim also holds for t ≥ t since the empirical risk is nonincreasing with gradient
flow.
Next, the proof of the Rademacher complexity bound used for all GF sample complexities.
Proof of Lemma 2.5. For any (a, v) ∈ Rd+1 and any x, recalling the notation defining e
a := sgn(a)
and ve := v/kvk,
k(a, v)k2
aσ(v T x) = kavke v T x) ≤
aσ(e v T x),
aσ(e
e
2
and therefore, letting
X m
sconv(S) := pj uj : m ≥ 0, p ∈ Rm , kpk1 ≤ 1, uj ∈ S
j=1
33
denote the symmetrized convex hull as used throughout Rademacher complexity (Shalev-Shwartz
and Ben-David, 2014), then
1 X
T m×(d+1)
F := x 7→ aj σ(v j x) : m ≥ 0, W ∈ R
kW k2
j
1 X
T m×(d+1)
= x 7→ ka j vj ke
a σ(ev j x) : m ≥ 0, W ∈ R
kW k2
j
X k(aj , vj )k2
T m×(d+1)
⊆ x 7→ aσ(e vj x) m ≥ 0, W ∈ R
:
2kW k2
e
j
X 1
⊆ x 7→ pj σ(uTj x) : m ≥ 0, p ∈ Rm , kpk1 ≤ , kuj k2 = 1
2
j
1
= sconv x 7→ σ(v T x) : kvk2 = 1 .
2
As such, by standard rules of Rademacher complexity (Shalev-Shwartz and Ben-David, 2014),
1 1
Rad(F) ≤ Rad x 7→ σ(v T x) : kvk2 = 1 ≤ √ ,
2 2 n
and thus, by a refined margin bound for Rademacher complexity (Srebro et al., 2010, Theorem
5), with probability at least 1 − δ, simultaneously for all γgl and all m, every W ∈ Rm×(d+1) with
γ̊(W ) ≥ γgl satisfies
1 1 !
ln(n) 3 ln ln γ + ln δ ln(n) 3 ln 1
Pr[p(x, y; W ) ≤ 0] = O 2 Rad(F)2 + gl
=O + δ ,
γgl n nγgl2 n
and to finish, instantiating this bound with γ̊(W ) gives the desired form.
Proof of Lemma 2.4. By the second part of Lemma A.3, with probability at least 1 − 7δ, simultane-
√
ously kak ≤ 2, and kV k ≤ 2 m, and for any kW 0 − W0 k ≤ R, then
D E √ hp i γ √m
ˆ 0 ntk
min θ, ∂pi (W ) ≥ γntk m − 32 ln(n/δ) + 8R + 4 ≥ ,
i 2
√
where θj := θ(wj )/ m as usual, and kθk ≤ 2; for the remainder of the proof, suppose these bounds,
and discard the corresponding 7δ failure probability. Moreover, for any W 0 with R(W
b 0 ) < `(0)/n and
34
Ji and Telgarsky (2019, Lemma 5.4, first part, which does not depend on linear predictors),
¯γ (W 0 ) = sup
∂e ¯γ (W 0 )
W, ∂e
kW k≤1
* +
1 X
0
¯ i (W )
≥ θ, qi ∂p
2
i
1X D ¯ E
= qi θ, ∂pi (W 0 )
2
i
√
γntk m X
≥ qi
4
i
√
γntk m
≥ .
4
Furthermore, it holds that γ̊(Ws ) ≥ γ̊(Wt ) for all s ≥ t (Lyu and Li, 2019), which completes the
proof for Wt under the standard parameterization.
Now consider the rebalanced parameters W ct := (at /√γntk , Vt √γntk ); since m ≥ 256/γ 2 , which
√ ntk
means 16 ≤ γntk m, then
√ √ √
γntk m γntk m γntk m
kat k ≤ ka0 k + kat − a0 k ≤ 2 + R ≤ + ≤ ,
√ 8√ 32 4
kVt k ≤ kV0 k + kVt − V0 k ≤ 2 m + R ≤ 3 m,
35
then the rebalanced parameters satisfy
√
ct k ≤ kat /√γntk k + kVt √γntk k ≤
kW
γntk m √ √
+ 3 γntk m ≤ 4 γntk m,
4
and thus, for any (x, y), since
X X aj √
p(x, y; w) = aj σ(vjT x) = √ σ( γntk vjT x) = p(x, y; W
c ),
γntk
j j
then
pi (W
c) pi (w) e2 m/256
γ γntk
γ̊(W
c ) = min = min ≥ ≥ ,
i kW k
c 2 i kW k
c 2 16γ ntk m 4096
which completes the proof.
Thanks to Lemmas 2.3 and 2.4, the proof of Theorem 2.2 is now immediate.
√
Proof of Theorem 2.2. As in the statement, define R := γntk m/32, and note that Assumption 1.3
holds almost surely for any finite sample due to Assumption 1.4. The analysis now uses two stages.
The first stage is handled by Lemma 2.3 run until time τ := n, whereby, with probability at least
1 − 7δ, √
b τ ) ≤ 1 < `(0) ,
R(W kWτ − W0 k ≤
γntk m
≤ .
R
5n n 80 2
The second stage now follows from Lemma 2.4: since Wτ as above satisfies all the conditions
of Lemma 2.4, there exists Wt with kWt − W0 k = R, and γ̊(Ws ) ≥ γntk 2 /4096 for all s ≥ t, and
ct ) ≥ γntk /4096 for the rebalanced iterates W √ √
γ̊(W ct = (at / γntk , Vt γntk ). The first generalization
bound now follows by Lemma 2.5. The second generalization bound uses p(x, y; Wt ) = p(x, y; W ct )
for all (x, y), whereby Pr[p(x, y; Wt ) ≤ 0] = Pr[p(x, y; Wct ) ≤ 0], and the earlier margin-based
generalization bound can be invoked with the improved margin of W ct .
Lastly, here are the details from the construction leading to Proposition 2.6 which give a setting
where GF converges to parameters with higher margin than the maximum margin linear predictor.
Proof of Proposition 2.6. Let u be the maximum margin linear separator for ((xi , yi ))ni=1 ; necessarily,
there exists at least one support vector in each of S1 and S2 , since otherwise u is also the maximum
margin linear predictor over just one of the sets, but then the condition
1
min{ xi , xj : xi ∈ S1 , xj ∈ S2 } ≤ − √ (B.6)
2
implies that u points away from and is incorrect on the set with no support vectors.
As such, let v1 ∈ S1 and v2 ∈ S2 be support vectors, and consider the addition of a single data
point (x0 , +1) from a parameterized family of points {zα : α ∈ [0, 1]}, defined as follows. Define
v0 := (I − v1 v1T )u/k(I − v1 v1T )uk, which is orthogonal to −v1 by construction, and consider the
geodesic between v0 and −v1 : p
zα := αv0 − 1 − α2 v1 ,
which satisfies kzα k = 1 by construction by orthogonality, meaning
p
kzα k2 = α2 − 2α 1 − α2 hv0 , v1 i + (1 − α2 ) = 1.
In order to pick a specific point along the geodesic, here are a few observations.
36
1. First note that −v2 is a good predictor for S1 : eq. (B.6) implies
1
min h−v2 , xi = − max hv2 , xi ≥ √ .
x∈S1 x∈S1 2
√
It follows similarly that −v1 is a good predictor for S2 : analogously, minx∈S2 h−v1 , xi ≥ 1/ 2.
√
2. It is also the case that −v1 is a good predictor for every zα with α ≤ 1/ 2:
p 1
h−v1 , zα i = 1 − α2 ≥ √ .
2
4. Lastly, as α → 0, then zα → −v1 , but the resulting set of points S1 ∪ S2 ∪ {z0 } is linearly
separably only with margin at most zero (if it is linearly separable at all). Consequently,
there exist choices of α0 ∈ [0, 1] so that the resulting maximum margin linear predictor over
S1 ∪ S2 ∪ {zα } has arbitrarily small yet still positive margins, and for concreteness choose some
α0 so that the resulting linear separability margin γ satisfies γ ∈ (0, c/(215 d)), where c is the
positive constant in Lemma 2.3.
As a consequence of the last two bullets, using label +1 and inputs S1 ∪ S2 ∪ {zα0 }, the maximum
margin linear predictor has margin γ, Assumption 1.3 is satisfied with margin at least γ1 , but most
importantly, by Lemma 2.3 (with m0 given by the width lower bound required there), GF will
achieve a margin γ2 ≥ cγ12 /4 ≥ 2γ.
It only remains to argue that the linear predictor itself is a KKT point (for the margin objective),
but this is direct and essentially from prior work (Lyu and Li, 2019): e.g., taking all outer weights
to be equal and positive, and all inner weights to be equal to u and of equal magnitude and equal to
the outer layer magnitude, and then taking all these balanced norms to infinity, it can be checked
that the gradient and parameter alignment conditions are asymptotically satisfied (indeed, the
ReLU can be ignored since all nodes have positive inner product with all examples), which implies
convergence to a KKT point (Lyu and Li, 2019, Appendix C).
Lemma C.1. Let (β1 , . . . , βr ) be given with kβk k = 1, and suppose (vj )m
j=1 are sampled with
√
vj ∼ Nd / d, with corresponding normalized weights vej . If
d
2 r
m≥2 ln ,
δ
2
then maxk minj ke
vj − βk k ≤ , alternatively mink maxj vejT βk ≥ 1 − 2.
37
Proof. By standard sampling estimates (Ball, 1997, Lemma 2.3), for any fixed k and j, then
1 d−1
vj − βk k ≤ ] ≥
Pr[ke ,
2 2
and since all (vj )m
j=1 are iid,
m
Pr[∃j ke
vj − βk k ≤ ] = 1 − Pr[∀j ke
vj − βk k > ] = 1 − 1 − Pr[ke
v1 − βk k ≤ ]
d−1 !
m δ
≥ 1 − exp − ≥1− ,
2 2 r
and union bounding over all (βk )rk=1 gives maxk minj ke vj − βk k ≤ . To finish, the inner product
vj − βk k2 = 2 − 2e
form comes by noting ke vjT βk and rearranging.
First comes the proof of Theorem 3.2, whose entirety is the construction of a potential Φ and a
verification that it satisfies the conditions in Lemma A.6.
Proof of Theorem 3.2. To start the construction of φ, first define per-weight potentials φk,j which
track rotation of mass towards each βk :
φk,j := φk (wj ) := φ αek aj σ(vjT βk ) ≥ (1 − )kaj vj k .
d
The goal will be to show that dt φk,j is monotone nondecreasing, which shows that weights get
d
trapped pointing towards each βk once sufficiently close. As such, to develop dt φk,j , note (using
aj := sgn(aj ) and vej := vj /kvj k),
e
d T T d T
α
ek aj σ(vj βk ) = α
ek ȧj σ(vj βk ) + aj σ(vj βk )
dt dt
X h i
= −eαk `0i yi σ(vjT xi )σ(vjT βk ) + a2j σ 0 (vjT βk )xTi βk ,
i
X h i
= −e
αk `0i yi kvj k2 σ(e vjT βk ) + a2j σ 0 (vjT βk )xTi βk ,
vjT xi )σ(e
i
d d 1/2
kaj vj k = aj vj , aj vj
dt dt
2 aj vj , ȧj vj + aj v̇j
= 1/2
2 aj vj , aj vj
i
P 0 h 2 0
aj vj , − i `i yi vj σ(vj xi ) + aj σ (vj xi )xi
T T
=
kaj vj k
P 0 2
− i `i pi (wj )kwj k
=
kaj vj k
X
=− `0i e
aj σ(evjT xi )kwj k2 ,
i
d dh i
φj = ek aj σ(vjT βk ) − (1 − )kaj vj k
α when φj ∈ (0, 1)
dt dt
X h
=− `0i yi kvj k2 e vjT xi ) α
aj σ(e aj σ(e
ek e vjT βk ) − (1 − )
i
i
+ a2j αek σ 0 (vjT βk )βkT xi − (1 − )e vjT xi ) .
aj σ(e
38
Analyzing the last two terms separately, the first (the coefficient to kvk2 is nonnegative since the
term in parentheses is a rescaling of φj :
φj
α vjT βk ) − (1 − ) =
aj σ(e
ek e > 0,
kaj vj k
and moreover this holds for any choice of > 0. The second term (the coefficient of a2j ) is more
complicated; to start, fixing any example (xi , yi and writing zi := cβk + z⊥ , where necessarily c ≥ γnc ,
ej := e
note (using u aj vej )
2
z⊥ 2 2
,u
ej ≤ (I − βk βkT )e ej ≤ 1 − (1 − )2 = 2 − 2 ≤ 2,
uj = 1 − βk , u
kz⊥ k
and thus, since φj ∈ (0, 1) whereby σ(vjT xi ) = vjT xi , the nonnegativity of the second term follows
from the narrowness of the around βk (cf. Assumption 3.1):
ek βkT xi yi − (1 − )e
α aj vejT xi yi ≥ c − (1 − ) e
aj vej , cβk + z⊥
√
≥ c − (1 − )kz⊥ k 2
≥ c − (1 − )γnc
≥ 0,
whereby
P h d d
i
d 1X j φ k,j dt ka v
j j k + ka v k φ
j j dt k,j
Φ(w) = |αk | P
dt 4
k j φk,j kaj vj k
P 0 P
1X − i `i yi j φk,j e aj σ(evjT xi )kwj k2
≥ |αk | P
4
k j φk,j kaj vj k
vj − βk + βk )T xi kwj k2
P
1 X X yi j φk,j α ek σ (e
= Q |αk | qi P
4
k i∈Sk j φk,j kaj vj k
P
1 X X yi j φk,j (γnc − ) 2kaj vj k
≥ Q αk qi P
4
k i∈Sk j φk,j kaj vj k
γgl −
≥Q .
r
Written another way, this establishes dΦ(Wt )/ dt ≥ Qb γ with γ b = (γnc − )/k, which is one of the
conditions needed in Lemma A.6. The other properties meanwhile are direct:
1X X 1X 1
Φ(Wt ) ≤ |αk | ln φj,k kwj k2 ≤ |αk | ln kWt k2 = ln kWt k,
4 4 2
k j k
and Φ(W0 ) > −∞ due to random initialization (cf. Lemma C.1), which allows the application of
Lemma A.6 and Lemma 2.5 and completes the proof.
39
To close, the proof of Theorem 3.3.
Proof of Theorem 3.3. Throughout the proof, use W = ((aj , bk ))m j=1 to denote the full collection of
parameters, even in this scalar parameter setting.
To start, note that a2k = b2k for all times t; this follows directly, from the initial condition
√
ak (0)2 = bk (0)2 = 1/ m, since
ak (t)2 − bk (t)2 = ak (t)2 − ak (0)2 − bk (t)2 + bk (0)2
Z t
= ak ȧk − bk ḃk ds
0
Z tX
|`0i | ak σ(bk vkT xi ) − ak σ 0 (bk vkT xi )vkT xi ds
=
0 i
= 0.
This also implies that a2k + b2k = 2a2k = 2|ak | · |bk |.
For each βk , choose j so that kebj vej − βk k ≤ = γgl /2 and e aj = sgn(αk ); this holds with
probability at least 1 − 3δ over the draw of W due to the choice of m, first by noting that with
probability at least 1 − δ, there are at least m/4 positive aj and m/4 negative aj , and then with
probability 1 − 2δ by applying Lemma C.1 to (ebj vj )m j=1 (which are equivalent in distribution to
(vj )j=1 ) for each choice of output sign. For the rest of the proof, reorder the weights ((aj , bj , vj ))m
m
j=1
so that each (αk , βk ) is associated with (ak , bk , vk ).
Now consider the potential function
r
1X
Φ(W ) := |αk | ln a2k + b2k .
4
k=1
The time derivative of this potential can be lower bounded in terms of the reference margin:
d XX ak σ(bk vkT xi )
Φ(W ) = |`0i |yi |αk |
dt
i
a2k + b2k
k
XX |ak |kbk vk kσ(ebk vekT xi )
=Q qi y i α k
2|ak | · |bk |
k i
XX
=Q qi yi αk σ (ebk vek − βk + βk )T xi
k i
XX XX
≥Q qi yi αk σ βkT xi − Q qi |αk | ebk vek − βk
k i k i
X XX
≥ Qγgl qi − Q qi |αk |
i k i
X
=Q qi γgl − > 0.
i
As an immediate consequence, the signs of all ((ak , bk ))rk=1 never flip (since this would require their
values to pass through 0, which would cause Φ(W ) = −∞ < Φ(W (0))). This implies the preceding
lower bound always holds, and since Φ(W0 ) > −∞ thanks to random initialization, and
r r
1X 1X 1
Φ(W ) = |αk | ln a2k + b2k ≤ |αk | ln kW k2 = ln kW k,
4 4 2
k=1 k=1
all conditions of Lemma A.6 are satisfied, and the proof is complete after applying Lemma 2.5.
40