You are on page 1of 40

Feature selection with gradient descent

on two-layer networks in low-rotation regimes


Matus Telgarsky∗
arXiv:2208.02789v1 [cs.LG] 4 Aug 2022

Abstract
This work establishes low test error of gradient flow (GF) and stochastic gradient descent
(SGD) on two-layer ReLU networks with standard initialization, in three regimes where key sets
of weights rotate little (either naturally due to GF and SGD, or due to an artificial constraint),
and making use of margins as the core analytic√technique. The first regime is near initialization,
specifically until the weights have moved by O( m), where m denotes the network width, which
is in sharp contrast to the O(1) weight motion allowed by the Neural Tangent Kernel (NTK);
here it is shown that GF and SGD only need a network width and number of samples inversely
proportional to the NTK margin, and moreover that GF attains at least the NTK margin itself,
which suffices to establish escape from bad KKT points of the margin objective, whereas prior
work could only establish nondecreasing but arbitrarily small margins. The second regime is
the Neural Collapse (NC) setting, where data lies in extremely-well-separated groups, and the
sample complexity scales with the number of groups; here the contribution over prior work is
an analysis of the entire GF trajectory from initialization. Lastly, if the inner layer weights are
constrained to change in norm only and can not rotate, then GF with large widths achieves
globally maximal margins, and its sample complexity scales with their inverse; this is in contrast
to prior work, which required infinite width and a tricky dual convergence assumption. As purely
technical contributions, this work develops a variety of potential functions and other tools which
will hopefully aid future work.

1 Introduction
This work studies standard descent methods on two-layer networks of width m, specifically stochastic
gradient descent (SGD) and gradient flow (GF) on the logistic and exponential losses, with a goal of
establishing good test test error (low sample complexity) via margin theory (see Section 1.2 for full
details on the setup). The analysis considers three settings where weights rotate little, but rather
their magnitudes change a great deal, whereby the networks select or emphasize good features.
The motivation and context for this analysis is as follows. While the standard promise of deep
learning is to provide automatic feature learning, by contrast, standard optimization-based analyses
typically utilize the Neural Tangent Kernel (NTK) (Jacot et al., 2018), which suffices to establish
low training error (Du et al., 2018; Allen-Zhu et al., 2018; Zou et al., 2018), and even low testing
error (Arora et al., 2019; Li and Liang, 2018; Ji and Telgarsky, 2020b), but ultimately the NTK is
equivalent to a linear predictor over fixed features given by a Taylor expansion, and fails to achieve
the aforementioned feature learning promise.
Due to this gap, extensive mathematical effort has gone into the study of feature learning, where
the feature maps utilized in deep learning (e.g., those given by Taylor expansions) may change
drastically, as occurs in practice. The promise here is huge: even simple tasks such as 2-sparse parity

<mjt@illinois.edu>; comments greatly appreciated.

1
(learning the parity of two bits in d dimensions) require d2 / samples for a test error of  > 0 within
the NTK or any other kernel regime, but it is possible for ReLU networks beyond the NTK regime
to achieve an improved sample complexity of d/ (Wei et al., 2018). This point will be discussed in
detail throughout this introduction, though a key point is that prior work generally changes the
algorithm to achieve good sample complexity; as a brief summary, many analyses either add noise
to the training process (Shi et al., 2022; Wei et al., 2018), or train the first layer for only one step
and thereafter train only the second (Daniely and Malach, 2020; Abbe et al., 2022; Barak et al.,
2022), and lastly a few others idealize the network and make strong assumptions to show that global
margin maximization occurs, giving the desired d/ sample complexity (Chizat and Bach, 2020). As
will be expanded upon briefly, and indeed is depicted in Figure 1 and summarized in Table 1, in the
case of the aforementioned 2-sparse parity, the present work will achieve the optimal kernel sample
complexity d2 / with a standard SGD (cf. Theorem 2.1), and a beyond-kernel sample complexity of
d/ with an inefficient and somewhat simplified GF (cf. Theorem 3.3).
The technical approach in the present work is to build upon the rich theory of margins in
machine learning (Boser et al., 1992; Schapire and Freund, 2012). While the classical theory provides
that descent methods on linear models can maximize margins (Zhang and Yu, 2005; Telgarsky,
2013; Soudry et al., 2017), recent work in deep networks has revealed that GF eventually converges
monotonically to local optima of the margin function (Lyu and Li, 2019; Ji and Telgarsky, 2020a),
and even that, under a variety of conditions, margins may be globally maximized (Chizat and Bach,
2020). As mentioned above, global margin maximization is sufficient to beat the NTK sample
complexity for many well-studied problems, including 2-sparse parity (Wei et al., 2018).
The contributions of this work fall into roughly two categories.

1. Section 2: Margins at least as good as the NTK. The first set of results may sound a
bit unambitious, but already constitute a variety of improvements over prior work. Throughout
these contributions, let γntk denote the NTK margin, a quantity defined and discussed in detail
in Section 1.2.

(a) Theorem 2.1: O(1/(γ 2


ntk )) steps of SGD (with one example per iteration) on a network of
e
2
width m = Ω(1/γntk ) suffice to achieve test error  > 0. For 2-sparse parity, this suffices to
e
achieve the optimal within-kernel sample complexity d2 /, and in fact the computational
cost and sample complexity improve upon all existing prior work on efficient methods (cf.
Table 1).
(b) Theorem 2.2: with similar width and samples, GF also achieves test error  > 0. Arguably
more importantly, this analysis gives the first guarantee in a general setting that constant
margins (in fact γntk /4096) are achieved; prior work only established nondecreasing but
arbitrarily small margins (Lyu and Li, 2019). Furthermore, this result suffices to imply
that GF can escape bad local optima of the margin objective (cf. Proposition 2.6).
(c) These proofs are not within the NTK: the NTK requires weights to move at most O(1),

whereas weights can move by O( m) within these proofs; in fact, the first gradient step of

SGD is shown to have norm at least γ m (and the step size is O(1)). These proofs will
therefore hopefully serve as the basis of other proofs outside the NTK in future work.

2. Section 3: Margins surpassing the NTK. The remaining margin results are able to beat
the preceding NTK margins, however they pay a large price: the network width is exponentially
large, and the method is GF, not the discrete-time algorithm SGD. Even so, the results already
require new proof techniques, and again hopefully form the basis for improvements in future
work.

2
(a) Theorem 3.2: the first result is for Neural Collapse (NC), a setting where data lies in k
tightly packed clusters which are extremely far apart; in fact, each cluster is correctly
labeled by some vector βk , and no other data may fall in the halfspace defined by βk .
Despite this, it is an interesting setting with a variety of empirical and theoretical works,
for instance showing various asymptotic stability properties (Papyan et al., 2020); the
contribution here is to analyze the entire GF trajectory from initialization, and show that
it achieves a margin (and sample complexity) at least as good as the sparse ReLU network
with one ReLU pointing at each cluster.
(b) Theorem 3.3: under a further idealization that the inner weights are constrained to change
in norm only (meaning the inner weights can not rotate), global margin maximization
occurs. As mentioned before, this suffices to establish sample complexity d/ in 2-sparse
parity, which beats the d2 / of kernel methods. As a brief comparison to prior work, similar
idealized networks were proposed by Woodworth et al. (2020), however a full trajectory
analysis and relationship to global margin maximization under no further assumptions
was left open. Additionally, without the no-rotation constraint, Chizat and Bach (2020)
showed that infinite-width networks under a tricky dual convergence assumption also
globally maximize margins; the proof technique here is rooted in a desire to drop this dual
convergence assumption, a point which will be discussed extensively in Section 3.
(c) New potential functions: both Theorem 3.2 and Theorem 3.3 are proved via new potential
arguments which hopefully aid future work.

This introduction will close with further related work (cf. Section 1.1), notation (cf. Section 1.2),
and detailed definitions and estimates of the various margin notions (cf. Section 1.2). Margin
maximization at least as good as the NTK is presented in Section 2, margins beyond the NTK are
in Section 3, and open problems and concluding remarks appear in Section 4. Proofs appear in the
appendices.

1.1 Further related work


Margin maximization. The concept and analytical use of margins in machine learning originated
in the classical perceptron convergence analysis of Novikoff (1962). The SGD analysis in Theorem 2.1,
as well as the training error analysis in Lemma 2.3 were both established with a variant of the
perceptron proof; similar perceptron-based proofs appeared before (Ji and Telgarsky, 2020b; Chen
et al., 2019), however they required width 1/γntk 8 , unlike the 1/γ 2 here, and moreover the proofs
ntk
themselves were in the NTK regime, whereas the proof here is not.
Works focusing on the implicit margin maximization or implicit bias of descent methods are more
recent. Early works on the coordinate descent side are (Schapire et al., 1997; Zhang and Yu, 2005;
Telgarsky, 2013); the proof here of Lemma 2.4 uses roughly the proof scheme in (Telgarsky, 2013).
More recently, margin maximization properties of gradient descent were established, first showing
global margin maximization in linear models (Soudry et al., 2017; Ji and Telgarsky, 2018b), then
showing nondecreasing smoothed margins of general homogeneous networks (including multi-layer
ReLU networks) (Lyu and Li, 2019), and the aforementioned global margin maximization result for
2-layer networks under dual convergence and infinite width (Chizat and Bach, 2020). The potential
functions used here in Theorems 3.2 and 3.3 use ideas from (Soudry et al., 2017; Lyu and Li, 2019;
Chizat and Bach, 2020), but also the shallow linear and deep linear proofs of Ji and Telgarsky (2019,
2018a).

3
(a) Trajectories (|aj |vj )m m
j=1 with m = 16 across time. (b) Trajectories (|aj |vj )j=1 with m = 256 across time.

1.0 1.0

0.8 0.8

0.6 0.6

0.4 0.4

0.2 0.2

0.0 0.0

(c) Sorted per-node rotations for m ∈ {16, 256}. (d) Sorted per-node norms for m ∈ {16, 256}.

Figure 1. Two runs of approximate GF (GD with small step size) on 2-sparse parity with d = 20
and n = 64 and m ∈ {16, √ 256}: specifically, a new data point x is sampled uniformly from the
d
hypercube corners {±1/ d} , and the label y for simplicity is the parity of the first two bits,
meaning y = dx1 x2 . The first two dimensions of the data and the trajectories of the m nodes
are depicted in Figures 1a andD1b for m ∈ {16,E 256}. Figure 1c shows per-node rotation over
m
jv (0) v (t)
j
time in sorter order, meaning kvj (0)k , kvj (t)k . Figure 1d shows per-node relative norms in
 m j=1
ka v k
sorted order, meaning maxk jkajk vk k . Due to projection onto the first two coordinates, the 64
j=1
data points land in 4 clusters, and are colored red if negatively labeled, blue if positively labeled.
Individual nodes are colored red or blue based on the sign of their output weights. The shades of
red and blue darken for nodes whose total norm is larger. Comparing Figure 1c and Figure 1d, large
norm and large rotation go together. While Figure 1a is highly noisy, the behavior of Figures 1b
to 1d was highly regular across training. The trend of larger width leading to less rotation and
greater norm imbalance was consistent across runs. This somewhat justifies the lack of rotation with
exponentially large width in Theorem 3.3, and the overall terminology choice of feature selection.

4
Reference Algorithm Technique m n t
(Ji and Telgarsky, 2020b) SGD perceptron d8 d2 / d2 /
Theorem 2.1 SGD perceptron d2 d2 / d2 /
(Barak et al., 2022) 2-phase SGD correlation O(1) d4 /2 d2 /2
(Wei et al., 2018) WF+noise margin ∞ d/ ∞
(Chizat and Bach, 2020) WF margin ∞ d/ ∞
Theorem 3.3 scalar GF margin dd d/ ∞

Table 1. Performance on 2-sparse parity by a variety of works, loosely organized by technique;


see Section 1.1 for details. Briefly, m denotes width, n denotes total number of samples (across all
iterations), and t denotes the number of algorithm iterations. Overall, the table captures tradeoffs
in all parameters, and understanding the Pareto frontier is an interesting avenue for future work.

Feature learning. There are many works in feature learning; a few also carrying explicit guar-
antees on 2-sparse parity are summarized in Table 1. An early work with high relevance to the
present work is (Wei et al., 2018), which in addition to establishing that the NTK requires Ω(d2 /)
samples whereas O(d/) suffice for the global maximum margin solution, also provided a noisy
Wasserstein Flow (WF) analysis which achieved the maximum margin solution, albeit using noise,
infinite width, and continuous time to aid in local search, The global maximum margin work of
Chizat and Bach (2020) was mentioned before, and will be discussed in Section 3. The work of
Barak et al. (2022) uses a two phase algorithm: the first step has a large minibatch and effectively
learns the support of the parity in an unsupervised manner, and thereafter only the second layer is
trained, a convex problem which is able to identify the signs within the parity; as in Table 1, this
work stands alone in terms of the narrow width it can handle. The work of (Abbe et al., 2022)
uses a similar two-phase approach, and while it can not learn precisely the parity, it can learn an
interesting class of “staircase” functions, and presents many valuable proof techniques. Another
work which operates in two phases and can learn an interesting class of functions which excludes
the parity (specifically due to a Jacobian condition) is the recent work of (Damian et al., 2022).
Other interesting feature learning works are (Shi et al., 2022; Bai and Lee, 2019).

1.2 Notation
Architecture and initialization. With the exception P of Theorem 3.3, the architecture will be a
2-layer ReLU network of the form x 7→ F (x; w) = j aj σ(vjT x) = aT σ(V x), where σ(z) = max{0, z}
is the ReLU, and where a ∈ Rm and V ∈ Rm×d have initialization roughly matching the variances

of pytorch default √ initialization, meaning a ∼ N m / m (m iid Gaussians with variance 1/m)
and V ∼ Nm×d / d (m × d iid Gaussians with variance 1/d). These parameters (a, V ) will be
collected into a tuple W = (a, V ) ∈ Rm × Rm×d ≡ Rm×(d+1) , and for convenience per-node tuples
wj = (aj , vj ) ∈ R × Rd will often be used as well.
Given a pair (x, y) with x ∈ Rd and y ∈ {±1}, the prediction or unnormalized margin mapping is
p(x, y; W ) = yF (x; W ) = yaT σ(V x); when examples ((xi , yi ))ni=1 are available, a simplified notation
pi (W ) := p(xi , yi ; W ) is often used, and moreover define a single-node variant pi (wj ) := yi aj σ(vjT xi ).
Throughout this work, kxk ≤ 1.
It will often be useful to consider normalized parameters within proofs: define vej := vj /kvj k
and eaj := sgn(aj ) := aj /|aj |.

