0% found this document useful (0 votes)
34 views1 page

Higher-Order Runge-Kutta in DNNs

The document discusses applying higher-order Runge-Kutta methods to train neural networks. Specifically, it aims to model deep neural network training as an optimal control problem to simplify network design, analyze stability and generalization, and develop variational frameworks. It presents the motivation for using higher-order methods due to their equivalence to ResNets and skip connections. It outlines the 4th-order Runge-Kutta scheme and provides numerical results demonstrating its effectiveness in training a simple convolutional network model.

Uploaded by

l15801823611
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
34 views1 page

Higher-Order Runge-Kutta in DNNs

The document discusses applying higher-order Runge-Kutta methods to train neural networks. Specifically, it aims to model deep neural network training as an optimal control problem to simplify network design, analyze stability and generalization, and develop variational frameworks. It presents the motivation for using higher-order methods due to their equivalence to ResNets and skip connections. It outlines the 4th-order Runge-Kutta scheme and provides numerical results demonstrating its effectiveness in training a simple convolutional network model.

Uploaded by

l15801823611
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd

A PPLYING H IGHER -O RDER R UNGE -K UTTA M ETHODS TO N EURAL N ETWORKS

D EREK O NKEN AND L ARS R UTHOTTO D EPARTMENT OF M ATHEMATICS AND C OMPUTER S CIENCE , E MORY U NIVERSITY

O BJECTIVES M OTIVATION M ODEL


Broader Goals: Model training of deep neural net- Since the community recognizes the effectiveness of
works (DNNs) as optimal control problem. Resnets and their skip connections (shown to be equiv-
alent to Forward Euler), wouldn’t higher-order Runge- Loop Back Twice
1. simplify design of DNNs Kutta schemes assist in training? P
(≈ discretize a PDE) Opening Layer Dynamic Unit Connecting Layer Dense Layer
o
o
Dog
2. analyze stablity and generalization ! ∘ # ∘ $%&'3 ) Runge-Kutta scheme ! ∘ # ∘ $%&'1 ) +)
(≈ vanishing/exploding gradients) R UNGE -K UTTA S CHEMES l

3. develop variational framework Goal: Improve training by maintaining few parame-


(; multilevel and multiscale learning) ters and controlling conditioning
4. design reversible dynamics N UMERICAL R ESULTS N OISY S TOCHASTIC S HIFTS
(; memory-free learning) Recall the Fourth-Order Runge-Kutta We train a simple model of a convolutional opening Goal: Analyze the network when the time-stepping is
Defining the length of the j-th time interval by layer, three blocks containing the RK scheme doubling varied every epoch
Current focus:
the channels each pass, and one fully connected layer.
hj = tj+1 − tj , Fixing the control time steps tθ = [0, 1, 2, 3, 4] and state
1. research: model order reduction, efficient opti- The dynamic unit is the only portion that we vary.
mization, stable dynamics, time-integrators [1] time steps tY = [0, 1, 2, 3, 4] or [0, 2, 4],
the update scheme reads Our learning strategy uses 120 epochs of SGD with mo- at every epoch, draw noise  from a uniform distribu-
2. community: free MATLAB/Julia software
3. accessibility: building models in pyTorch hj mentum with initial learning rate of 0.1 which reduces tion.
uj+1 = uj + f (θ(tj ), z1 ) + 2f (θ(tj+1/2 ), z2 ) by a factor of 10 after epochs 60, 80, 100.
6 This varies the interpolation of the control weights to
DNN S MEET O PTIMAL C ONTROL

+2f (θ(tj+1/2 ), z3 ) + f (θ j+1 , z4 ) 80
STL-10 Double Sym Layer obtain different state weights.
Goal: Find a function f : Rn × Rp → Rm and its pa- where f is the primary layer in the dynamic unit as a
70
Results:
rameter θ ∈ Rp such that f (yk , θ) ≈ ck for training 60

