You are on page 1of 21
ta, By using Medium, you ag You have 2 free member-only stories left this month, Sign up for Medium and get an extra one @ Cameron Wolfe Sep17 . 19minread . 4 . © Listen (isve vy Of & Critical Learning Periods in Deep Networks Why the first epochs matter the most... Deficit Normal training 0 f N 160+ SSeS Bends we DNN Kitten 34% i: 92% 6 oo 5 0% q 3 B% z 4 le. ® 5% | 2a , 3 ' s fF o% 2 ' ' i 82% | i 1 4 2 0 D 4 © @ 100120140 a 100 20 Deficit removal (epoch) Age (days) a a) ata related ical periods are defined as times of early post-natal (.e., after birth) development, during which impairments to learning (e.g., sensory deficits) can lead to permanent impairment of one’s skills [5]. For example, vision impairments at a young age — a critical period for the development of one’s eyesight — often lead to problems like amblyopia in adult humans. Although I am far from a biological expert (in fact, I haven't taken a biology class since high school), this concept of critical learning periods is still curiously relevant to deep learning, as the same behavior is exhibited within the learning process for neural networks. If a neural network is subjected to some impairment (e.g., only shown blurry images or not regularized properly) during the early phase of learning, the resulting network (after training is fully complete) will generalize more poorly relative to a network that never received such an early learning impairment, even given an unlimited training budget. Recovering from this carly learning impairment is not possible. In analyzing this curious behavior, researchers have found that neural network training seems to progress in two phases. During the first phase — the critical period that is sensitive to learning deficits — the network memorizes the data and passes through a bottleneck in the optimization landscape, eventually finding a more well- behaved region within which convergence can be achieved. From here, the network goes through a forgetting process and learns generalizable features rather than memorizing the data. In this phase, the network exists within a region of the loss landscape in which many equally-performant local optima exist, and eventually converges to one of these solutions. Critical learning periods are fundamental to our understanding of deep learning as a whole. Within this overview, I will embrace the fundamental nature of the topic by first overviewing basic components of the neural network learning process. Given this background, my hope is that the resulting overview of critical learning periods will provide a more nuanced perspective that reveals the true complexity of training deep a) networks. Si arning process for neural networks and critical periods during learning. The overviews within this section are quite broad and may take time to truly grasp, so I provide further links for those who need more depth. Neural Network Training Neural network training is a fundamental aspect of deep learning. Covering the full depth of this topic is beyond the scope of this overview. However, to understand critical learning periods, one must have at least a basic grasp of the training procedure for neural networks. The goal of neural network training is — starting from a neural network with randomly-initialized weights — to learn a set of parameters that allow the neural network to accurately produce a desired output given some input. Such input and output can take many forms — key points predicted on an image, a classification of text, object detections in a video, and more. Additionally, the neural network architecture oftentimes changes depending on the type of input data and problem being solved. Despite the variance in neural network definitions and applications, however, the basic concepts of model training remain (more or less) the same. To make Medium work, we log us our Privacy Policy, including ar Class 3: Gym Class 4: Airport [Positiv or Negative “Thad a great day at” school!” ‘Sentiment ek wl & | :—— Depiction of the input-output map learned by neural networks in different problem domains (created by author) To learn this map between input and desired output, we need a (preferably large) training dataset of input-output pairs. Why? So that we can: 1. Make predictions on the data 2. See how the model's predictions compare to the desired output 3. Update the model’s parameters to make predictions better This process of updating the model’s parameters over training data to better match known labels is the crux of the learning process. For deep networks, this learning process is performed for several epochs, defined as full passes through the training dataset. ‘To determine the quality of model predictions, we define a loss function. The goal of training is to minimize this loss function, thus maximizing the quality of model parameters. log user data, By using Medium, y To make Medium work, Privacy Policy, including cookie policy. + Computes the gradient of the loss function * Uses the chain rule of calculus to compute the gradient of the loss with respect to every parameter within the model * Subtracts the gradient, scaled by a learning rate, from each parameter Although a bit complicated to understand in detail, SGD at an intuitive level is quite simple — each iteration just determines the direction that model parameters should be updated to decrease the loss and takes a small step in this direction. We perform optimization over the network's parameters to minimize training loss. See below for a schematic depiction of this process, where the learning rate setting controls the size of each SGD step. Low Learning Rate Loss Model Parameters High Learning Rate Loss Model Parameters Minimization of the Training Loss with Different Learning Rates (created by author) a Q To make Mediu Policy, includi er data, By usi cookie policy. model that f g that it also performs well — to unseen testing data. Dataset Get Mini-Batch Training Iteration ‘Compute Gradient ~~ Compute Loss Update Model Parameters. Forward Pass Neural Network Basic Illustration of the Steps within a Neural Network Training Iteration (created by author) See the figure above for a high-level depiction of the neural network training process. There are many more details that go into neural network training, but the purpose of this overview is to understand critical learning periods, not to take a deep dive into neural network training. Thus, I provide below some links to useful articles that can be used to understand key neural network training concepts in greater detail for the interested reader. + Neural Network Training Overview [blog] [video] Understanding Backpropagation [blog] [video] SGD (and Other Optimization Algorithms) [blog] Basic Neural Network Training in PyTorch [notebook] What is generalization? [blog] a) performance Mon unseen testing data when it is deployed into the real world. A model that performs well on such unseen data is said to generalize well. Minimizing loss on the training data does not guarantee that a model will generalize. For example, a model could just “memorize” each training example, thus preventing it from learning generalizable patterns that can be applied to unseen data. Ti ensure good generalization, deep learning practitioners typically utilize regularization techniques. Many such techniques exist, but the most relevant for the purposes of this post are weight decay and data augmentation. Weight decay is a technique that is commonly applied to the training of machine learning models (even beyond neural networks). The idea is simple. During training, adjust your loss function to penalize the model for learning parameters with large magnitude. Then, optimizing the loss function becomes a joint goal of (i) minimizing loss over the training set and (ii) making network parameters low in magnitude. The strength of weight decay during training can be adjusted to find different tradeoffs between these two goals — it is another hyperparameter of the learning process that can be tweaked/modified (similar to the learning rate). To learn more, I suggest reading this article. Data augmentation takes many different forms depending on the domain and setting in which itis being applied. But, the fundamental idea behind data augmentation remains constant — each time your model encounters some data during training, one should randomly change the data a little bit in a way that still preserves the data’s output label. Thus, your model never sees the same data example twice. Rather, the data is always slightly perturbed, preventing the model from simply memorizing examples from the training set. Although data augmentation can take many different forms, numerous survey papers and explanations exist that can be used to better understand these techniques. * Data Augmentation for Computer Vision [blog] [survey] log user data, By using Medium, y To make Medium work, e a) one will alsc 3 for deep networks. All of these methods follow the same learning process outlined above — pre- training and fine-tuning are just terms that refer to a specific, slightly-modified setup for the same training process. Pre-training typically refers to training a model from scratch (i.e., random initialization) over a very large dataset. Although such training over large pre-training datasets is computationally expensive, model weights learned from pre-training can be very useful, as they contain patterns that have been learned from raining over a lot of data that may generalize elsewhere (e.g., learning how to detect edges, understanding shapes/textures, etc.). Pre-Training followed by Fine-Tuning (Performs Better) Normal Training from Random Initialization (Performs Well) Randomly-Intialized Neural Network Pro-Training Dataset (large) Pro-Trained Neural Network Illustration of the Differences between Pre-Training, Fine-Tuning, and Normal Training of a Neural Network (created by author) Pre-trained model parameters are often used as a “warm start” for performing training on other datasets, often referred to as the downstream or target dataset. Instead of initializing model parameters randomly when performing downstream training, we can set model parameters equal to the pre-trained weights and fine-tune — or further train — these weights on the downstream dataset; see the figure above. If the pre- training dataset is sufficiently large, such an approach yields improved performance, a Q a) existence of t paper studies the impact of data blurring on the learning process, while the following papers study learning behavior with respect to model regularization and data distributions during training. Despite taking different approaches, each of these works follow a similar approach of: * Applying some impairment to a portion of the learning process * Analyzing how such a deficit impacts model performance after training Critical Learning Periods in Deep Networks [1] owt era wang ont ‘Ee Seeeoee” ‘Be Fae” 22m” own ten DAN sensiovy ton sony oe : i el an i Bom Det removal (epoch) ge (eaye) ‘Wedow eet epoch) Widow onset (day) Figure 1: DNNs exhibit critical periods. (A) Final accuracy achieved by a CNN trained with a cataract-like deficit as a function of the training epoch Nat which the deficit is removed (solid line). Performance is permanently impaired if the deficit is not corrected early enough, regardless of how much additional training is performed. As in animal models, critical periods coincide with the early learning phase during which, in the absence of deficits, test accuracy would rapidly in- crease (dashed). (B) For comparison, we report acuity for kittens monocularly deprived since birth and tested at the time of eye-opening (solid), and normal visual acuity development (in kittens) as a function of their age (dashed) (Giffin & Mitchell, 1978; Mitchell, 1988). Sensitivity during learning: (C) Final test accuracy of a DNN as a function of the onset of a short 40-epoch deficit. ‘The decrease in the final performance can be used to measure the sensitivity to deficits. The most sensitive epochs corresponds to the early rapid learning phase, before the test error (dashed line) begins to plateau. Afterwards, the network is largely unaffected by the temporary deficit. (D) This can be compared with changes in the degree of functional disconnection (normalized numbers of V1 monocular cells disconnected from the contralateral eye) as a function of the kittens’ age at the onset of a 10-12-day deficit window (Olson & Freeman, 1980). Dashed lines are as in A and B respectively, up to a re-scaling of the y-axis. (from [1)) a) artificial net 2B blurring of images) to the training of deep neural networks, even for only a short period of time, can result in degraded performance. Going further, the extent of the damage to performance depends on when and how long the impairment is implied — a finding that mirrors the behavior of biological systems. For example, if the impairment is applied at the beginning of training, there exists a sufficient number of impaired learning epochs, beyond which the deep network's performance will never recover. Biological neural networks demonstrate similar properties with respect to early impairments to learning. Namely, experiencing an impairment to learning for too long during early stages of development can have permanent consequences (¢.g., amblyopia). The figure above demonstrates the impact of critical learning periods in both artificial and biological systems. Ata high level, the discovery within this paper can be simply stated as follows: If one impairs a deep network's training process in a sustained fashion during the early epochs of training, the network's performance cannot recover from this impairment To better understand this phenomenon, authors quantitatively study the connectivity of the network's weight matrices, finding that learning is comprised of a two-step process of “memorizing”, then “forgetting”. More specifically, the network memorizes data during the early learning period, then reorganizes/forgets such data as it begins to learn more efficient, generalizable patterns. During the early memorization period, the network navigates a bottleneck in the loss landscape — the network is quite sensitive to learning impairments as it traverses this narrow landscape. Eventually, however, the network escapes this bottleneck to discover a wider valley that contains many high-performing solutions — the network is more robust to learning impairments within this region. Methodology. Within this work, authors train a convolutional neural network ine tha OTPAD 10 das re iam a) impairment ing process. By studying the impact of such impairments on network performance, the authors discover that: * Ifthe impairment is not removed sufficiently early during training, then network performance will be permanently damaged. * Sensitivity to such learning impairments peaks during the early period of learning (i.e., the first 20% of epochs). To further explore the properties of critical learning periods in deep networks, authors measure the Fisher information within the model’s parameters, which quantitatively describes the connectivity between network layers, or the amount of “useful information” contained within network weights. Fisher information is found to increase rapidly during early training stages, then decay throughout the remainder of training. Such a trend reveals that the model first memorizes information during the early learning phase, then slowly reorganizes or reduces this information — even as classification performance improves — by removing redundancy and establishing robustness to non-relevant variability in the data. When an impairment is applied, the Fisher Information grows and remains much higher than normal, even after the deficit is removed, revealing that the network is less capable of learning generalizable data representations in this case. See the figure below for an illustration of this trend. Norm. tno in V Flip deficit until epoch 40 Flip deficit uni epoch 100 Norm. Info in Weights Epoch Figure 5: Normalized quantity of information contained in the weights of each layer as a function of the training epoch. (Top Left) In the absence of deficits, the network relies mostly on the middle layers (3-4-5) to solve the task. (Top Right) In the presence of an image blur deficit until epoch 100, more resources are allocated to the higher layers (6-7) rather than to the middle layers. ‘The blur deficit destroys low- and mid-level features processed by those layers, leaving only the global features of the image, which are processed by the higher layers. Even if the deficit is removed, the middle layers remain underdeveloped. (Top Center) When the deficit is removed at an earlier epoch, the layers can partially reconfigure (notice, e.g., the fast loss of information of layer 6), resulting in less severe long-term consequences. We refer to the redistribution of information and the relative changes in effective connectivity as “Information Plasticity”. (Bottom row) Same plots, but using a vertical flip deficit, which does not induce a critical period. As expected, the quantity of information in the layers is not affected. (from (1) Findings. * Network performance is most sensitive to impairments during the early stage of training. If image blurring is not removed within the first 25-40% of training epochs (i.e., the exact ratio depends on network architecture and training hyperparameters) for the deep network, then network performance will be permanently damaged. + High-level changes to data (e..g, vertical flipping of images, permutation of output labels) do not have any impact of network performance. Additionally, performing impaired training with white noise does not damage network performance — completely sensory deprivation (i.e., this parallels dark rearing in biological a) ets, where level image features can be most efficiently processed. Impairments ¢ Fisher in, ay pawtnny sang a low and mic to the learning process lead to a concentration of Fisher information in the final network layer, which contains no lower or mid-level features, unless the deficit is removed sufficiently early in training. Time Matters in Regularizing Deep Networks: Weight Decay and Data Augmentation Affect Early Learning Dynamics, Matter Little Near Convergence [2] ResNet 18 on OFAR20 A. Esty Termination of wo Batayes pation of WO © sing vandow (0) Epon bec Epoch D_ ExryteinatontWo ay emiaton et oA i oo 7 mi Figure 2: Intermediate application or removal of regularization affects the final solution: (A- ) La norm of the weights as a function of the training epoch (corresponding to Figure 1 (Top)). The weights of the network move afier application or removal of regularization, which can be seen by the change in their norm, Correlation between the norm of the weights and generalization properties is not as straightforward as lower norm implying better generalization. For instance, (C) applying weight decay only at the beginning (curve 0) reduces the norm only during the critical period, and yields higher norm asymptotically than, for example, curve 25. Yet it has better generalization. This suggests that the having a lower norm mostly help only during the critical period. (D) PCA-projection of the training paths obtained removing weight decay at different times (see Appendix A.1). Removing WD before the end of the critical period (curves 25, 50) makes the network converge to different regions of the parameter space. Removing WD after the critical period (curves 75 to 200) still sensibly changes the final point (in particular, critical periods are not due the optimization being stuck in a local minimum), but all points lie in a similar area, supporting the Critical Period interpretation of [1]. (E) Same plots, but for DA, which unlike WD does not have a sharp critical period: all training paths converge to a similar area. (from [2)) a) bias the lear nal, critical point of learning is smooth/flat in the loss landscape, which is (arguably) indicative of good generalization performance. Whether such an intuition is correct is a subject of hot debate — one can read several interesting articles about the connection between local curvature and generalization online. ‘This paper proposes an alternative perspective of regularization, going beyond these basic intuitions, The authors find that removing regularization (i.e., weight decay and data augmentation) after the early epochs of training does not alter network performance. On the other hand, if regularization is only applied during the later stages of training, it does not benefit network performance — the network performs just as poorly as if regularization were never applied. Such results collectively demonstrate the existence of a critical period for regularizing deep networks that is indicative of final performance; see the figure above. Such a result reveals that regularization does not simply bias network optimization towards final solutions that generalize well. If this intuition were correct, removing regularization during the later training periods — when the network begins to converge to its final solution — would be problematic. Rather, regularization is found to have an impact on the early learning transient, biasing the network optimization process towards regions of the loss landscape that contain numerous solutions with good generalization to be explored later in training. Methodology. Similarly to previous work, the impact of regularization on network performance is studied using convolutional neural network architectures on the CIFAR-10 dataset. In each experiment, the authors apply regularization (.e., weight decay and data augmentation) to the learning process for the first t epochs, then continue training without regularization. When comparing the generalization performance of networks with regularization applied for different durations at the beginning of training, authors find that good generalization can be achieved by only performing regularization during the earlier phase of training. a) benefit in te the findings in [1], as the lack of regularization imposed can be viewed as a form of learning deficit that impairs network performance. Findings. * The effect of regularization on final performance is maximal during the initial, “critical” training epochs. + The critical period behavior of weight decay is more pronounced than that of data augmentation. Data augmentation impacts network performance similarly throughout training, while weight decay is most effective when applied during earlier epochs of training. + Performing regularization for the entire duration of training yields networks that achieve comparable generalization performance to those that only receive regularization during the early learning transient (Le, first 50% of training epochs). * Using regularization or not during later training periods results in different points of convergence (i.e., the final solution is not identical), but the resulting generalization performance is the same. Such a result reveals that regularization “directs” training during the early period toward regions with multiple, different solutions that perform equally well. ‘On Warm-Starting Neural Network Training [3] § 60 30 2 < | = 40 3 40 y é e 20 20 o i ° @ 160 260 360 400 500 600 700 3 160 260 360 450 500 650 750 Epoch Epoch Figure 1: A comparison between ResNets trained using a warm start and a random initialization on CIFAR-10. Blue lines are models trained on 50% of CIFAR-10 for 350 epochs then trained on 100% of the data for a further 350 epochs. Orange lines are models trained on 100% of the data from the start. The two procedures produce sim- ilar training performance but differing test performance. (from (3) Main Idea. In real-world machine learning systems, it is common for new data to arrive in an incremental fashion. Generally, one will begin with some aggregated dataset, then over time, as new data becomes available, this dataset grows and evolves. In such a case, sequences of deep learning models are trained over each version of the dataset, where each model takes advantage of all the data that is available so far. Given such a setup, however, one may begin to wonder whether a “warm start” could be formulated, such that each model in this sequence begins training with the parameters of the previous model, mimicking a form of pre-training that allows model training to be more efficient and high-performing. In [3], the authors find that simply initializing model parameters with the parameters of a previously-trained model is not sufficient to achieve good generalization performance. Although final training losses are similar, models that first pre-trained over a smaller subset of data, then fine-tuned on the full dataset achieve degraded test accuracy in comparison to models that are randomlv initialized and trained using the a) degraded pe ver, the authors propose a simple warm-starting technique that can be used to avoid such deteriorations in test accuracy. Methodology. Consider a setup where new data arrives into a system once each day. In such a system, one would ideally re-train their model when this new data arrives each day. Then, to minimize training time, a naive warm starting approach could be implemented by initializing the new model’s parameters with the parameters of the previous days’ model prior to training/fine-tuning. Interestingly, however, such a warm starting approach is found to yield models that generalize poorly, revealing that pre-training over an incomplete subset of data is a form of learning impairment when applied during a critical period. To overcome the impact of this impairment, the authors propose a simple technique called Shrink, Perturb, Repeat that: 1. Shrinks model weights towards zero. 2. Adds a small amount of noise to model weights. If such a procedure is applied to the weights of a previous model trained over an incomplete subset of data, then the parameters of this model can be used to warm start training over the full dataset without causing any deterioration in generalization performance. Although the amount of shrinking and scale of the noise introduce new hyperparameters to the training process, this simple trick yields remarkable computational savings — due to the ability to warm start, and thus speed up, model training — with no deterioration to network performance. ‘To elucidate the efficacy of this approach, authors explain that a naive warm start methodology experiences significant imbalances between the gradients of new and old data. Such imbalances are known to negatively impact the learning process [4]. However, shrinking and noising model parameters prior to training both (i) preserves network predictions and (ii) balances the gradient contributions of new and old data, =) Findings. * Although critical learning periods associated with incomplete datasets are demonstrated in deep networks, simpler models (e.g., logistic regression) do not experience such an effect (ie., likely because training is convex). * The degradation in test accuracy due to naive warm starting cannot be alleviated via tuning of hyperparameters like batch size or learning rate. + Only a small amount of training (i.e., a few epochs) over an incomplete subset of data is necessary to damage the test accuracy of a model trained over the full dataset, further revealing that training over incomplete data is a form of learning impairment with connections to critical learning periods. + Leveraging the Shrink, Perturb, Repeat method completely eliminates the generalization gap between randomly-initialized and warm-started models, enabling significant computational savings. Is deep learning theory missing the mark? ‘The existence of critical learning periods gives birth to an interesting perspective of the learning process for deep neural networks. Namely, the fact that such networks cannot recover from impairments applied during the early epochs of training reveals that learning progresses in two, distinct phases, each of which have interesting properties and behavior. 1. Critical Learning Period: the memorization period. The network must navigate a narrow/bottlenecked region of the loss landscape. 2. Converging to a Final Solution: the forgetting period. After traversing a bottlenecked region of the loss landscape, the network enters a wide valley of many equally-performant solutions to which it can converge. The critical learning period during the early learning transient plays a key role in a) of the final, of er ical learning periods or different phases of learning appear. The convincing empirical results that outline the existence of critical learning periods within deep networks hint that there is more to deep learning than is revealed by current, asymptotic analysis. Theoretical analysis that truly captures the complexity of learning within deep networks is yet to come. Takeaways ‘The takeaways from the overview can be stated pretty simply: * Neural network training seems to proceed in two major phases — memorization and forgetting. + Impairing the learning process during the first, early phase is not good. To be a bit more specific, learning impairments during the first phase aren't just bad... they are seemingly catastrophic. One cannot recover from these impairments during the second phase, and the resulting network is doomed to poor performance in most cases. The work overviewed here has demonstrated this property in numerous domains, showing that the following impairments applied during the first phase of learning can degrade network generalization: * Sufficiently blurred images * Lack of regularization (.c., data augmentation or weight decay) Lack of sufficient data Critical learning periods provide a unique perspective on neural network training that makes even seasoned researchers question their intuitions. This two-phase view of neural network training defies commonly-held beliefs and is not reflected within much of the theoretical analysis of deep networks, revealing that much more work is to be done if we are to collectively arrive at a more nuanced understanding of deep 2. Friendly 3. Catastrophic Fisher Explosion: Early Phase Fisher Matrix Impacts Generalization Conclusion ‘Thanks so much for reading this article. I hope that you enjoyed it and learned something new. I am Cameron R. Wolfe, a research scientist at Alegion and PhD student at Rice University studying the empirical and theoretical foundations of deep learning. If you liked this post, please follow my Deep (Learning) Focus newsletter, where I pick a single, bi-weekly topic in deep learning research, provide an understanding of relevant background information, then overview a handful of popular papers on the topic. You can also check out my other writings! jography [1] Achille, Alessandro, Matteo Rovere, and Stefano Soatto. “Critical learning periods in deep networks.” International Conference on Learning Representations. 2018. [2] Golatkar, Aditya Sharad, Alessandro Achille, and Stefano Soatto. “Time matters in regularizing deep networks: Weight decay and data augmentation affect early learning dynamics, matter little near convergence.” Advances in Neural Information Processing Systems 32 (2019). [3] Ash, Jordan, and Ryan P. Adams. “On warm-starting neural network training.” Advances in Neural Information Processing Systems 33 (2020): 3884-3894. [4] Yu, Tianhe, et al. “Gradient surgery for multi-task learning.” Advances in Neural Information Processing Systems 33 (2020): 5824-5836. [5] Eric R Kandel, James H Schwartz, Thomas M Jessel, Steven A Siegelbaum, and A James Hudspeth. Principles of Neural Science. McGraw-Hill, New York, NY, 5th edition, 2013. ou setsones ) Sign up for The Variable By Towards Data Science Every Thursday, the Variable delivers the very best of Towards Data Science: from hands-on tutorials and cutting-edge research to original features you don't want to miss. Take a look. 8y signing up, you will ereate a Medium account Ityou dont already have one, Review ‘our Pvacy Poliey for more Information about our privacy practices ( & Getthis newsletter \ atti creo ryg

You might also like