5
SGD and GF. The loss function ` in this work will always be either the exponential loss
`exp (z) := exp(−z), or the logistic loss `log (z) := ln(1 + exp(−z)); the corresponding empirical risk
Rb is
n
1X
R(p(W
b )) := `(pi (W )),
n
i=1

which used p(W ) := (p1 (W ), . . . , pn (W )) ∈ Rn . With ` and R


b in hand, the descent methods are
defined as

Wi+1 := Wi − η ∂ˆW `(pi (Wi )), stochastic gradient descent (SGD), (1.1)
d
Ẇt := Wt = −∂¯W R(p(W b t )), gradient flow (GF), (1.2)
dt

where ∂ˆ and ∂¯ are appropriate generalizations of subgradients for the present nonsmooth nonconvex
setting, detailed as follows. For SGD, ∂ˆ will denote any
 valid P
element of the Clarke  differential (i.e.,
a measurable selection); for example, ∂F ˆ (x; W ) = σ(V x), 0 T T
j aj σ (vj x)ej x , where ej denotes
the jth standard basis vector, and σ 0 (vjT xi ) ∈ [0, 1] is chosen in some consistent and measurable
way. This approach is able to model what happens in a standard software library such as pytorch.
For GF, ∂¯ will denote the unique minimum norm element of the Clarke differential; typically, GF
is defined as a differential inclusion, which agrees with this minimum norm Clarke flow almost
everywhere, but here the minimum norm element is used to define the flow purely for fun. Since at
this point there is a vast array of ReLU network literature using Clarke differentials, it is merely
asserted here that chain rules and other basic properties work as required, and ∂ˆ and ∂¯ are used
essentially as gradients, and details are deferred to the detailed discussions in (Lyu and Li, 2019; Ji
and Telgarsky, 2020a; Lyu et al., 2021).
Many expressions will have a fixed time index t (e.g., Wt ), but to reduce clutter, it will either
appear as a subscript, or as a function (e.g., W (t)), or even simply dropped entirely.

Margins and dual variables. Firstly, for convenience, often `i (W ) = `(pi (W )) and `0i (W ) :=
`0 (pi (W )) are used; since `0i is negative, often |`0i | is written.
The start of developing margins and dual variables is the observation that F and pi are 2-
homogeneous in W , meaning F (x; cW ) = caT σ(cV x) = c2 F (x; w) for any x ∈ Rd and c ≥ 0 (and
pi (cW ) = c2 pi (W )). It follows that F (x; W ) = kW k2 F (x; W/kW k), and thus F and pi scale
quadratically in kW k, and it makes sense to define a normalized prediction mapping pei and margin
γ as
pi (W ) mini pi (W )
pei (W ) := , γ(W ) := = min pei (W ).
kW k2 kW k2 i

Due to nonsmoothness, γ can be hard to work with, thus, following Lyu and Li (2019), define the
smoothed margin γ
e and the normalized smoothed margin γ̊ as
 
  X γ
e(W )
e(W ) := `−1 nR(W
γ b ) = `−1  `(pi (W )) , γ̊(W ) := ,
kW k2
i

where a key result is that γ̊ is eventually nondecreasing (Lyu and Li,P


2019). These quantities
 may
look complicated and abstract, but note for `exp that γ e(W ) := − ln i exp(−p i (W )) .

6
When working with gradients of losses and of smoothed margins, it will be convenient to define
dual variables (qi )ni=1
 
P P
−1 
X ∇p i `(pi ) ∇p i `(p)
q := ∇p ` `(pi ) = 0 −1 P
 = 0 ,