Validation Accuracy
function of the controls θ(tk ) and intermediate states For a Double Sym Layer in the dynamic unit
data y1 , . . . , ys ∈ Rn and labels c1 , . . . , cs ∈ Rm . zi that are computed as follows 50

CIFAR-10 Noisy Double Sym Layer tY=[0,1,2,3,4]


40 100

Model ykN = f (yk , θ) as output of Residual Neural z1 = uj 30 RK4 [2]


90

Network (ResNN) with N layers. Let yk0 = yk and hj


RK4 [1]
RK1 [1] 80
20 RK1 [.5]

Valdiation Accuracy
z2 = uj + f (θ(tj ), uj ) RK1 [.25] 70

yki+1 = yki + hg(yki , θi ), ∀i = 0, . . . , N − 1. 2 10


RK1 [.125]
60
0 500 1000 1500 2000 2500 3000
hj Time (s) 50
z3 = uj + f (θ(tj+1/2 ), z1 )
(g transforms features, e.g., g(y, θ) = tanh(K(θ)y)) 2 100
CIFAR-10 Double Sym Layer 40
no noise
U[-.1,.1]

Note that ResNN is a forward Euler discretization [2] z4 = uj + hj f (θ(tj+1/2 ), z2 ) 90


30
U[-.2,.2]
U[-.3,.3]
20 U[-.4,.4]
of the initial value problem (t ∈ [0, T ]) 80
U[-.5,.5]

Validation Accuracy
10
0 20 40 60 80 100 120
From this RK4 scheme for f , we build a dynamic unit 70
Epoch

∂t yk (t, θ) = g(yk (t, θ), θ(t)), yk (0, θ) = yk as part of a simple model to compare different time- CIFAR-10 Noisy Double Sym Layer tY=[0,2,4]
60 100

steppings for when f is a layer of type: 90


Learning: Find θ and weights of classifier by solving 50 RK4 [2]
RK4 [1]
80
RK1 [1]
Double / ResNN: σ2 ◦ N2 ◦ Kθ2 ◦ σ1 ◦ N1 ◦ Kθ1 (Y ) 40 RK1 [.5]

Valdiation Accuracy
RK1 [.25] 70
s
1 X 30
RK1 [.125]

min loss(yk (T, θ)W, ck ) + regularizer(θ, W). Preactivated Double: N2 ◦ Kθ2 ◦ σ2 ◦ N1 ◦ Kθ1 ◦ σ1 (Y ) 0 2000 4000 6000 8000 10000 12000 14000 16000 18000
60

θ,W s Time (s) 50


k=1
Double Sym / Parabolic [3]: −Kθ> ◦ σ ◦ N ◦ Kθ (Y ) 40
no noise
U[-.1,.1]
learning ≈ mass transport, trajectory planning for activation functions σ, normalizations N , and con- T EAM 30

20
U[-.2,.2]
U[-.3,.3]
U[-.4,.4]

volution operators K defined by weights θ 10


U[-.5,.5]

R EFERENCES • Eldad Haber (UBC, Vancouver)


0 20 40 60
Epoch
80 100 120

• Eran Treister (Ben Gurion, Israel)


[1] Chen et al. Neural Ordinary Differential Equations.. NeurIPS, S OFTWARE • Simion Novikov (Ben Gurion, Israel)
2018.
Github:
F UTURE D IRECTIONS
[2] E Haber, L Ruthotto Stable Architectures for Deep Neural Net- • Meganet.m: academic and teaching tool
• Loss Landscape Analysis
works. Inverse Problems, 2017. • Meganet.jl: high-performance dis- F UNDING
[3] L Ruthotto, E Haber Deep Neural Networks Motivated by Par- tributed computing • Adversarial Vulnerability Analysis
tial Differential Equations. arXiv, 2018.
Supported by the National Science Foundation
• PyTorch implementations in the works awards DMS 1522599 and CAREER DMS 1751636 • Adaptive Time-Stepping
and by NVIDIA Corporation. • Adams-Bashforth Methods

You might also like