` (` ( i `(pi ))) ` (e
γ (p))
i

which made use of the inverse function P theorem. Correspondingly define Q := `0 (e


γ (p)), whereby
0
P
−`i = qi Q; for the exponential loss, Q = i exp(−pi ) and i qi P
= 1, and while these quantities are
more complicated for the logistic loss, they eventually satisfy i qi ≥ 1 (Ji and Telgarsky, 2019,
Lemma 5.4, first part, which does not depend on linear predictors).

Margin assumptions. It is easiest to first state the global margin definition, used in Section 3.
The first version is stated on a finite sample.
Assumption 1.1. For given examples ((xi , yi ))ni=1 , there exists a scalar γgl > 0 and parameters
((αk , βk ))rk=1 with αj ∈ R with kαk1 = 1 and βk ∈ Rd with kβk k2 = 1, where
r
X
min yi αk σ(βkT xi ) ≥ γgl . ♦
i
k=1

This definition can also be extended to hold over a distribution, whereby it holds for all finite
samples almost surely.
Assumption 1.2. There exists γgl > 0 and parameters ((αk , βk ))rk=1 so that Assumption 1.1 holds
almost surely for any ((xi , yi ))ni=1 drawn iid for any n from the underlying distribution (with the
same γgl and ((αk , βk ))rk=1 ). ♦
Assumptions 1.1 and 1.2 are used in this work to approximate the best possible margin amongst
all networks of any width; in particular, the formalism here is a simplification of the max-min
characterization given in (Chizat and Bach, 2020, Proposition 12, optimality conditions). As will
be seen shortly in Proposition 1.6, this definition suffices to achieve sample complexity d/ with
gradient flow on 2-sparse parity, as in Table 1. Lastly, while the appearance of an `1 norm may be a
surprise, it can be seen to naturally arise from 2-homogeneity:
P
pi (W ) j pi (wj )
X kwj k2
= = α j p i (w j ), where α j := , thus kαk1 = 1.
kW k2 kW k2 kW k2
e
j

Next comes the definition of the NTK margin γntk . As a consequence of 2-homogeneity,
W, ∂¯W pi (W ) = 2pi (W ), which can be interpreted as a linear predictor with weights W and
features ∂¯W pi (W )/2. Decoupling the weights and features gives W, ∂¯W pi (W0 ) , where W0 is at
initialization, and W is some other choice. To get the full definition from here, W is replaced with
an infinite-width object via expectations; similar definitions originated with the work of Nitanda
and Suzuki (2019), and were then used in (Ji and Telgarsky, 2020b; Chen et al., 2019).
Assumption 1.3. For given examples ((xi , yi ))ni=1 , there exists a scalar γntk > 0 and a mapping
θ : Rd+1 → Rd+1 with θ(w) = 0 whenever kwk ≥ 2 so that

min Ew∼Nθ θ(w), ∂¯w pi (w) ≥ γntk ,


i

where w = (a, v) ∼ Nθ means a ∼ N and v ∼ Nd / d. ♦

7
Similarly to the global maximum margin definition, Assumption 1.3 can be restated over the
distribution.
Assumption 1.4. There exists γntk > 0 and a transport θ : Rd+1 → Rd+1 so that Assumption 1.3
holds almost surely for any ((xi , yi ))ni=1 drawn iid for any n from the underlying distribution (with
the same γntk and θ). ♦
Before closing this section, here are a few estimates of γntk and γgl . Firstly, both function classes
are universal approximators, and thus the assumption can be made to work for any prediction
problem with pure conditional probabilities (Ji et al., 2020). Next, as a warmup, note the following
estimates of γntk and γgl , for linear predictors, with an added estimate of showing the value of
working with both layers in the definition of γntk .

Proposition 1.5. . Let examples ((xi , yi ))ni=1 be given, and suppose they are linearly separable:
there exists kūk = 1 and γ
b > 0 with mini yi xTi ū ≥ γ
b.
 γ
1. Choosing θ(a, v) := 0, sgn(a)ū · 1[k(a, v)k ≤ 2], then Assumption 1.3 holds with γntk ≥ 32b
.

2. Choosing θ(a, v) := sgn(ūT v), 0 · 1[k(a, v)k ≤ 2], then Assumption 1.3 holds with γntk ≥ 16γb√d .


3. Choosing α = (1/2, −1/2) and β = (ū, −ū), then Assumption 1.1 holds with γgl ≥ γb2 .

Margin estimates
√ for 2-sparse parity are as follows; the key is that γntk scales with 1/d whereas
γgl scales with 1/ d, which suffices to yield the separations in Table 1. The bound on γntk is also
necessarily an upper bound, since otherwise the estimates due to Ji and Telgarsky (2020b), which
are within the NTK regime, would beat NTK lower bounds (Wei et al., 2018).

Proposition 1.6. Suppose 2-sparse parity data, meaning inputs are supported on Hd := {±1/ d}d ,
and for any x ∈ Hd , the label is the product of two fixed coordinates dxa xb with a 6= b.
1
1. Assumption 1.4 holds with γntk ≥ 50d .

2. Assumption 1.2 holds with γgl ≥ √1 .


8d

2 Margins at least as good as the NTK


This section collects results which depend on the NTK margin γntk (cf. Assumptions 1.3 and 1.4).
SGD is presented first in Section 2.1, with GF following in Section 2.2. The SGD results will not
establish large margins, only low test error, whereas the GF proofs establish both. As mentioned
before, these results yield the good computation and sample complexity for 2-sparse parity in
Table 1, and are also enough to establish escape from bad KKT points in Section 2.2.

2.1 Stochastic gradient descent


The only SGD guarantee in this work is as follows.

Theorem 2.1. Suppose the data distribution satisfies Assumption 1.4 for some γntk > 0, let time t
be given, and suppose width m and step size η satisfy
" #
64 ln(t/δ) 2 2
 
γntk γntk
m≥ , η∈ √ , .
γntk 10 m 6400

8
Then, with probability at least 1 − 7δ, the SGD iterates (Ws )s≤t with logistic loss ` = `log satisfy

  8 ln(1/δ) 2560
min Pr p(x, y; Ws ) ≤ 0 ≤ + 2 , (test error bound),
s<t t tγntk

80η m
max kWs − W0 k ≤ , (norm bound).
s<t γntk

Note that while maxs<t kWs −W0 k ≤ 80η m/γntk is only an upper bound, in fact, by Lemma A.3,
√ 2 /6400) is enough
the first gradient has norm γntk m, thus one step (with maximal step size η = γntk
to exit the NTK regime. As another incidental remark, note that this proof requires the logistic loss
`log , and in fact breaks with the exponential loss `exp . Lastly, for the 2-sparse parity, γntk ≥ 1/(50d)
as in Proposition 1.6, which after plugging in to Theorem 2.1 gives the corresponding row of Table 1.
2 , whereas prior work has 1/γ 8 (Ji and Telgarsky,
As discussed previously, the width is only 1/γntk ntk
2020b; Chen et al., 2019). The proof of Theorem 2.1 is in Appendix B.1, but here is a sketch.

1. Sampling a good finite-width comparator W . Perhaps the heart of the proof is showing that the
parameters θ ∈ Rm×(d+1) given by θj := θ(wj ) satisfy
D E γ √m
ˆ
θ, ∂pi (W0 ) ≥
ntk
, ∀i.
2
More elaborate versions of this are used in the proof, and sometimes require a bit of surprising
algebra (cf. Lemma A.3), which for instance seem to be able to treat σ as though it were
smooth, and without the usual careful activation-accounting in ReLU NTK proofs.

2. Standard expand-the-square. As is common in optimization proofs, the core potential is a


squared Euclidean norm to a good comparator (in this case denoted by W , and defined in
terms of W0 and θ), whereby expanding the square gives
ˆ s (w) − W k2
kWs+1 − W k2 = kWs − η ∂`
D E 2
= kWs − W k2 − 2η ∂` ˆ s (Ws ), Ws − W + η 2 ∂`ˆ s (Ws )
D E 2
= kWs − W k2 + 2η`0s (Ws ) ∂pˆ s (Ws ), W − Ws + η 2 `0 (Ws )2 ∂p
s
ˆ s (Ws ) .
P
Applying s to both sides and telescoping, the (summations of the) last two terms will need
ˆ s (Ws )k2 ; by the form of ps and the large initial
to be controlled. A worrisome prospect is k∂p
weight norm, this can be expected to scale as O(m), but η 2 = O(1); how can this term be
swallowed? This point will be returned to shortly.
Another critical aspect of the proof is that, in order to control various terms, it will be necessary

to maintain maxs<t kWs − W0 k = O( m) throughout. The proof handles this in a way which
is common in deep network optimization proofs (albeit with a vastly larger norm here): by
carefully choosing the parameters of the proof, one can let τ denote the first iteration the bound
is violated, and then derive that τ > t, and all the derivations with the assumed bound in fact
hold unconditionally.
D E
The middle term 2η`0s (Ws ) ∂p ˆ s (Ws ), W − Ws will be handled by a similar trick to one used
in (Ji and Telgarsky, 2020b): convexity can still be applied to `s (just not to ` composed with
p), and the remaining expression can be massaged via homogeneity and the choice of W .

9
ˆ s (Ws )k2 : the perceptron argument. Using the first point above, it is
3. Controlling `0s (Ws )2 k∂p
possible to show

1D E X
0 γ m
kWt − W0 k ≥ −θ, Wt − W0 ≥ η |`s (Ws )| ,
2 s<t
4

which can then be massaged to control the aforementioned squared gradient term. Interest-
ingly, this proof step is reminiscent of the perceptron convergence proof, and the quantity
0
P
|`
s<t s (Ws )| is exactly the analog of the mistake bound quantity central in perceptron proofs
(Novikoff, 1962). Indeed, this proof never ends up caring about the loss terms, and derives the
final test error bound (a zero-one loss!) via this perceptron term.

2.2 Gradient flow


Whereas SGD gave a test error guarantee for free, producing a comparable test error bound with
GF in this section will require much more work. This section also sketches the main steps of the
proof, and then closes with a discussion of escaping bad KKT points.

Theorem 2.2. Suppose the data distribution satisfies Assumption 1.4 for some γntk > 0, and the
GF curve (Ws )s≥0 uses ` ∈ {`exp , `log } on an architecture whose width m satisfies
 2
640 ln(n/δ)
m≥ .
γntk

Then, with probability at least 1 − 15δ, there exists t with kWt − W0 k = γntk m/32, and
!
2
γntk ln(n)3 ln 1δ
γ̊(Ws ) ≥ and Pr[p(x, y; Ws ) ≤ 0] ≤ O 4 + ∀s ≥ t,
4096 nγntk n

and moreover the specified iterate Wt satisfies an improved bound


!
  ln(n)3 ln 1δ
Pr p(x, y; Wt ) ≤ 0 ≤ O 2 + .
nγntk n

A few brief remarks are as follows. Firstly, for Wt , the sample complexity matches the SGD
sample complexity in Theorem 2.1, though via a much more complicated proof. Secondly, this proof
can handle {`log , `exp } and not just `log . Lastly, an odd point is that a bit of algebra grants a better
generalization bound at iterate Wt , but it is not clear if this improved bound holds for all time
s ≥ t; in particular it is not immediately clear that the nondecreasing margin property established
by Lyu and Li (2019) can be applied.
One interesting comparison is to a leaky ReLU convergence analysis on a restricted form of
linearly separable data due to Lyu et al. (2021). That work, through an extremely technical and
impressive analysis, establishes convergence to a solution which is equivalent to the best linear
predictor. By contrast, while the work here does not recover that analysis, due to γntk being a
constant multiple of the linear margin (cf. Proposition 1.5), the sample complexity is within a
constant factor of the best linear predictor, thus giving a sample complexity comparable to that of
(Lyu et al., 2021) via a simpler proof in a more general setting.
To prove Theorem 2.2, the first step is essentially the same as the proof of Theorem 2.1, however
it yields only a training error guarantee, not a test error guarantee.

10
Lemma 2.3. Suppose the data distribution satisfies Assumption 1.4 for some γntk > 0, let time t
be given, and suppose width m satisfies
640 ln(t/δ) 2
 
m≥ .
γntk
Then, with probability at least 1 − 7δ, the GF curve (Ws )s∈[0,t] on empirical risk R
b with loss
` ∈ {`log , `exp } satisfies

b t) ≤ 1 ,
R(W (training error bound),
5t √
γntk m
sup kWs − W0 k ≤ , (norm bound).
s<t 80
Note that this bound is morally equivalent to the SGD bound in Theorem 2.1 after accounting
2 “units” arising from the step size.
for the γntk
The second step of the proof of Theorem 2.2 is an explicit margin guarantee, which is missing
from the SGD analysis.
Lemma 2.4. Let data ((xi , yi ))ni=1 be given satisfying Assumption 1.3 with margin γntk > 0, and let
(Ws )s≥0 denote the GF curve resulting from loss ` ∈ {`log , `exp }. Suppose the width m satisfies
256 ln(n/δ)
m≥ 2 ,
γntk

fix a distance parameter R := γntk m/32, and let time τ be given so that kWτ − W0 k ≤ R/2 and
R(Wτ ) < `(0)/n. Then, with probability at least 1 − 7δ, there exists a time t with kWt − W0 k = R
so that for all s ≥ t,
γ2
kWs − W0 k ≥ R and γ̊(Ws ) ≥ ntk ,
4096
√ √
ct := (at / γntk , Vt γntk ) satisfies p(x, y; Wt ) = p(x, y; W
and moreover the rebalanced iterate W ct ) for
all (x, y), and
γ̊(Wct ) ≥ γntk .
4096
Before discussing the proof, a few remarks are in order. Firstly, the final large margin iterate
Wt is stated as explicitly achieving some distance from initialization; needing such a claim is
unsurprising, as the margin definition requires a lot of motion in a good direction to clear the noise

in W0 . In particular, it is unsurprising that moving O( m) is needed to achieve a good margin,

given that the initial weight norm is O( m); analogously, it is not surprising that Lemma 2.3 can
not be used to produce a meaningful lower bound on γ̊(Wτ ) directly.
Regarding the proof, surprisingly it can almost verbatim follow a proof scheme originally designed
for margin rates of coordinate descent (Telgarsky, 2013). Specifically, noting that ∂¯w γ(Ws ) and Ẇs
are colinear, the fundamental theorem of calculus (adapted to Clarke differentials) gives
Z t Z tD Z t
d ¯
E
γ(Wt ) − γ(Wτ ) = γ(Ws ) ds = ∂W γ(Ws ), Ẇs ds = k∂¯W γ(Ws )k · kẆs k ds,
τ ds τ τ
and now the terms can be controlled separately. Assuming the exponential P ¯ loss for simplicity and
¯
recalling the dual variable notation from Section 1.2, then ∂W γ(Ws ) = i qi ∂W pi (Ws ). Consequently,
using the same good property of θ discussed in Section 2.1 gives
* + √
X X θ γntk m
k∂¯W γ(Ws )k = k qi ∂¯W pi (Ws )k ≥ qi , ∂¯W pi (Ws ) ≥ .
kθk 4
i i

11
S2 u

S1

Figure 2. An arrangement of positively labeled points (the blue x’s) where the margin objective
has multiple KKT points, and the gradient flow is able to avoid certain bad ones. Specifically, as
the two cones of data S1 and S2 are rotated away from each other, the linear predictor u may still
achieve a positive margin, but it will become arbitrarily small. By contrast, pointing two ReLUs at
each of S1 and S2 achieves a much better margin. Lemma 2.3 is strong enough to establish this
occurs, at least for some arrangements of the cones, as detailed in Proposition 2.6. This construction
is reminiscent of other bad KKT constructions in the literature (Lyu et al., 2021; Vardi et al., 2021).

This leaves the other term of the integral, which is even easier:
Z t Z t
R
kẆs k ds ≥ Ẇs ds = kWt − Wτ k ≥ ,
τ τ 2
which completes the proof for the exponential loss. For the logistic loss, the corresponding elements
qi do not form a probability vector, and necessitate the use of a 2-phase analysis which warm-starts
with Lemma 2.3.
To finish the proof of Theorem 2.2, it remains to relate margins to test error, which in order

to scale with 1/n rather than 1/ n makes use of a beautiful refined margin-based Rademacher
complexity bound due to Srebro et al. (2010, Theorem 5), and thereafter uses the special structure
of 2-homogeneity to treat the network as an `1 -bounded linear combination of nodes, and thereby
achieve no dependence, even logarithmic, on the width m.

Lemma 2.5. With probability at least 1 − δ over the draw of ((xi , yi ))ni=1 , for every width m, every
choice of weights W ∈ Rm×(d+1) with γ̊(W ) > 0 satisfies
!
ln(n)3 ln 1δ
Pr[p(x, y; W ) ≤ 0] ≤ O + .
nγ̊(W )2 n

Combining the preceding pieces yields the proof of Theorem 2.2. To conclude this section, note
that these margin guarantees suffice to establish that GF can escape bad KKT points of the margin
objective. The construction appears in Figure 2 and is elementary, detailed as follows. Consider
data, all of the same label, lying in two narrow cones S1 and S2 . If S1 and S2 are close together, the
global maximum margin network corresponds to a single linear predictor. As the angle between S1
and S2 is increased, eventually the global maximum margin network chooses two separate ReLUs,
one pointing towards each cone; meanwhile, before the angle becomes too large, if S1 and S2 are
sufficiently narrow, there exists a situation whereby a single linear predictor still has positive margin,
but worse than the 2 ReLU solution, and is still a KKT point.

12
Proposition 2.6. Let ((xi , yi ))ni=1 be given as in Figure 2, where yi = +1, and (xi )ni=1 all have
kxi k = 1, and are partitioned into two sets, S1 and S2 , such that max{ xi , xj : xi ∈ S1 , xj ∈ S2 } ≤

−1/ 2. Then there always exists a margin parameter γ b > 0 and a single new data point x0 so that
the resulting data S1 ∪ S2 ∪ {x0 } satisfies the following conditions.
1. The maximum margin linear predictor u achieves margin γ
b, and is also a KKT point for any
shallow ReLU network of width m ≥ 1.

2. There exists m0 so that for any width m ≥ m0 , GF achieves limt γ̊(Wt ) > γ
b.

3 Margins beyond the NTK


This section develops two families of bounds beyond the NTK, meaning in particular that the final
margin and sample complexity bounds depend on γgl , rather than γntk as in Section 2. On the
downside, these bounds all require exponentially large width, GF, and moreover Theorem 3.3 forces
the inner layer to never rotate. These results are proved with `exp for convenience, though the same
techniques handling `log with GF in Section 2.2 should also work here.

3.1 Neural collapse (NC)


The NC setting has data in groups which are well-separated (Papyan et al., 2020); in particular,
data is partitioned into cones, and all data points outside a cone live within the convex polar to
that cone (Hiriart-Urruty and Lemaréchal, 2001). The formal definition is as follows.
Assumption 3.1. There exist (βk )rk=1 with kβk k = 1 and αk ∈ {±1/k} and γnc > 0 and  ∈ (0, γnc )
so that almost surely over the draw of any data ((xi , yi ))ni=1 , then for any particular (xi , yi , βk ):
• either βkT xi yi ≥ γnc and k(I − βk βkT )xi k ≤ γnc /2 (example i lies in a narrow cone around βk ),
p

• or βkT xi yi ≤ − (example i lies in the polar of the cone around βk ). ♦


It follows that Assumption 3.1 implies Assumption 1.2 with margin γgl ≥ γnc /k, but the condition
is quite a bit stronger. The corresponding GF result is as follows.
Theorem 3.2. Suppose the data distribution satisfies Assumption 3.1 for some (r, γnc , ), and let
` = `exp be given. If the network width m satisfies
 d
2 r
m≥2 ln ,
 δ
then, with probability at least 1 − 3δ, the GF curve (Ws )s≥0 for all large times t satisfies
!
γnc −    r2 ln(n)3 ln 1δ
γ̊(Wt ) ≥ and Pr p(x, y; Wt ) ≤ 0 = O + .
8r n(γnc − )2 n

Note that Theorem 3.2 only implies that GF selects a predictor with margins at least as good
as the NC solution, and does not necessarily converge to the NC solution (i.e., rotating all ReLUs
to point in the directions (βk )rk=1 specified by Theorem 3.2). In fact, this may fail to be true, and
Proposition 2.6 and Figure 2 already gave one such construction; moreover, this is not necessarily
bad, as GF may converge to a solution with better margins, and potentially better generalization.
Overall, the relationship of NC to the bias of 2-layer network training in practical regimes remains
open.

13
The proof of Theorem 3.2 proceeds by developing a potential functions that asserts that either
mass grows in the directions (βk )rk=1 , or their margin is exceeded. Within the proof, large width
ensures that the mass in each good direction is initially positive, and thereafter Assumption 3.1 is
used to ensure that the fraction of mass in these directions is increasing. The proof of Theorem 3.2
and of Theorem 3.3 invoke the same abstract potential function lemma, and discussion is momentarily
deferred until after the presentation of Theorem 3.3.
One small point is worth explaining now, however. It may have seemed unusual to use kaj vj k
as a (squared!) norm in Figure 1. Of course, layers asymptotically balance, thus asymptotically
not only is there the Fenchel inequality 2kaj vj k ≤ a2j + kvj k2 = kwj2 , but also a reverse inequality
2kaj vj k & kwj k2 . Despite this fact, the disagreement between 2kaj vj k and kwj k2 , namely the
imbalance between a2j and kvj k2 , can cause real problems, and one solution used within the proofs
is to replace kwj k2 with kaj vj k.

3.2 Global margin maximization


The final theorem will be on stylized networks where the inner layer is forced to not rotate.
Specifically, the networks are of the form
X
x 7→ aj σ(bj vjT x),
j

where ((aj , bj ))m


j=1 are trained, but vj are fixed at initialization; the new scalar parameter bj is
effectively the norm of vj (though it is allowed to be negative). As a further simplification, aj and
bj are initialized to have the same norm; this initial balancing is common in many implicit bias
proofs, but is impractical and constitutes a limitation to improve in future work. While it is clearly
unpleasant that (vj )m j=1 can not rotate, Figure 1c provides some hope that this is approximated in
networks of large width.
Theorem 3.3. Suppose the data distribution satisfies Assumption 1.2 for some γgl > 0 with reference
architecture ((αk , βk ))k=1 . Consider the architecture x 7→ j aj σ(bj vjT xi ) where ((aj (0), bj (0)))m
r
P
j=1
1/4
are sampled uniformly √ from the two choices ±1/m , and vj (0) is sampled from the unit sphere
(e.g., first vj0 ∼ Nd / d, then vj (0) := vj0 /kvj0 k), and
 d
4 r
m≥2 ln .
γgl δ
Then, with probability at least 1 − 3δ, for all large t, GF on ((aj , bj ))m
j=1 with `exp satisfies
!
 γgl   ln(n)3 ln 1δ
γ̊ (a(t), b(t)) ≥ and Pr p(x, y; (a(t), b(t))) ≤ 0 = O + .
4 nγgl2 n

The main points of comparison for Theorem 3.3 are the global margin maximization proofs of
Wei et al. (2018) and Chizat and Bach (2020). The analysis by Wei et al. (2018) is less similar,
as it heavily relies upon the benefits to local search arising from weight re-initialization, whereas
the analysis here in some sense is based on the technique in (Chizat and Bach, 2020), but diverges
sharply due to dropping the two key assumptions therein. Specifically, (Chizat and Bach, 2020)
requires infinite width and dual convergence, meaning q(t) converges, which is open even for linear
predictors in general settings. The infinite width assumption is also quite strenuous: it is used to
ensure not just that weights cover the sphere at initialization (a consequence of exponentially large
width), but in fact that they cover the sphere for all times t.

14
The proof strategy of Theorem 3.3 (and Theorem 3.2) is as follows. The core of the proof scheme
in (Chizat and Bach, 2020) is to pick two weights wj and wk , where wk achieves better margins
than wj in some sense, and consider
!
d kwj k2 kwj k2 X  
= 4Q q i p i (w j ) − p i (w k ) ;
dt kwk k2 kwk k2
e e
i

as a purely technical aside, it is extremely valuable that this ratio potential automatically normalizes
the margins, leading to the appearance of pei (wj ) not pi (wj ), and a similar idea is used in the proofs
here, albeit starting from ln kwj k − ln kwk k, the idea of ln(·) also appearing in the proofs by Lyu
and Li (2019). Furthermore, this expression already shows the role of dual convergence: if we
can
P assume every qi converges, then we need only pick nodes wk for which the margin surrogate
i qi (∞)e
pi wk (∞) is very large, and the above time derivative becomes negative if wj has bad
margin, which implies mass accumulates in directions with good margin. This is the heart of the
proof scheme due to (Chizat and Bach, 2020), and circumventing or establishing dual convergence
seems tricky.
The approach here is to replace kwj k and kwk k with other quantities which can be handled
without dual convergence. First, kwj k is replaced with kW k, the norm of all nodes, which can
be controlled in an elementary way for arbitraryPL-homogeneous networks: as summarized in
P A.5, as soon as kW k becomes large, then i qi pei (W ) ≈ γ(Wt ), essentially by properties of
Lemma
ln exp.
Replacing ln kwk k is much harder, since qi (t) may oscillate and thus the notion of nodes with
good margin seems to be time-varying. If there is little rotation, then P nodes near the reference
directions (βk )rk=1 can be swapped with (βk )rk=1 , and the expression i qi pei (wk ) can be swapped
with γgl . A potential function that replaces ln kwk k and allows this swapping need only satisfy a few
abstract but innocuous conditions, as summarized in Lemma A.6. Unfortunately, verifying these
conditions is rather painful, and handling general settings (without explicitly disallowing rotation)
seems to still need quite a few more ideas.

4 Concluding remarks and open problems


This work provides settings where SGD and GF can select good features, but many basic questions
and refinements remain.
Figure 1 demonstrated low rotation with 2-sparse parity; can this be proved, thereby establishing
Theorem 3.3 without forcing nodes to not rotate?
Theorem 2.1 and Theorem 2.2 achieve the same sample complexity for SGD and GF, but via
drastically different proofs, the GF proof being weirdly complicated; is there a way to make the two
more similar?
Looking to Table 1 for 2-sparse parity, the approaches here fail to achieve the lowest width; is
there some way to achieve this with SGD and GF, perhaps even via margin analyses?
The approaches here are overly concerned with reaching a constant factor of the optimal
margins; is there some way to achieve slightly worse margins with the benefit of reduced width and
computation? More generally, what is the Pareto frontier of width, samples, and computation in
Table 1?
The margin analysis here for the logistic loss, namely Theorem 2.2, requires a long warm start
phase. Does this reflect practical regimes? Specifically, does good margin maximization and feature
learning occur with the logistic loss in this early phase?

15
Acknowledgements
The author thanks Peter Bartlett, Spencer Frei, Danny Son, and Nati Srebro for discussions. The
author thanks the Simons Institute for hosting a short visit during the 2022 summer cluster on
Deep Learning Theory, the audience at the corresponding workshop for exciting and clarifying
participation during a talk presenting this work, and the NSF for support under grant IIS-1750051.

References
Emmanuel Abbe, Enric Boix-Adsera, and Theodor Misiakiewicz. The merged-staircase property: a
necessary and nearly sufficient condition for sgd learning of sparse functions on two-layer neural
networks. arXiv preprint arXiv:2202.08658, 2022.

Alekh Agarwal, Daniel Hsu, Satyen Kale, John Langford, Lihong Li, and Robert Schapire. Taming
the monster: A fast and simple algorithm for contextual bandits. In International Conference on
Machine Learning, pages 1638–1646. PMLR, 2014.

Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning via
over-parameterization. arXiv preprint arXiv:1811.03962, 2018.

Sanjeev Arora, Simon S Du, Wei Hu, Zhiyuan Li, and Ruosong Wang. Fine-grained analysis of
optimization and generalization for overparameterized two-layer neural networks. arXiv preprint
arXiv:1901.08584, 2019.

Yu Bai and Jason D Lee. Beyond linearization: On quadratic and higher-order approximation of
wide neural networks. arXiv preprint arXiv:1910.01619, 2019.

Keith M. Ball. An elementary introduction to modern convex geometry. In Flavors of Geometry,


Mathematical Sciences Research Institute Publications, pages 1–58. Cambridge University Press,
1997.

Boaz Barak, Benjamin L Edelman, Surbhi Goel, Sham Kakade, Eran Malach, and Cyril Zhang.
Hidden progress in deep learning: Sgd learns parities near the computational limit. arXiv preprint
arXiv:2207.08799, 2022.

Avrim Blum, John Hopcroft, and Ravindran Kannan. Foundations of data science, 2017. URL
https://www.cs.cornell.edu/jeh/book.pdf.

Bernhard E. Boser, Isabelle M. Guyon, and Vladimir N. Vapnik. A training algorithm for optimal
margin classifiers. In Proceedings of the Fifth Annual Workshop on Computational Learning Theory,
COLT ’92, page 144–152, New York, NY, USA, 1992. Association for Computing Machinery. ISBN
089791497X. doi: 10.1145/130385.130401. URL https://doi.org/10.1145/130385.130401.

Zixiang Chen, Yuan Cao, Difan Zou, and Quanquan Gu. How much over-parameterization is
sufficient to learn deep relu networks? arXiv preprint arXiv:1911.12360, 2019.

Lenaic Chizat and Francis Bach. Implicit bias of gradient descent for wide two-layer neural networks
trained with the logistic loss. arXiv preprint arXiv:2002.04486, 2020.

Alexandru Damian, Jason Lee, and Mahdi Soltanolkotabi. Neural networks can learn representations
with gradient descent. In Conference on Learning Theory, pages 5413–5452. PMLR, 2022.

16
Amit Daniely and Eran Malach. Learning parities with neural networks. Advances in Neural
Information Processing Systems, 33:20356–20365, 2020.

Simon S Du, Jason D Lee, Haochuan Li, Liwei Wang, and Xiyu Zhai. Gradient descent finds global
minima of deep neural networks. arXiv preprint arXiv:1811.03804, 2018.

Jean-Baptiste Hiriart-Urruty and Claude Lemaréchal. Fundamentals of Convex Analysis. Springer


Publishing Company, Incorporated, 2001.

Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Convergence and
generalization in neural networks. In Advances in neural information processing systems, pages
8571–8580, 2018.

Ziwei Ji and Matus Telgarsky. Gradient descent aligns the layers of deep linear networks. arXiv
preprint arXiv:1810.02032, 2018a.

Ziwei Ji and Matus Telgarsky. Risk and parameter convergence of logistic regression.
arXiv:1803.07300v3 [cs.LG], 2018b.

Ziwei Ji and Matus Telgarsky. Characterizing the implicit bias via a primal-dual analysis. arXiv
preprint arXiv:1906.04540, 2019.

Ziwei Ji and Matus Telgarsky. Directional convergence and alignment in deep learning. arXiv
preprint arXiv:2006.06657, 2020a.

Ziwei Ji and Matus Telgarsky. Polylogarithmic width suffices for gradient descent to achieve
arbitrarily small test error with shallow ReLU networks. In ICLR, 2020b.

Ziwei Ji, Matus Telgarsky, and Ruicheng Xian. Neural tangent kernels, transportation mappings,
and universal approximation. In ICLR, 2020.

Yuanzhi Li and Yingyu Liang. Learning overparameterized neural networks via stochastic gradient
descent on structured data. In Advances in Neural Information Processing Systems, pages
8157–8166, 2018.

Kaifeng Lyu and Jian Li. Gradient descent maximizes the margin of homogeneous neural networks.
arXiv preprint arXiv:1906.05890, 2019.

Kaifeng Lyu, Zhiyuan Li, Runzhe Wang, and Sanjeev Arora. Gradient descent on two-layer nets:
Margin maximization and simplicity bias. Advances in Neural Information Processing Systems,
34, 2021.

Atsushi Nitanda and Taiji Suzuki. Refined generalization analysis of gradient descent for over-
parameterized two-layer neural networks with smooth activations on classification problems. arXiv
preprint arXiv:1905.09870, 2019.

Albert B.J. Novikoff. On convergence proofs on perceptrons. In Proceedings of the Symposium on


the Mathematical Theory of Automata, 12:615–622, 1962.

Vardan Papyan, XY Han, and David L Donoho. Prevalence of neural collapse during the terminal
phase of deep learning training. Proceedings of the National Academy of Sciences, 117(40):
24652–24663, 2020.

Robert E. Schapire and Yoav Freund. Boosting: Foundations and Algorithms. MIT Press, 2012.

17
Robert E. Schapire, Yoav Freund, Peter Bartlett, and Wee Sun Lee. Boosting the margin: A new
explanation for the effectiveness of voting methods. In ICML, pages 322–330, 1997.

Shai Shalev-Shwartz and Shai Ben-David. Understanding Machine Learning: From Theory to
Algorithms. Cambridge University Press, 2014.

Zhenmei Shi, Junyi Wei, and Yingyu Liang. A theoretical analysis on feature learning in neural
networks: Emergence from inputs and advantage over fixed features. In International Conference
on Learning Representations, 2022. URL https://openreview.net/forum?id=wMpS-Z_AI_E.

Daniel Soudry, Elad Hoffer, Mor Shpigel Nacson, Suriya Gunasekar, and Nathan Srebro. The
implicit bias of gradient descent on separable data. arXiv preprint arXiv:1710.10345, 2017.

Nathan Srebro, Karthik Sridharan, and Ambuj Tewari. Smoothness, low noise and fast rates. In
NIPS, 2010.

Matus Telgarsky. Margins, shrinkage, and boosting. In ICML, 2013.

Gal Vardi, Ohad Shamir, and Nathan Srebro. On margin maximization in linear and relu networks.
arXiv preprint arXiv:2110.02732, 2021.

Martin J. Wainwright. High-Dimensional Statistics: A Non-Asymptotic Viewpoint. Cambridge


University Press, 1 edition, 2019.

Colin Wei, Jason D Lee, Qiang Liu, and Tengyu Ma. Regularization matters: Generalization and
optimization of neural nets vs their induced kernel. arXiv preprint arXiv:1810.05369, 2018.

Blake Woodworth, Suriya Gunasekar, Jason D Lee, Edward Moroshko, Pedro Savarese, Itay Golan,
Daniel Soudry, and Nathan Srebro. Kernel and rich regimes in overparametrized models. In
Conference on Learning Theory, pages 3635–3673. PMLR, 2020.

Tong Zhang and Bin Yu. Boosting with early stopping: Convergence and consistency. The Annals
of Statistics, 33:1538–1579, 2005.

Difan Zou, Yuan Cao, Dongruo Zhou, and Quanquan Gu. Stochastic gradient descent optimizes
over-parameterized deep relu networks. arXiv preprint arXiv:1811.08888, 2018.

A Technical preliminaries
As follows are basic technical tools used throughout.

A.1 Estimates of γntk and γgl


This section provides estimates of γntk and γgl in various settings. The first estimate is of linear
predictors.

Proof of Proposition 1.5. The proof considers the three settings separately.

18
1. For any i, first note that
D E
Ew∼Nθ θ(w), ∂ˆw pi (w) = E(a,v)∼Nθ |a|ūT xi yi σ 0 (v T xi )1[k(a, v)k ≤ 2]
bE(a,v)∼Nθ |a|σ 0 (v T xi )1[k(a, v)k ≤ 2].
≥γ

To control the expectation, note that with probability at least 1/2, then 1/4 ≤ |a| ≤ 2, and
thus by rotational invariance
1 √
E(a,v)∼Nθ |a|σ 0 (v T xi )1[k(a, v)k ≤ 2] ≥ E(a,v)∼Nθ σ 0 (v T xi )1[kvk ≤ 2]
8
1 √
≥ E(a,v)∼Nθ σ 0 (v1 )1[kvk ≤ 2]
8
1
≥ .
32

2. For convenience, fix any example (x, y) ∈ ((xi , yi ))ni=1 , and write (a, v) = w, whereby w ∼ Nw
means a ∼ Na and v ∼ Nv . With this out of the way, define orthonormal matrix M ∈ Rd×d
where the first column is ū, the second column is (I − ūūT )x/k(I − ūūT )xk, and the remaining
columns are arbitrary soplong as M is orthonormal, and note that M u = e1 and M x =
e1 ūT x + e2 r2 where r2 := kxk2 − (ūT x)2 . Then, using rotational invariance of the Gaussian,
D E
ˆ
Ew θ(w), ∂pi (w) = yEw=(a,v) sgn(ūT v)σ(v T x)1[kwk ≤ 2]
= yEk(a,M v)k≤2 α(M v)σ(v T M T x)
= Ek(a,v)k≤2 ysgn(v1 )σ(v1 ūT xy 2 + v2 r2 )
= Ek(a,v)k≤2 ysgn(v1 )σ(ysgn(v1 )|v1 |ūT xy + v2 r2 )
h
= E k(a,v)k≤2 σ(|v1 |ūT xy + v2 r2 ) − σ(−|v1 |ūT xy + v2 r2 )
ysgn(v1 )=1
v2 ≥0
i
+ σ(|v1 |ūT xy − v2 r2 ) − σ(−|v1 |ūT xy − v2 r2 ) .

Considering cases, the first ReLU argument is always positive, exactly one of the second and
third is positive, and the fourth is negative, whereby
 
yEk(a,v)k≤2 α(v)σ(v T x) = E k(a,v)k≤2 |v1 |ūT xy + v2 r2 + |v1 |ūT xy − v2 r2
ysgn(v1 )=1
v2 ≥0
= 2E k(a,v)k≤2 |v1 |ūT xy
ysgn(v1 )=1
≥ 2b
γE kvk≤1 |v1 |
ysgn(v1 )=1

bPr[k(a, v)k ≤ 2]E |v1 | k(a, v)k ≤ 2 ,

where Pr[k(a, v)k ≤ 2] ≥ 1/4 since (for example) the χ2 random variables corresponding
√ to |a|2
2
and kvk have median less than one, and the expectation term is at least 1/(4 d) by standard
Gaussian computations (Blum et al., 2017, Theorem 2.8).

19
3. For any pair (xi , yi ),
2
X
2yi αj σ(βjT xi ) = yi σ(ūT xi ) − yi σ(−ūT xi )
j=1

= 1[yi = 1]σ(yi ūT xi ) + 1[yi = −1]σ(yi ūT xi )


= yi ūT xi
≥γ
b.

Next, the construction for 2-sparse parity. As is natural in maximum margin settings, but in
contrast with most studies of sparse parity, only the support of the distribution matters (and the
labeling), but not the marginal distribution of the inputs.

Proof of Proposition 1.6. This proof shares ideas with (Wei et al., 2018; Ji and Telgarsky, 2020b),
though with some adjustments to exactly fit the standard 2-sparse parity setting, and to shorten
the proofs.
Without loss of generality, due to the symmetry of the data distribution about the origin, suppose
a = 1 and b = 2, meaning for any x ∈ Hd , the correct label is dx1 x2 , the product of the first two
coordinates. Both proofs will
P use the global margin construction (the parameters for γgl ), given as
follows: p(x, y; (α, β)) = y 4j=1 αj σ(βjT x), where α = (1/4, −1/4, −1/4, 1/4) and
 
1 1
β1 := √ , √ , 0, . . . , 0 ∈ Rd ,
2 2
 
1 −1
β2 := √ , √ , 0, . . . , 0 ∈ Rd ,
2 2
 
−1 1
β3 := √ , √ , 0, . . . , 0 ∈ Rd ,
2 2
 
−1 −1
β4 := √ , √ , 0, . . . , 0 ∈ Rd .
2 2
Note moreover that for any x ∈ Hd , then βjT x > 0 for exactly one j, which will be used for both γntk
and γgl . The proof now splits into the two different settings, and will heavily use symmetry within
Hd and also within (α, β).

1. Consider the transport mapping


 
4
 sgn(a) X
θ (a, v) = 0, βj 1[βjT v ≥ 0] ;
2
j=1

note that this satisfies the condition kθ(w)k ≤ 1 thanks to the factor 1/2, since each βj gets
a hemisphere, and (β1 , β4 ) together partition the sphere once, and (β2 , β3 ) similarly together
partition the sphere once.
Now let any x be given, which as above has label y = x1 x2 . By rotational symmetry of the data
and also the transport mapping, suppose suppose β1 is the unique choice with β1T x > 0, which

20
implies y = 1, and also β2T x = 0 = β3T x = 0, however β4T x = −β4T x. Using these observations,
and also rotational invariance of the Gaussian,
¯
Ea,v θ(a, v), ∂p(x, y; w)
4
|a| X T
= Ea,v βj x1[βjT v ≥ 0] · 1[v T x ≥ 0]
2
j=1
 
T |a| 
= β1 x Ea · Ev 1[β1T v ≥ 0] · 1[v T x ≥ 0] − Ev 1[−β1T v ≥ 0] · 1[v T x ≥ 0] .
2

Now consider Ev 1[β1T v ≥ 0] · 1[v T x ≥ 0]. A standard Gaussian computation is to introduce a


rotation matrix M whose first column is β1 , whose second column is (I − β1 β1T )x/k(I − β1 βp
T
1 )xk,
T
and the rest are orthogonal, which by rotational invariance and the calculation β1 x = 2/d
gives

Ev 1[β1T v ≥ 0] · 1[v T x ≥ 0] = Ev 1[β1T M v ≥ 0] · 1[v T M x ≥ 0]


q
= Ev 1[v1 ≥ 0] · 1[v1 β1T x + v2 1 − (β1T x)2
p
= Ev 1[v1 ≥ 0] · 1[v1 + v2 d/2 − 1 ≥ 0].

Performing a similar calculation for the other term (arising from β4T x) and plugging all of this
back in,
¯
Ea,v θ(a, v), ∂p(x, y; w)
r  
2 |a|  p p 
= Ea · Ev 1[v1 ≥ 0] 1[v1 + v2 d/2 − 1 ≥ 0] − 1[−v1 + v2 d/2 − 1 ≥ 0] .
d 2

To finish, a few observations suffice.pWhenever v1 ≥ 0 (which is enforced by the common


first term), then −v1 + v2 τ ≤ v1 + v2 d/2 − 1, so the first indicator is 1 whenever the second
indicator is 1, thus their difference is nonnegative,
p and to lower bound the overall
p quantity, it
suffices to asses the probability that v1 + v2 d/2 − 1 ≥ 0 whereas −v1 + v2 d/2 − 1 ≤ 0. To
lower bound this event, it suffices to lower bound
p p p
Pr[v1 ≥ 0 ∧ v2 ≥ 0 ∧ v1 ≥ v2 d/2 − 1] ≥ Pr[v1 ≥ 1/2] · Pr[0 ≤ v2 ≤ 1/d.

The first term is at least 1/5, and the second can be calculated via brute force:
√ √
√ 1/ d 1/ d
Z Z  
1 2 1 1 1 1
Pr[v2 ≥ 1/ d] = √ exp(−x ) dx ≥ √ exp(−1/d) dx ≥ √ √ ,
2π 0 2π 0 2π d e

which completes the proof after similarly using Ea |a| ≥ 1, and simplifying the various constants.

2. Let any x ∈ Hd be given, and as above note that βjT x > 0 for exactly one j. By symmetry,
suppose it is β1 , whereby y = x1 x2 = 1, and
 
X 1 2 1
γgl ≥ p(x, y; (α, β)) = y αj σ(βjT x) = |α1 | · β1T x = √ =√ .
4 2d 8d
j

21
Lastly, an estimate of γntk in a simplified version of Assumption 3.1, which is used in the proof
of Proposition 2.6.
Lemma A.1. Suppose Assumption 1.1 holds for data ((xi , yi ))ni=1 with reference solution ((αk , βk ))rk=1
and margin γgl > 0, and additionally
√ αk > 0 and yi = +1 and kxi k = 1. Then Assumption 1.3 holds
with margin γntk ≥ γgl /(8 d).
Pr 0 T
 n
Proof. Define θ(a, v) := k=1 αk σ (βk v), 0 1[k(a, v)k ≤ 2]. Fix any x ∈ (xi )i=1 , and for each k
define orthonormal matrix Mk with first column βk and second p column (I − βk βkT )x/k(I − βk βkT )xk,
whereby Mk βk = e1 and Mk x = e1 βk x + e2 rj where rj := 1 − (βk x)2 . Then, using rotational
T T

invariance of the Gaussian and Jensen’s inequality applied to the ReLU,


r
X
yEw∼Nθ ¯ i (w) = Ekwk≤2
θ(w), ∂σp αk σ 0 (βkT v)σ(v T x)
k=1
X
= αk EkM T vk≤2 σ 0 (v T Mk uk )σ(v T Mk x)
j
k
X
= αk Ekvk≤1 σ(v1 βkT x + v2 r2 )
k v1 ≥0
X  
≥ αk σ Ekvk≤1 v1 uTk x + v2 r2
k v1 ≥0
X  
= αk βkT xσ Ekvk≤1 v1
k v1 ≥0

≥ γgl Ekvk≤1 v1 ,
v1 ≥0

which is bounded below by γgl /(8 d) via similar arguments to those in the proof of Proposition 1.5.

A.2 Gaussian concentration


The first concentration inequalities are purely about the initialization.
√ √
Lemma A.2. Suppose a ∼ Nm / m and V ∼ Nm×d / d.
p
− δ, then kak ≤ 1 + 2 ln(1/δ)/m; similarly, with probability at least
1. With probability at least 1 p

1 − δ, then kV k ≤ m + 2 ln(1/δ)/d.
2. Let examples (x1 , . . . , xn ) be given with kxi k ≤ 1. With probability at least 1 − 4δ,

X
max aj σ(vjT xi ) ≤ 4 ln(n/δ).
i
j

√ √ √
Proof. 1. Rewrite ã := a m, so that ã ∼ Nm . Since ã 7→ kãk/ m = kak is (1/ m)-Lipschitz,
then by Gaussian concentration, (Wainwright, 2019, Theorem 2.26),

kak = kãk/ m
√ p
≤ Ekãk/ m + 2 ln(1/δ)/m
p √ p
≤ Ekãk2 / m + 2 ln(1/δ)/m
p
= 1 + 2 ln(1/δ)/m.

22

Similarly for V , defining Ṽ := V d whereby Ṽ ∼ Nm×d , Gaussian concentration grants
√ √ p
kV k = kṼ k/ d ≤ m + 2 ln(1/δ)/d.

2. Fix any example xi , and constants 1 > 0 and 2 > 0 to be optimized at the end of the proof,
and define di := d/kxi k2 for convenience. By rotational invariance √ of Gaussians
√ and since xi
is fixed, then σ(V xi ) is equivalent
√ in
√ distribution to kx i kσ(g)/ d = σ(g)/ di where g ∼ Nm .

Meanwhile, g 7→ kσ(g)k/ di is (1/ di )-Lipschitz with Ekσ(g)k ≤ m, and so, by Gaussian
concentration (Wainwright, 2019, Theorem 2.26),
!
√ p √ −di 21
Pr[kσ(V xi )k ≥ 1 + m] = Pr[kσ(g)k/ di ≥ 1 + m] ≤ exp .
2

Next consider the original expression aT σ(V xi ). To simplify handling of the 1/m variance

of the coordinates of a, define another Gaussian h := a m, and a new constant ci := mdi
for convenience, whereby aT σ(Vi ) is equivalent in distribution to equivalent in distribution to

hT σ(g)/ ci since a and V are independent (and thus h and V are independent). Conditioned
on g, since Eh = 0, then E[hT σ(g)|g] = 0. As such, applying Gaussian concentration to this
√ √
conditioned random variable, since h 7→ hT σ(g)/ ci is (kσ(g)k/ ci )-Lipschitz, then
!
T
√ −ci 22
Pr[h σ(g)/ ci ≥ 2 g] ≤ exp .
2kσ(g)k2

Returning to the original expression, it can now be controlled via the two preceding bounds,
conditioning, and the tower property of conditional expectation:

Pr[hT σ(g)/ ci ≥ 2 ]
h √ p √ i h p √ i
≤ Pr hT σ(g)/ ci ≥ 2 kσ(g)k/ di ≤ 1 + m · Pr kσ(g)k/ di ≤ 1 + m
h √ p √ i h p √ i
+ Pr hT σ(g)/ ci ≥ 2 kσ(g)k/ di > 1 + m · Pr kσ(g)k/ di > 1 + m
h √ p √ i p √
= E Pr[hT σ(g)/ ci ≥ 2 | g] kσ(g)k/ di ≤ 1 + m Pr[kσ(g)k/ di ≤ 1 + m]
√ p √ p √
+ Pr[hT σ(g)/ ci ≥ 2 | kσ(g)k/ di > 1 + m]Pr[kσ(g)k/ di > 1 + m]
 ! 
−ci 22 p √  
2
≤ E exp kσ(g)k/ d i ≤  1 + m + exp −d 
i 1 /2
2kσ(g)k2
!
−ci 22 
2

≤ exp + exp −d 
i 1 /2 .
4di 21 + 4di m
p p
As such, choosing 2 := 4 ln(n/δ) mdi /ci = 4 ln(n/δ) and 1 := 2 ln(n/δ)/di gives

√ δ δ
Pr[aT σ(V xi ) ≥ 2 ] = Pr[hT σ(g)/ ci ≥ 2 ] ≤ + ,
n n
which is a sub-exponential concentration bound. Union bounding over the reverse inequality
and over all n examples and using maxi kxi k ≤ 1 gives the final bound.

23
Next comes a key tool in all the proofs using γntk : guarantees that the infinite-width margin
assumptions imply the existence of good finite-width networks. These bounds are stated for
Assumption 1.3, however they will also be applied with Assumption 1.4, since Assumption 1.4
implies that Assumption 1.3 holds almost surely over any finite sample.

Lemma A.3. Let examples ((xi , yi ))ni=1 be given, and suppose Assumption 1.3 holds, with corre-
sponding θ : Rd+1 → Rd+1 and γntk > 0 given.

1. With probability at least 1 − δ over the draw of (wj )m
j=1 , defining θ j := θ(wj )/ m, then
XD E √
ˆ i (wj ) ≥ γntk m − 32 ln(n/δ).
p
min θj , ∂p
i
j

2. With probability at least 1 − 7δ over the draw of W with rows (wj )m j=1 with m ≥ 2 ln(1/δ),

defining rows θj := θ(wj )/ m of θ ∈ R m×(d+1) , then for any W and any R ≥ kW − W 0 k and
0

any rθ ≥ 0 and rw ≥ 0,
D E √ hp i
ˆ i (W 0 ) − rw pi (W 0 ) ≥ γntk rθ m − rθ
rθ θ + rw W, ∂p 32 ln(n/δ) + 8R + 4
 √ √ 
− rw 4 ln(n/δ) + 2R + 2R m + 4 m ,

and moreover, writing W = (a, V ), then kak ≤ 2 and kV k ≤ 2 m. For the particular choice
rθ := R/8 and rw = 1, if R ≥ 8 and m ≥ (64 ln(n/δ)/γntk )2 , then

D
ˆ 0
E
0 γntk rθ m
rθ θ + W, ∂pi (W ) − pi (W ) ≥ − 160rθ2 .
2

Proof. 1. Fix any example (xi , yi ), and define


D E
µ := ˆ
Ew θ(w), ∂pi (w) ,

where µ ≥ γntk by assumption. By the various conditions on θ, it holds for any (a, v) := w ∈ Rd+1
and corresponding (ā, v) := θ(w) ∈ Rd+1 that
D E
ˆ i (w) ≤ āσ(v T xi ) + v, axi σ 0 (v T xi )
θ(w), ∂p

≤ |ā| · 1[kvk ≤ 2] · kvk · kxi k + kvk · |a| · 1[|a| ≤ 2] · kxi k


≤ 4.

and therefore, by Hoeffding’s inequality, with probability at least 1 − δ/n over the draw of m
iid copies of this random variable,
XD E
ˆ i (wj ) ≥ mµ − 32m ln(n/δ) ≥ mγntk − 32m ln(n/δ),
p p
θ(wj ), ∂p
j

√ √
which gives the desired bound after dividing by m, recalling θj := θ(wj )/ m, and union
bounding over all n examples.

24
2. First, suppose with probability at least 1 − 7δ that the consequences of Lemma A.2 and the

preceding part of the current lemma hold, whereby simultaneously kak ≤ 2, and kV k ≤ 2 m,
and
XD E √
ˆ i (wj ) ≥ γntk m − 32 ln(n/δ).
p
min pi (W ) ≥ −4 ln(n/δ), and min θj , ∂p
i i
j

The remainder of the proof proceeds by separately lower bounding the two right hand terms in
D E D E D E
0 0
ˆ i (W ) − rw pi (W ) = rθ θ, ∂p
ˆ i (W ) + θ, ∂p 0
ˆ i (W ) − ∂p
ˆ i (W )
rθ θ + rw W, ∂p
D E 
ˆ 0 0
+ rw W, ∂pi (W ) − rw pi (W ) .

For the first term, writing (ā, V ) = θ and noting kāk ≤ 2 and kV k ≤ 2, then for any
W 0 = (a0 , V 0 ),
D E X  
ˆ i (W 0 ) − ∂p
θ, ∂p ˆ i (W ) ≤ āj σ(xTi vj0 ) − σ(vjT xi )
j

X  
+ xTi v j a0j σ 0 (xT vj0 ) − aj σ 0 (xT vj )
j
sX s 2
X
≤ ā2j σ(xTi vj0 ) − σ(vjT xi )
j j
X
+ |xTi v j | · a0j σ 0 (xT vj0 ) − aj σ 0 (xT vj0 )
j
X
+ |xTi v j | · aj σ 0 (xT vj0 ) − aj σ 0 (xT vj0 )
j

≤ kāk · kV 0 − V k + ka0 − ak · kV k + kak · kV k


≤ 4R + 4.
For the second term,
D E D E D E D E
ˆ i (W 0 ) − pi (W 0 ) = a, ∂ˆa pi (W 0 ) + V, ∂ˆV pi (W 0 ) − V 0 , ∂ˆV pi (W 0 )
W, ∂p

X X D E
≤ aj σ(xTi vj0 ) + a0j vj − vj0 , xi σ 0 (xTi vj )
j j

X   X
≤ pi (w) + yi aj σ(xTi vj0 ) − σ(xTi vj ) + a0j · kvj − vj0 k
j j

≤ 4 ln(nδ) + kak · kV − V k + ka − a + ak · kV − V 0 k
0 0

≤ 4 ln(nδ) + 4R + R2 .
Multiplying through by rθ and r and combining these inequalities gives, for every i,
D E √ hp i
ˆ i (W 0 ) − rw pi (W 0 ) ≥ γntk rθ m − rθ
rθ θ + rw W, ∂p 32 ln(n/δ) + 4R + 4
h i
− rw 4 ln(n/δ) + 4R + R2 ,

25
which establishes the first inequality. For the particular choice rθ := R/8 with R ≥ 8 and
rw = 1, and using m ≥ (64 ln(n/δ)/γntk )2 , the preceding bound simplifies to
" √ #
D E √ γ m
ˆ i (W 0 ) − rw pi (W 0 ) ≥ γntk rθ m − rθ ntk
rθ θ + rw W, ∂p + 32rθ + 32rθ
8
" √ #
γntk m 2
− + 32rθ + 64rθ
16

γntk rθ m
≥ − 160rθ2 .
2

A.3 Basic properties of L-homogeneous predictors


This subsection collects a few properties of general L-homogeneous predictors in a setup more general
than the rest of the work, and used in all large margin calculations. Specifically, suppose general
parameters ut with some unspecified initial condition u0 , and thereafter given by the differential
equation

u̇t = −∂¯u R(p(u


b t )), p(u) := (p1 (u), . . . , pn (u)) ∈ Rn , (A.1)
pi (u) := yi F (xi ; u),
F (xi ; cu) = cL F (xi ; u) ∀c ≥ 0.

The first property is that norms increase once there is a positive margin.

Lemma A.4 (Restatement of (Lyu and Li, 2019, Lemma B.1)). Suppose the setting of eq. (A.1)
and also ` ∈ {`exp , `log }. If R(uτ ) < `(0)/n, then, for every t ≥ τ ,

d
kut k > 0 and hut , u̇t i > 0,
dt
and moreover limt kut k = ∞.

Proof. Since Rb is nonincreasing during gradient flow, it suffices to consider any us with R(u
b s) <
`(0)/n. To apply (Lyu and Li, 2019, Lemma B.1), first note that both the exponential and logistic
losses can be handled, e.g., via the discussion of the assumptions at the beginning of (Lyu and Li,
2019, Appendix A.1). Next, the statement of that lemma is

d
ln kus k > 0,
ds

but note that kus k > 0 (otherwise R(u


b s ) < `(0)/n is impossible), and also that

d hus , u̇s i d hus , u̇s i


kus k = , and ln kus k = ,
ds kus k ds kus k2

which together with (d/ ds) ln kus k > 0 from (Lyu and Li, 2019, Lemma B.1) imply the main part
of the statement; all that remains to show is kus k → ∞, but this is given by (Lyu and Li, 2019,
Lemma B.6).

26
Next, even without the assumption R(ub s ) < `(0)/n (which at a minimum requires a two-phase
proof, and certain other annoyances), note that once kus k is large, then the gradient can be related
to margins, even if they are negative, which will be useful in circumventing the need for dual
convergence and other assumptions present in prior work (e.g., as in (Chizat and Bach, 2020)).

Lemma A.5 (See also (Ji and Telgarsky, 2020a, Proof of Lemma C.5)). Suppose the setting of
eq. (A.1) and also ` = `exp . Then, for any u and any ((xi , yi ))ni=1 (and corresponding R),
b
D E
u, −n∂ˆu R(u)
b 
ln n
 
ln n

≤ Q γ̊(u) + ≤ Q γ(u) + .
LkukL kukL kukL
P
Proof. Define π(p) = − ln exp(−p) = γ e(u), whereby q = ∇p π(p). Since π is concave in p,
D E X X
u, −n∂ˆu R(u)
b = |`0i | u, ∂p
¯ i = LQ qi pi = LQ ∇p π(p), p
i i
 
= LQ ∇p π(p), p − 0 ≤ LQ π(p) − π(0) = LQ [e
γ + ln n] .

Moreover, by standard properties of π, letting k be the index of any example with pk (u) = mini pi (u),
X
e = − ln
γ exp(−p) ≤ − ln exp(−pk ) = pk = γ(u)kukL .

Combining these inequalities and dividing by LkukL gives the desired bounds.

Lastly, a key abstract potential function lemma: this potential function is a proxy for mass
accumulating on certain weights with good margin, and once it satisfies a few conditions, large
margins are implied directly. This is the second component needed to remove dual convergence
from (Chizat and Bach, 2020).

Lemma A.6. Suppose the setup of eq. (A.1) with L = 2, and additionally that there exists a
b > 0, a time τ , and a potential function Φ(u) so that Φ(uτ ) > −∞, and for all t ≥ τ ,
constant γ
1
Φ(u) ≤ ln kuk,
L
d
Φ(u) ≥ Q(u)bγ.
dt
Rt
Then it follows that R(u)
b → ∞, and kuk → ∞, and τ Q(us ) ds = ∞, and lim inf t γ(ut ) ≥ γ
b.

Proof. First it is shown that if infRs R(u


b s ) > 0 (which is well-defined since since R
b is nonincreasing

and bounded below by 0), then τ Qs ds = ∞ and kuk → ∞. Since R(us ) = n1 Qs , this implies
b
inf s Qs > 0, and consequently Z ∞
Qs ds = ∞,
τ
which also implies
Z t
1 d
lim inf ln kut k ≥ lim inf Φ(ut ) − Φ(uτ ) + Φ(uτ ) = Φ(uτ ) + lim inf Φ(us ) ds
t L t t τ ds
Z t
≥ Φ(uτ ) + lim inf bQs ds = ∞,
γ
t τ

27
thus kus k → ∞. On the other hand, if inf s R(u b s ) = 0, then there exists t1 so that for all t ≥ t1 ,
then γt > 0 (also making use of nondecreasing margins (Lyu and Li, 2019)), which is only possible
if kus k → ∞, and thus moreover we can take t2 ≥ t1 so that additionally kut k ≥ ln(n)/γt1 , (which
will hold for all t0 > t2 by Lemma A.4), and by Lemma A.5 and the restriction L = 2 means
P 0 ¯
i |`i | u, ∂pi (u)
P
d QL i qi pi (u)
ln kuk = = ≤ QL(γt + γt1 ) ≤ 2LQγt ,
dt kuk2 kuk2
R∞  
which after integrating and upper bounding γt ≤ 1 means t2 Qs ds ≥ limt ln kut k − ln kut2 k = ∞.
b s ) = 0, then still kus k → ∞ and ∞ Qs ds = ∞.
R
As such, independent of whether inf s R(u τ
This now suffices to complete the proof. If lim inf t γt ≥ γ b, then in fact limt γt is well-defined
(by non-decreasing margins and kuk → ∞) and limt γt ≥ γ b > 0, whereby lim supt R(u b t) ≤
L
lim supt `(−γt kwt k ) = 0. Alternatively, suppose contradictorily that lim inf t γt < γ b, and choose
any  ∈ (0, γ b/4) so that lim inf t γt < γ b − 3; noting that γt is monotone once there exists some
γs > 0, then, choosing t3 large enough so that kut k2 ≥ ln(n)/ for all t ≥ t3 and γt < γ b − 2 for all
t ≥ t3 , it follows by Lemma A.5 that
 
1
0 ≤ lim inf ln kut k − Φ(ut )
t L
Z t  
1 d 1
≤ ln kut3 k − Φ(ut3 ) + lim inf ln kus k − Φ(us ) ds
L t t3 ds L
Z t
1  
≤ ln kut3 k − Φ(ut3 ) + lim inf Q(b
γ − ) − Qb γ ds
L t t3
Z t
1
≤ ln kut3 k − Φ(ut3 ) + lim inf [−Q] ds
L t t3
= −∞,

a contradiction, and since  ∈ (0, γ


b/4) was arbitrary, it follows that lim inf γt ≥ γ
b.

B Proofs for Section 2


This section contains proofs with a dependence on γntk .

B.1 SGD proofs


The following application of Freedman’s inequality is used to obtain the test error bound.
Lemma B.1 (Nearly identical to (Ji and Telgarsky, 2020b, P Lemma 4.3)). Define Q(W ) :=
Ex,y |`0 (p(x, y; W ))| and Qi (W ) := |`0 (p(xi , yi ; W ))|. Then
 
i<t Q(W i ) − Q i (W i ) is a martin-
gale difference sequence, and with probability at least 1 − δ,
X X
Q(Wi ) ≤ 4 Qi (Wi ) + 4 ln(1/δ),
i<t i<t

Proof. This proof is essentially a copy of one due to Ji and Telgarsky (2020b, Lemma 4.3); that one
is stated for the analog of pi used there, and thus need to be re-checked.
Let Fi := {((xj , yj )) : j < i} denote the σ-field of all information until time i,whereby xi is inde- 
pendent of FP i , whereas
  after conditioning on Fi . Consequently, E Q(Wi ) − Qi (Wi )|Fi =
wi deterministic
0, whereby i<t Q(Wi ) − Qi (Wi ) is a martingale difference sequence.

28
The high probability bound will now follow via a version of Freedman’s inequality (Agarwal
et al., 2014, Lemma 9). To apply this bound, the conditional variances must be controlled: noting
that |`0 (z)| ∈ [0, 1], then Q(Wi ) − Qi (Wi ) ≤ 1, and since Qi (Wi ) ∈ [0, 1], then Qi (Wi )2 ≤ Qi (Wi ),
and thus
h 2 i h i
E Q(Wi ) − Qi (Wi ) Fi = E Qi (Wi )2 Fi − Q(Wi )2
 
≤ E Qi (Wi ) Fi − 0
= Q(Wi ).
As such, by the aforementioned version of Freedman’s inequality (Agarwal et al., 2014, Lemma 9),
X  X h 2 i
Q(Wi ) − Qi (Wi ) ≤ (e − 2) E Q(Wi ) − Qi (Wi ) Fi + ln(1/δ)
i<t i<t
X
≤ (e − 2) Q(Wi ) + ln(1/δ),
i<t

which rearranges to give


X X
(3 − e) Q(Wi ) ≤ Qi (Wi ) + ln(1/δ),
i<t i<t

which gives the result after multiplying by 4 and noting 4(3 − e) ≥ 1.

With Lemma B.1 and the Gaussian concentration inequalities from Appendix A.2 in hand, a
proof of a generalized form of Theorem 2.1 is as follows.

Proof of Theorem 2.1. Let (wj )mj=1 be given with corresponding (āj , v j ) := θ j := θ(wj )/ m
(whereby kθj k ≤ 2 by construction), and define
√ √ √ √
10η m γ m 80η m γ m
r := ≤ , R := 8r = ≤ , W := rθ + W,
γ 640 γ 80
which implies r ≥ 1, and R ≥ 1, and η ≤ R/16. Since Assumption 1.4 implies that Assumption 1.3
holds for ((xi , yi ))i≤t with probability 1, for the remainder of the proof, rule out the 7δ failure
probability associated with the second part of Lemma A.3 (which is stated in terms of Assumption 1.3
not Assumption 1.4), whereby simultaneously for every kW 0 − W0 k ≤ R
√ √ 2
ˆ i (W 0 ) ≥ rγ m − 160r2 ≥ rγ m = γ m ≥ ln(t),
D E
min W , ∂p (B.1)
i 2 4 2560
√ √ √
D
0
E √ √ γ m γ m γ m
ˆ
p
min θ, ∂pi (W ) ≥ γ m − 32 ln(n/δ) − 4R − 4 ≥ γ m − − ≥ , (B.2)
i 8 10 2

and also ka0 k ≤ 2 and kV0 k ≤ 2 m.
The proof now proceeds as follows. Let τ denote the first iteration where kWτ − W0 k ≥ R,
whereby τ > 0 and maxs<τ kWs − W0 k ≤ R. Assume contradictorily that τ ≤ t; it will be shown
that this implies kWτ − W0 k ≤ R.
Consider any iteration s < τ . Expanding the square,
ˆ s (Ws ) − W k2
kWs+1 − W k2 = kWs − η ∂`
D E 2
= kWs − W k2 − 2η ∂` ˆ s (Ws ), Ws − W + η 2 ∂`ˆ s (Ws )
D E 2
= kWs − W k2 + 2η`0s (Ws ) ∂p ˆ s (Ws ), W − Ws + η 2 `0 (Ws )2 ∂p
s
ˆ s (Ws ) .

29
By convexity, kWs − W0 k ≤ R, and eq. (B.1),
D  !
D E E
`0s (Ws ) ∂p
ˆ s (Ws ), W − Ws = `0s (Ws ) ˆ s (Ws ), W − ps (Ws ) − ps (Ws )
∂p
D E 
≤ `s ˆ
∂ps (Ws ), W − ps (Ws ) − `s (Ws )

≤ ln(1 + exp(− ln(t))) − `s (Ws ),


1
≤ − `s (Ws ),
t
which combined with the preceding display gives
 
1 2
2
kWs+1 − W k ≤ kWs − W k + 2η 2
− `s (Ws ) + η 2 `0s (Ws )2 ∂p
ˆ s (Ws ) .
t
P
Since this inequality holds for any s < τ , then applying the summation s<τ and rearranging gives
X X 2
kWτ − W k2 + 2η `s (Ws ) ≤ kW0 − W k2 + 2η + η 2 `0s (Ws )2 ∂p
ˆ s (Ws ) .
s<τ s<τ

To simplify the last term, using kV0 k ≤ 2 m and ka0 k ≤ 2 and kWs − W0 k ≤ R gives
2 X 2
ˆ s (W )
∂p = kσ(Vs xs )k2 + ej ai,j σ 0 (vi,j
T
xs )xs
j
2
≤ σ(Vs xs ) +kas k2
≤ 2kVs − V0 k2 + 2kV0 k2 + 2kas − a0 k2 + 2ka0 k2
≤ 2R2 + 8m + 8,
≤ 10m,

and moreover the first term can be simplified via


D E
kWτ − W k2 = kWτ − W0 k2 − 2 Wτ − W0 , W − W0 + kW − W0 k2
≥ kWτ − W0 k2 − 2rkWτ − W0 k + kW − W0 k2 ,

whereby combining these all gives


X
kWτ − W0 k2 − 2rkWτ − W0 k + kW − W0 k2 + 2η `s (Ws )
s<τ
X
≤ kWτ − W k2 + 2η `s (Ws )
s<τ
X 2
≤ kW0 − W k2 + 2η + η 2 `0s (Ws )2 ∂p
ˆ s (Ws )
s<τ
X
≤ kW0 − W k + 2η + 10η 2 m
2
|`0s (Ws )|,
s<τ

which after canceling and rearranging gives


X X
kWτ − W0 k2 + 2η `s (Ws ) ≤ 2rkWτ − W0 k + 2η + 10η 2 m |`0s (Ws )|.
s<τ s<τ

30
To simplify the last term, note by eq. (B.2) that
kWτ − W0 k = sup hW, Wτ − W0 i
kW k≤1
1
≥ −θ̄, Wτ − W0
2
η XD ˆ s (Ws )
E
= −θ, ∂`
2 s<τ
ηX 0 D
ˆ i (Ws )
E
= |`s (Ws )| θ, ∂p
2 s<τ

ηX 0 γ m
≥ |`s (Ws )| , (B.3)
2 s<τ 2

and thus, by the choice of R, and since kWτ − W0 k ≥ 1 and η ≤ R/16,



X 40η mkWτ − W0 k
kWτ − W0 k2 + 2η `s (Ws ) ≤ 2rkWτ − W0 k + 2η +
s<t
γ
 
R R R
≤ + + kWτ − W0 k.
4 8 2
P
Dropping the term 2η s<t `s (Ws ) ≥ 0 and dividing both sides by kWτ − W0 k ≥ R > 0 gives
R R R
kWτ − W0 k ≤ + + < R,
4 8 2
the desired contradiction, thus τ > t and all above derivations hold for all s ≤ t.

To finish the proof, combining eq. (B.3) with kWt − W0 k ≤ R = 80η m/γ gives
X 4kWt − W0 k 320
|`0s (Ws )| ≤ √ ≤ 2 .
s<t
ηγ m γ

Lastly, for the generalization bound, defining Q(W ) := Ex,y |`0 (p(x, y; W ))|, discarding an additional
δ failure probability, by Lemma B.1,
X X 1280
Q(Ws ) ≤ 4 ln(1/δ) + 4 |`0s (Ws )| ≤ 4 ln(1/δ) + 2 .
s<t
γ
i<t

Since 1[ps (Ws ) ≤ 0] ≤ 2|`0s (Ws )|, the result follows.

B.2 GF proofs
This section culminates in the proof of Theorem 2.2, which is immediate once Lemmas 2.3 and 2.4
are established.
Before proceeding with the main proofs, the following technical lemma is used to convert a
bound on `0 to a bound on `.
Lemma B.2. For ` ∈ {`log , `exp }, then |`0 (z)| ≤ 1/8 implies `(z) ≤ 2|`0 (z)|.
Proof. If ` = `exp , then `0 = −`, and thus `(z) ≤ 2|`0 (z)| automatically. If `(z) = `log , the logistic
loss, then |`0 (z)| ≤ 1/8 implies z ≥ 2. By the concavity of ln(·), for any z ≥ 2, since 1 + e−z ≤ 7/6,
then
(7/6)e−z
`(z) = ln(1 + e−z ) ≤ e−z ≤ ≤ 2|`0 (z)|,
1 + e−z
thus completing the proof.

31
Next comes the proof of Lemma 2.3, which follows the same proof plan as Theorem 2.1.

Proof of Lemma 2.3. This proof is basically identical to the SGD in Theorem 2.1. Despite this,
proceeding with amnesia, let rows (wj )m
j=1 of W0 be given with corresponding (āj , v j ) := θ j :=

θ(wj )/ m (whereby kθj k ≤ 2 by construction), and define
√ √
γntk m γntk m
r := , R := 8r = , W := rθ + W,
640 80
with immediate consequences that r ≥ 1 and R ≥ 8. For the remainder of the proof, rule out the 7δ
failure probability associated with the second part of Lemma A.3, whereby simultaneously for every
kW 0 − W0 k ≤ R,
D E rγ √m √
rγntk m γ2 m
ˆ 0 ntk 2
min W , ∂pi (W ) ≥ − 160r ≥ = ntk ≥ ln(t), (B.4)
i 2 4 2560
√ √ √
√ √
ˆ i (W 0 ) ≥ γntk m − 32 ln(n/δ) − 4R − 4 ≥ γntk m − γntk m − γntk m ≥ γntk m .
D E p
min θ, ∂p
i 8 10 2
(B.5)

The proof now proceeds as follows. Let τ denote the earliest time such that kWτ − W0 k = R; since
Ws traces out a continuous curve and since R > 0 = kW0 − W0 k, this quantity is well-defined. As a
consequence of the definition, sups<τ kWs − W0 k ≤ R. Assume contradictorily that τ ≤ t; it will be
shown that this implies kWτ − W0 k < R.
By the fundamental theorem of calculus (and the chain rule for Clarke differentials), convexity
of `, and since kWs − W0 k ≤ R holds for s ∈ [0, τ ), which implies eq. (B.4) holds,
Z τ
2 2 d
kWτ − W k − kW0 − W k = kWs − W k2 ds
0 ds
Z τ D E
= 2 Ẇs , Ws − W ds
0
2 τX 0
Z D E
= ¯
`i (Ws ) ∂pi (Ws ), Ws − W ds
n 0
i
!
2 τX 0
Z D E 
= `i (Ws ) ˆ i (Ws ), W − pi (Ws ) − pi (Ws ) ds
∂p
n 0
i
Z τX D  !
2 E
ˆ i (Ws ), W − pi (Ws ) − `i (Ws ) ds
≤ `i ∂p
n 0
i
Z τ X 
2 1
≤ − `i (Ws ) ds
n 0 t
i
Z τ
≤2−2 R(W
b s ) ds.
0

To simplify the left hand side,


D E
kWτ − W k2 − kW0 − W k2 = kWτ − W0 k2 − 2 Wτ − W0 , W − W0 ≥ kWτ − W0 k2 − 2rkWτ − W0 k,

which after combining, rearranging, and using r ≥ 1 and kWτ − W0 k ≥ R ≥ 1 gives


Z τ
kWτ − W0 k2 + 2 R(Ws ) ds ≤ 2 + 2rkWτ − W0 k ≤ 4rkWτ − W0 k,
0

32
which implies
R
kWτ − W0 k ≤ 2r = < R,
2
a contradiction since Wτ is well-defined as the earliest time with kWτ − W0 k = R, which thus
contradicts τ ≤ t. As such, τ ≥ t, and all of the preceding inequalities follows with τ replaced by t.
To obtain an error bound, similarly to the key perceptron argument before, using eq. (B.5),

kWt − W0 k = sup hW, Wt − W0 i


kW k≤1
1D E
≥ −θ, Wt − W0
2* +
Z t
1
= −θ, ȧs ds
2 0
Z tX
1 D E
= |`0i (Ws )| θ, ∂p
ˆ i (Ws ) ds
2n 0
i
√ Z
γntk m t X 0
≥ |`i (Ws )| ds,
4n 0 i

which implies Z tX
1 4kWt − W0 k 1
|`0i (Ws )| ds ≤ √ ≤ ,
n 0 γntk m 20
i

and in particular
Z tX
1X 0 1 4kWt − W0 k 1
inf |`i (Ws )| ≤ |`0i (Ws )| ds ≤ √ ≤
s∈[0,t] n tn 0 tγntk m 20t
i i

and so there exists k ∈ [0, t] with


1X 0 1
|`i (Wk )| ≤ .
n 10t
i

Since this also implies maxi |`0i (Wk )| ≤ n/(10t) ≤ 1/10, it follows by Lemma B.2 that R(W
b k) ≤
0
1/(5t), and the claim also holds for t ≥ t since the empirical risk is nonincreasing with gradient
flow.

Next, the proof of the Rademacher complexity bound used for all GF sample complexities.

Proof of Lemma 2.5. For any (a, v) ∈ Rd+1 and any x, recalling the notation defining e
a := sgn(a)
and ve := v/kvk,
k(a, v)k2
aσ(v T x) = kavke v T x) ≤
aσ(e v T x),
aσ(e
e
2
and therefore, letting
 
X m 
sconv(S) := pj uj : m ≥ 0, p ∈ Rm , kpk1 ≤ 1, uj ∈ S
 
j=1

33
denote the symmetrized convex hull as used throughout Rademacher complexity (Shalev-Shwartz
and Ben-David, 2014), then
 
 1 X
T m×(d+1)

F := x 7→ aj σ(v j x) : m ≥ 0, W ∈ R
 kW k2 
j
 
 1 X
T m×(d+1)

= x 7→ ka j vj ke
a σ(ev j x) : m ≥ 0, W ∈ R
 kW k2 
j
 
 X k(aj , vj )k2 
T m×(d+1)
⊆ x 7→ aσ(e vj x) m ≥ 0, W ∈ R
:
2kW k2
e
 
j
 
 X 1 
⊆ x 7→ pj σ(uTj x) : m ≥ 0, p ∈ Rm , kpk1 ≤ , kuj k2 = 1
 2 
j
1  
= sconv x 7→ σ(v T x) : kvk2 = 1 .
2
As such, by standard rules of Rademacher complexity (Shalev-Shwartz and Ben-David, 2014),
1   1
Rad(F) ≤ Rad x 7→ σ(v T x) : kvk2 = 1 ≤ √ ,
2 2 n

and thus, by a refined margin bound for Rademacher complexity (Srebro et al., 2010, Theorem
5), with probability at least 1 − δ, simultaneously for all γgl and all m, every W ∈ Rm×(d+1) with
γ̊(W ) ≥ γgl satisfies
 
1 1 !
ln(n) 3 ln ln γ + ln δ ln(n) 3 ln 1
Pr[p(x, y; W ) ≤ 0] = O  2 Rad(F)2 + gl
=O + δ ,
γgl n nγgl2 n

and to finish, instantiating this bound with γ̊(W ) gives the desired form.

The proof of Lemma 2.4 now follows.

Proof of Lemma 2.4. By the second part of Lemma A.3, with probability at least 1 − 7δ, simultane-

ously kak ≤ 2, and kV k ≤ 2 m, and for any kW 0 − W0 k ≤ R, then
D E √ hp i γ √m
ˆ 0 ntk
min θ, ∂pi (W ) ≥ γntk m − 32 ln(n/δ) + 8R + 4 ≥ ,
i 2

where θj := θ(wj )/ m as usual, and kθk ≤ 2; for the remainder of the proof, suppose these bounds,
and discard the corresponding 7δ failure probability. Moreover, for any W 0 with R(W
b 0 ) < `(0)/n and

kW − W0 k ≤ R, as a consequence of the preceding lower bound and also the property i qi (W 0 ) ≥ 1


0
P

34
Ji and Telgarsky (2019, Lemma 5.4, first part, which does not depend on linear predictors),
¯γ (W 0 ) = sup
∂e ¯γ (W 0 )
W, ∂e
kW k≤1
* +
1 X
0
¯ i (W )
≥ θ, qi ∂p
2
i
1X D ¯ E
= qi θ, ∂pi (W 0 )
2
i

γntk m X
≥ qi
4
i

γntk m
≥ .
4

Now consider the given Wτ with R(W


b τ ) < `(0)/n and kWτ − W0 k ≤ R/2. Since s 7→ Ws traces
out a continuous curve and since norms grow monotonically and unboundedly after time τ (cf.
Lemma A.4), then there exists a unique time r with kWt − W0 k = R. Furthermore, since R b is
nonincreasing throughout gradient flow, then R(W
b s ) < `(0)/n holds for all s ∈ [τ, t]. Then
Z tD E
e(Wt ) − γ
γ e(Wτ ) = ¯γ (Ws ), Ẇs ds
∂e
τ
Z t
¯γ (Ws )k · kẆs k ds
k∂e =
τ
√ Z
γntk m t
≥ kẆs k ds
4 τ
√ Z t
γntk m
≥ Ẇs ds
4 τ

γntk m
= kWt − Wτ k ds
4 √
γntk R m

8
2
γ m
≥ ntk .
256
√ √ √ √
Since kW0 k ≤ 3 m, thus kWt k ≤ 3 m + γntk m/32 ≤ 4 m, and the normalized margin satisfies
t
γ 2 m/256 γ2
Z
γ
e(Wτ ) 1 d
γ̊(Wt ) ≥ 2
+ e(Ws ) ds ≥ 0 + ntk
γ = ntk .
kWt k kWt k2 τ ds 16m 4096

Furthermore, it holds that γ̊(Ws ) ≥ γ̊(Wt ) for all s ≥ t (Lyu and Li, 2019), which completes the
proof for Wt under the standard parameterization.
Now consider the rebalanced parameters W ct := (at /√γntk , Vt √γntk ); since m ≥ 256/γ 2 , which
√ ntk
means 16 ≤ γntk m, then
√ √ √
γntk m γntk m γntk m
kat k ≤ ka0 k + kat − a0 k ≤ 2 + R ≤ + ≤ ,
√ 8√ 32 4
kVt k ≤ kV0 k + kVt − V0 k ≤ 2 m + R ≤ 3 m,

35
then the rebalanced parameters satisfy

ct k ≤ kat /√γntk k + kVt √γntk k ≤
kW
γntk m √ √
+ 3 γntk m ≤ 4 γntk m,
4
and thus, for any (x, y), since
X X aj √
p(x, y; w) = aj σ(vjT x) = √ σ( γntk vjT x) = p(x, y; W
c ),
γntk
j j

then
pi (W
c) pi (w) e2 m/256
γ γntk
γ̊(W
c ) = min = min ≥ ≥ ,
i kW k
c 2 i kW k
c 2 16γ ntk m 4096
which completes the proof.

Thanks to Lemmas 2.3 and 2.4, the proof of Theorem 2.2 is now immediate.

Proof of Theorem 2.2. As in the statement, define R := γntk m/32, and note that Assumption 1.3
holds almost surely for any finite sample due to Assumption 1.4. The analysis now uses two stages.
The first stage is handled by Lemma 2.3 run until time τ := n, whereby, with probability at least
1 − 7δ, √
b τ ) ≤ 1 < `(0) ,
R(W kWτ − W0 k ≤
γntk m
≤ .
R
5n n 80 2
The second stage now follows from Lemma 2.4: since Wτ as above satisfies all the conditions
of Lemma 2.4, there exists Wt with kWt − W0 k = R, and γ̊(Ws ) ≥ γntk 2 /4096 for all s ≥ t, and
ct ) ≥ γntk /4096 for the rebalanced iterates W √ √
γ̊(W ct = (at / γntk , Vt γntk ). The first generalization
bound now follows by Lemma 2.5. The second generalization bound uses p(x, y; Wt ) = p(x, y; W ct )
for all (x, y), whereby Pr[p(x, y; Wt ) ≤ 0] = Pr[p(x, y; Wct ) ≤ 0], and the earlier margin-based
generalization bound can be invoked with the improved margin of W ct .

Lastly, here are the details from the construction leading to Proposition 2.6 which give a setting
where GF converges to parameters with higher margin than the maximum margin linear predictor.

Proof of Proposition 2.6. Let u be the maximum margin linear separator for ((xi , yi ))ni=1 ; necessarily,
there exists at least one support vector in each of S1 and S2 , since otherwise u is also the maximum
margin linear predictor over just one of the sets, but then the condition
1
min{ xi , xj : xi ∈ S1 , xj ∈ S2 } ≤ − √ (B.6)
2
implies that u points away from and is incorrect on the set with no support vectors.
As such, let v1 ∈ S1 and v2 ∈ S2 be support vectors, and consider the addition of a single data
point (x0 , +1) from a parameterized family of points {zα : α ∈ [0, 1]}, defined as follows. Define
v0 := (I − v1 v1T )u/k(I − v1 v1T )uk, which is orthogonal to −v1 by construction, and consider the
geodesic between v0 and −v1 : p
zα := αv0 − 1 − α2 v1 ,
which satisfies kzα k = 1 by construction by orthogonality, meaning
p
kzα k2 = α2 − 2α 1 − α2 hv0 , v1 i + (1 − α2 ) = 1.

In order to pick a specific point along the geodesic, here are a few observations.

36
1. First note that −v2 is a good predictor for S1 : eq. (B.6) implies
1
min h−v2 , xi = − max hv2 , xi ≥ √ .
x∈S1 x∈S1 2

It follows similarly that −v1 is a good predictor for S2 : analogously, minx∈S2 h−v1 , xi ≥ 1/ 2.

2. It is also the case that −v1 is a good predictor for every zα with α ≤ 1/ 2:
p 1
h−v1 , zα i = 1 − α2 ≥ √ .
2

3. Now define a 2-ReLU predictor f (x) := (σ(−v1T x) + σ(−v2T x))/2.


√ As a consequence
√ of the two
preceding points, for any x ∈ S1 ∪ S2 ∪ {zα }, then f (x) ≥ 1/√ 2 so long as α ≤ 1/ 2. As such,
by Lemma A.1, Assumption 1.3 is satisfied with γ1 ≥ 1/(64 d).

4. Lastly, as α → 0, then zα → −v1 , but the resulting set of points S1 ∪ S2 ∪ {z0 } is linearly
separably only with margin at most zero (if it is linearly separable at all). Consequently,
there exist choices of α0 ∈ [0, 1] so that the resulting maximum margin linear predictor over
S1 ∪ S2 ∪ {zα } has arbitrarily small yet still positive margins, and for concreteness choose some
α0 so that the resulting linear separability margin γ satisfies γ ∈ (0, c/(215 d)), where c is the
positive constant in Lemma 2.3.

As a consequence of the last two bullets, using label +1 and inputs S1 ∪ S2 ∪ {zα0 }, the maximum
margin linear predictor has margin γ, Assumption 1.3 is satisfied with margin at least γ1 , but most
importantly, by Lemma 2.3 (with m0 given by the width lower bound required there), GF will
achieve a margin γ2 ≥ cγ12 /4 ≥ 2γ.
It only remains to argue that the linear predictor itself is a KKT point (for the margin objective),
but this is direct and essentially from prior work (Lyu and Li, 2019): e.g., taking all outer weights
to be equal and positive, and all inner weights to be equal to u and of equal magnitude and equal to
the outer layer magnitude, and then taking all these balanced norms to infinity, it can be checked
that the gradient and parameter alignment conditions are asymptotically satisfied (indeed, the
ReLU can be ignored since all nodes have positive inner product with all examples), which implies
convergence to a KKT point (Lyu and Li, 2019, Appendix C).

C Proofs for Section 3


This section develops the proofs of Theorem 3.2 and Theorem 3.3. Before proceeding, here is a
quick sampling bound which implies there exist ReLUs pointing in good directions at initialization,
which is the source of the exponentially large widths in the two statements.

Lemma C.1. Let (β1 , . . . , βr ) be given with kβk k = 1, and suppose (vj )m
j=1 are sampled with

vj ∼ Nd / d, with corresponding normalized weights vej . If
 d
2 r
m≥2 ln ,
 δ
2
then maxk minj ke
vj − βk k ≤ , alternatively mink maxj vejT βk ≥ 1 − 2.

37
Proof. By standard sampling estimates (Ball, 1997, Lemma 2.3), for any fixed k and j, then
1  d−1
 
vj − βk k ≤ ] ≥
Pr[ke ,
2 2
and since all (vj )m
j=1 are iid,
m
Pr[∃j  ke
vj − βk k ≤ ] = 1 − Pr[∀j  ke
vj − βk k > ] = 1 − 1 − Pr[ke
v1 − βk k ≤ ]
 d−1 !
m  δ
≥ 1 − exp − ≥1− ,
2 2 r
and union bounding over all (βk )rk=1 gives maxk minj ke vj − βk k ≤ . To finish, the inner product
vj − βk k2 = 2 − 2e
form comes by noting ke vjT βk and rearranging.
First comes the proof of Theorem 3.2, whose entirety is the construction of a potential Φ and a
verification that it satisfies the conditions in Lemma A.6.
Proof of Theorem 3.2. To start the construction of φ, first define per-weight potentials φk,j which
track rotation of mass towards each βk :
 
φk,j := φk (wj ) := φ αek aj σ(vjT βk ) ≥ (1 − )kaj vj k .
d
The goal will be to show that dt φk,j is monotone nondecreasing, which shows that weights get
d
trapped pointing towards each βk once sufficiently close. As such, to develop dt φk,j , note (using
aj := sgn(aj ) and vej := vj /kvj k),
e
 
d T T d T
α
ek aj σ(vj βk ) = α
ek ȧj σ(vj βk ) + aj σ(vj βk )
dt dt
X h i
= −eαk `0i yi σ(vjT xi )σ(vjT βk ) + a2j σ 0 (vjT βk )xTi βk ,
i
X h i
= −e
αk `0i yi kvj k2 σ(e vjT βk ) + a2j σ 0 (vjT βk )xTi βk ,
vjT xi )σ(e
i
d d 1/2
kaj vj k = aj vj , aj vj
dt dt
2 aj vj , ȧj vj + aj v̇j
= 1/2
2 aj vj , aj vj
 i
P 0 h 2 0
aj vj , − i `i yi vj σ(vj xi ) + aj σ (vj xi )xi
T T

=
kaj vj k
P 0 2
− i `i pi (wj )kwj k
=
kaj vj k
X
=− `0i e
aj σ(evjT xi )kwj k2 ,
i
 
d dh i
φj = ek aj σ(vjT βk ) − (1 − )kaj vj k
α when φj ∈ (0, 1)
dt dt
X h  
=− `0i yi kvj k2 e vjT xi ) α
aj σ(e aj σ(e
ek e vjT βk ) − (1 − )
i
 i
+ a2j αek σ 0 (vjT βk )βkT xi − (1 − )e vjT xi ) .
aj σ(e

38
Analyzing the last two terms separately, the first (the coefficient to kvk2 is nonnegative since the
term in parentheses is a rescaling of φj :
  φj
α vjT βk ) − (1 − ) =
aj σ(e
ek e > 0,
kaj vj k

and moreover this holds for any choice of  > 0. The second term (the coefficient of a2j ) is more
complicated; to start, fixing any example (xi , yi and writing zi := cβk + z⊥ , where necessarily c ≥ γnc ,
ej := e
note (using u aj vej )
 2
z⊥ 2 2
,u
ej ≤ (I − βk βkT )e ej ≤ 1 − (1 − )2 = 2 − 2 ≤ 2,
uj = 1 − βk , u
kz⊥ k
and thus, since φj ∈ (0, 1) whereby σ(vjT xi ) = vjT xi , the nonnegativity of the second term follows
from the narrowness of the around βk (cf. Assumption 3.1):

ek βkT xi yi − (1 − )e
α aj vejT xi yi ≥ c − (1 − ) e
aj vej , cβk + z⊥

≥ c − (1 − )kz⊥ k 2
≥ c − (1 − )γnc 
≥ 0,

meaning dφk,j / dt ≥ 0 for every pair (k, j).


With this in hand, define the overall potential as
1X X
Φ(w) := |αk | ln φk,j kaj vj k,
4
k j

whereby
P h d d
i
d 1X j φ k,j dt ka v
j j k + ka v k φ
j j dt k,j
Φ(w) = |αk | P
dt 4
k j φk,j kaj vj k
P 0 P
1X − i `i yi j φk,j e aj σ(evjT xi )kwj k2
≥ |αk | P
4
k j φk,j kaj vj k

vj − βk + βk )T xi kwj k2
P 
1 X X yi j φk,j α ek σ (e
= Q |αk | qi P
4
k i∈Sk j φk,j kaj vj k
P
1 X X yi j φk,j (γnc − ) 2kaj vj k
≥ Q αk qi P
4
k i∈Sk j φk,j kaj vj k
 
γgl − 
≥Q .
r
Written another way, this establishes dΦ(Wt )/ dt ≥ Qb γ with γ b = (γnc − )/k, which is one of the
conditions needed in Lemma A.6. The other properties meanwhile are direct:
1X X 1X 1
Φ(Wt ) ≤ |αk | ln φj,k kwj k2 ≤ |αk | ln kWt k2 = ln kWt k,
4 4 2
k j k

and Φ(W0 ) > −∞ due to random initialization (cf. Lemma C.1), which allows the application of
Lemma A.6 and Lemma 2.5 and completes the proof.

39
To close, the proof of Theorem 3.3.
Proof of Theorem 3.3. Throughout the proof, use W = ((aj , bk ))m j=1 to denote the full collection of
parameters, even in this scalar parameter setting.
To start, note that a2k = b2k for all times t; this follows directly, from the initial condition

ak (0)2 = bk (0)2 = 1/ m, since
ak (t)2 − bk (t)2 = ak (t)2 − ak (0)2 − bk (t)2 + bk (0)2
Z t 
= ak ȧk − bk ḃk ds
0
Z tX
|`0i | ak σ(bk vkT xi ) − ak σ 0 (bk vkT xi )vkT xi ds

=
0 i
= 0.
This also implies that a2k + b2k = 2a2k = 2|ak | · |bk |.
For each βk , choose j so that kebj vej − βk k ≤  = γgl /2 and e aj = sgn(αk ); this holds with
probability at least 1 − 3δ over the draw of W due to the choice of m, first by noting that with
probability at least 1 − δ, there are at least m/4 positive aj and m/4 negative aj , and then with
probability 1 − 2δ by applying Lemma C.1 to (ebj vj )m j=1 (which are equivalent in distribution to
(vj )j=1 ) for each choice of output sign. For the rest of the proof, reorder the weights ((aj , bj , vj ))m
m
j=1
so that each (αk , βk ) is associated with (ak , bk , vk ).
Now consider the potential function
r
1X  
Φ(W ) := |αk | ln a2k + b2k .
4
k=1
The time derivative of this potential can be lower bounded in terms of the reference margin:
d XX ak σ(bk vkT xi )
Φ(W ) = |`0i |yi |αk |
dt
i
a2k + b2k
k
XX |ak |kbk vk kσ(ebk vekT xi )
=Q qi y i α k
2|ak | · |bk |
k i
XX  
=Q qi yi αk σ (ebk vek − βk + βk )T xi
k i
XX  XX
≥Q qi yi αk σ βkT xi − Q qi |αk | ebk vek − βk
k i k i
X XX
≥ Qγgl qi − Q qi |αk |
i k i
X 
=Q qi γgl −  > 0.
i
As an immediate consequence, the signs of all ((ak , bk ))rk=1 never flip (since this would require their
values to pass through 0, which would cause Φ(W ) = −∞ < Φ(W (0))). This implies the preceding
lower bound always holds, and since Φ(W0 ) > −∞ thanks to random initialization, and
r r
1X   1X 1
Φ(W ) = |αk | ln a2k + b2k ≤ |αk | ln kW k2 = ln kW k,
4 4 2
k=1 k=1
all conditions of Lemma A.6 are satisfied, and the proof is complete after applying Lemma 2.5.

40

You might also like