You are on page 1of 10

Existing large language models (LLMs) can only afford fix-sized inputs due to the input length limit,

preventing them from utilizing rich long-context information from past inputs.

 LongMem : memorize long history

A novel decouple network architecture :

 Original backbone LLM frozen as encoder.


 Adaptive residual side-network as memory retriever and reader

 easily cache and update long-term past contexts for memory retrieval without suffering from memory
staleness.

Typically, LONGMEM can enlarge the long-form memory to 65k tokens and thus cache many-shot extra
demonstration examples as long-form memory for in-context learning.

our method outperforms strong long-context models on ChapterBreak, a challenging long-context


modeling benchmark.  effective in helping language models to memorize and utilize long-form
contents.

1. INTRODUCTION

LLMs have revolutionized NLP with great successes in advancing the SOTA on various understanding and
generation tasks with emergent abilities such as zero-shot prompting, in-context learning, and Chain-of-
Thought (CoT) reasoning.

Limitations: these models struggle when it comes to handling long-form in4: can’t process in4 beyond a
fix-sized session, which is a problem for real-world applications that require long-term planning.

 most straightforward method: increase input context length, similar to what GPT-3 did when it
doubled its input length from 1k to 2k tokens as compared to GPT-2

But not an ideal solution as it demands heavy computational power + still struggles with the issue of in-
context dense attention.

Alternative approach: develop in-context sparse attention to sidestep this problem but still need lots of
computation and has to start from scratch.

Memorizing Transformer (MemTRM): use dense attention over both in-context and memorized tokens:
handle up to 65k tokens and significantly improves its ability to handle long text. However, as the model
para are updated, the representations in memory may become outdated, affecting the overall
effectiveness of memory augmentation.

 LongMem : allow models to store previous context or knowledge into a non-differentiable


memory bank.

It uses a unique residual side-network (SideNet) which operates separately from the main memory
module to prevent the issue of memory staleness.
We extract attention keys and values from previous contexts using a frozen backbone LLM into the
memory bank  generated attention query of the current input is used to retrieve theses cached keys
and values from the memory , which are then integrated into learned hidden states through a joint-
attention mechanism.

Additionally, we have designed cross-network residual connections between SideNet and frozen
backbone LLM for better knowledge transfer.

2 main benefits:

 Separate the process of encoding previous inputs into memory form the process of memory
retrieval and fusion  resolving the issue of memory staleness.
 Adapting the entire LLM with memory augmentations is computationally inefficient and prone to
catastrophic forgetting. As the backbone LLM is frozen during the efficient memory-augmented
adaptation stage, LongMem can effectively utilize pre-trained knowledge without the risk of
catastrophic forgetting.

Test with 2 representative cases: language modeling with full-length book contexts + in-context learning
with thousands of task-relevant examples.

our model consistently outperforms the strong baselines in terms of long-text modeling and in-context
learning abilities.

LongMem improves LLM’s long-context language modeling capabilities significantly and achieve SOTA
performance on a challenging long-context modeling benchmark.

In-context learning improvements on popular NLU tasks were also notable when compared with
MemTRM and non-memory-augmented baselines.
SUMMARY: LLMs have limitations in processing long-form in4 beyond a fixed context length. Scaling up
input context length incurs computation-intensive training and the in-context attention is still heavily
constrained by the quadratic computation complexity of Transformer self-attention. To address this, the
proposed framework for LMs LongMem enables language models to cache long-form previous context or
knowledge into the non-differentiable memory bank, and further take advantage of them via a
decoupled memory module to address the memory staleness problem. LongMem outperforms strong
baselines in terms of long-text modeling and in-context learning abilities.

2. METHODS

Augment the frozen backbone LLM with a decoupled memory module.

Design a novel lightweight residual SideNet, which can be continually trained in an efficient way.

 The problem formulation of language modeling with memory augmentations


 Residual SideNet for adapting the frozen pretrained LLM to jointly attend over local input context
and retrieved memory context
 Processes of how past memory is encoded, stored, recalled and fused for language modeling.
2.1. Language Models Augmented with Long-Term Memory

LongMem model is built on the Transformer architecture with 3 key components: frozen backbone LLM,
SideNet and Cache Memory Bank.

In most pre-existing LLMs, the input has a fixed size, meaning they can only consider a specific portion of
a longer sequence (e.g., a book). This portion that fits within a size limit is referred to as the current
input. Previous segments of the sequence that couldn’t fit into the input are regarded as previous inputs
and are used for memory augmentations. We use frozen backbone LLM to encode both the current and
previous inputs , but we extract different kinds of information from them. From the prev inputs, we take
the key-value pairs from the Transformer’s self-attention at a specific layer and store them in the Cache
Memory Bank. For the current inputs , we keep the hidden states from each decoder layer of the LLM
and transfer these to the SideNet. The SideNet is trained to integrate the current input context with
relevant cached prev contexts from the memory.
The process starts with a forward pass through the frozen backbone LLM for a given fixed-sized input
text sequence, without any gradient calculation. The input seq  embedding space, outputting initial
hidden states. Each subsequent layer of LLM decoder uses the hidden states from prev layer to compute
new hidden states.

The key-value pairs used for self-attention at a certain decoder layer are stored in the CMB (maintain
latest k-v pairs from prev inputs and is used as memory augmentations for future inputs). Once the
memory retrieval and fusion process is complete, the CMB removes the k-v pairs of the oldest seq and
appends current seq to the cached vector bank.  ensure causality in language modeling at the seq
level and allows the memory bank to always keep records of the nearest prev context for the current
inputs.

Next, the SideNet module takes hidden states from the current input and past k-v pairs from CMB to
compute memory-augmented representations. It is composed of several standard Transformer decoder
layers, with one special memory-augmented decoder layer. This layer takes a memory-augmented input,
including the top relevant k-v pairs in memory and the hidden states from the current input. The k-v
pairs are retrieved using a token-based memory retrieval module. The SideNet then computes the
output using memory-augmented input.

Finally, the token probability is computed using the last SideNet hidden states. The aim is to maximize
the likelihood of the next token based on the left context. This output embedding is shared by both the
frozen LLM and the SideNet. We perfom a memory-augmented adaptation training for LONGMEM to
utilize the decoupled memory. Following generative unsupervised pre-training and the text corpus used
for pre-training is randomly sampled.

2.2. Residual SideNet

SideNet Architecture and Initialization: based on Transformer. To establish a balance between


complexity and performance, we employ a reduction factor to determine the number of decoder layers
in SideNet. The number of decoder layers L in SideNet is equal to the number of layers L′ in the backbone
LLM divided by a reduction factor (a layer reduction factor of 2 throughout this work L′ = 2L).

The weights of each decoder layer in SideNet are initialized from the corresponding pre-trained decoder
layer of the backbone LLM with the same depth.

SideNet takes the output of backbone LLM’s embedding layer and reuses the language modeling head
layer of backbone LLM (frozen during the continual adaption stage).

All other parameters of SideNet are updated accordingly based on the training signal.

As a result, SideNet, while being lightweight, can learn quickly by leveraging the knowledge transferred
from the pre-trained parameters.

Cross-Network Residual Connections: enhanced the learning capacity of SideNet. Fusing


representations from the backbone LLM into SideNet.

Add the difference between output hidden states at 2l-th and (2l − 2)-th layers of the backbone LLM as
the residual connections to the output hidden states at l-th layer of SideNet.
The input to the next (l + 1)-th layer of SideNet is the sum of the original hidden state forwarded through
the previous layer and the calculated difference of the hidden states from the backbone LLM.

It's important to highlight that the established practices for residual connections in the decoder layer of
a Transformer model remain in place and work in parallel with our proposed cross-network residual
connections.

2.3. Memory Retrieval and Fusion

Allow LongMem to access long-term memory, and it is made up of 2 components: token-to-chunk


memory retrieval and memory fusion.

Token-to-Chunk Memory Retrieval: Instead of performing token-to-token retrieval  token-to-chunk


retrieval for acceleration and integrity.

This chunk is an n-gram structure with chunk-size csz number of contiguous tokens.

The memory bank stores cached key-value pairs at the level of token chunks. Divide the memory bank
into M/csz attention key-value paired chunks. mean-pooled vector on the chunk-size dimension  key
vector for retrieval.

retrieve the top-(K/csz) attention key-value chunks w.r.t the dot product between the attention query of
the current input token and the mean-pooled attention key of a candidate chunk.

squeeze the chunk-size dimension for retrieved key-value paired chunks and flatten them into K key-
value pairs at token-level.

 reduce size of retrieval index, accelerate the process, improve accuracy. Chunk-size can be adjusted
based on the downstream tasks.

Memory Fusion: is performed within a special memory-augmented layer. This layer extends traditional
Transformer decoder layer's multi-head self-attention to a joint-attention mechanism  enable each
token to attend on both local contexts and retrieved memory contexts.

Take output from the prev layer and retrieved attention k-v pairs  generate hidden states for the
current memory-enhanced layer

the retrieved attention key-value pairs in cached memory are distinct to each token. This uniqueness
ensures that our model can tailor its understanding to each specific token. We use specific matrices to
transform the output from the previous layer into queries, keys, and values for the current layer's
attention mechanism.

3. EXPERIMENTS
Evaluate LongMem on different tasks that require in-memory long-contexts

a. long-text language modeling and language understanding when loading the past long-context
into cached memory
b. infinite-length in-context learning when loading large number of demonstration examples into
cached memory
3.1. Training Setup

Batchfying the training corpora:

We typically prepare or 'batchify' large data sets by cutting them into equal-length text segments and
shuffling these pieces to create mini-batches.

In contrast, LONGMEM must disable global shuffling and ensure the global causality at segment level:

Divide long documents into equivalent-length groups and then shuffle these documents within each
group.

concatenate shuffled documents and truncate them into ordered segments.

to ensure that two consecutive segments of one long document are distributed in two consecutive input
batches after batchfying, we select one segment from each document group with the same inner-group
index.  a mini-batch is made up of segments from the same number of document groups.

as the training iteration steps, the cached attention key-value pairs in memory bank are exactly previous
context of current inputs within the same document.

Training corpus and hyperparameters:

Sample subset of The Pile


Use GPT-2 (407M para) as pre-trained backbone LLM. We had to make a slight adjustment to the
position embedding because the original GPT-2 model didn't perform well when trying to learn long-
distance dependencies.

The backbone LLM: 24 layers, 16 heads, model size of 64

The SideNet: 12 layers, 16 heads, model size of 64

26B tokens, with a global 256 batch-size and 1024 sequence length

chunk-size csz is 4 tokens and the memory size M is 65k key-value pairs of tokens. Fetch 64 attention key-
value pairs for augmentation, which are K/csz=16 text chunks.

memory-augmentation layer is the 9-th layer of SideNet

attention keys and values from 18-th layer of backbone LLM is stored for future use.

Memory Retrieval Module:

The fixed memory-size of cached memory bank in one GPU is 65536 key-value pairs of tokens.

Use faiss toolkit to contruct an exact-search index for efficient retrieval.

The retrieval takes about 15ms per 1k tokens, which is 55% timecost of backbone LLM forwarding pass.

Baselines:

Pretrained GPT-2 + MemTRM – another memory-augmented adaptation baseline.

Use same settings for both baselines as do for our own model.

3.2. Long-Context Language Modeling

This kind of memory storage can come in handy when dealing with large bodies of text, like novels, as
the stored information can help models better understand and generate subsequent parts of the story
based on what they've learned from previous context and character relationships.

Evaluation Setting:

compare LONGMEM and baselines on 3 long-context modeling datasets. The majority of included books
or papers in these datasets have the length of at least 16k tokens. All listed datasets are evaluated in
zero-shot manner without any task-specific tuning.

Project Gutenberg 2020-2022 Language Modeling Dataset: e crawled and cleaned the books published
between 2020 and 2022  PG-22 - was distinct from their training subset, PG-19, in both domains and
writing styles, given that the latter includes books published prior to 1919. We provide different
validation splits of PG-22 based on length range.

ArXiv Dataset: papers in the areas of Math, Computer Science, and Physics. Subset of the Pile is
excluded from our training and is an out-of-distribution dataset. They reported the language modeling
complexity level, or perplexity, for both the PG-22 and ArXiv datasets.
ChapterBreak Benchmark: a challenging suffix identification dataset that requires LLMs to distinguish
the beginning of the ground-truth next chapter from a set of hard negative segments sampled from the
same book, given the long context of previous chapters. Even state-of-the-art x-formers for long text
processing fail to effectively leverage long-range context to perform well on it.

Archive of Our Own (AO3), contains fan-fiction pieces: was divided into 8 sections based on the prefix
length from 0.5k to 8k tokens. Evaluation on 4k, 6k, and 8k prefix. LLM cannot process over 4k tokens 
abandon front prefix.

For their evaluations, they measured the difficulty of each candidate segment using perplexity, and the
one with the lowest perplexity was marked as the correct label.  primary evaluation metric.

Results:

significantly outperform all considered baselines on long-text language modeling datasets, with
improvements of -1.38 to -1.62 perplexity on different length splits of PG-22, and -1.0 ppl on ARXIV
datasets.

Achieve SOTA performance of 40.5% accuracy on the AO3 suffix identification benchmark, outdoing even
powerful long-context transformers and the latest GPT-3 with 313x larger para.

 LONGMEM can comprehend past long-context in cached memory to well complete the language
modeling towards future inputs

3.3. Memory-Augmented In-Context Learning

LLMs have emerging capability of in-context learning (ICL), which involves learning from a few select
examples in a given context.

BUT traditional ICL is heavily restricted by input context length, limiting the amount of supervision that
can be gathered from demonstration examples within the training set.

LongMem, which uses an unlimited-length memory augmentation, can overcome the limitation of the
number of demonstration examples in the local context + even attend on the whole training set by
loading it into the cached memory.

 LONGMEM goes beyond the conventional few-shot in-context learning and realized memory-
augmented in-context learning with thousands of auxiliary demonstration examples.

Evaluation Setting: evaluate on 5 NLU datasets + evaluate models on two few-shot settings, 4-shot (data-
insufficient) and 20-shot (fulfill the 1k input length -> sufficient contextual self-supervisions).
transform the k-shot examples to semantically meaningful demonstration examples via fixed text
template.

di="Review: xi Sentiment: yi", ∀ {( x i , yi ) }i=1 ∈ D trainfor sentiment analysis tasks


k

evaluate 3-shot ICL on question-answering tasks of SQuAD

The demonstrate examples were separated by newlines, and the predicted label was generated using
greedy decoding based on the demonstration examples and test cases in context.

Prediction accuracy = evaluation metric.

report the mean and standard deviation of 6 runs with different random seeds to overcome the
randomness in selecting k-shot demonstration examples.

Chunk-size retrieved can impact the performance  fine-tuned it on the validation set of SST-2. The best
chunk-size =2.

Results: achieves remarkable improvements on all NLU tasks in


20-shot, with +8.0 average scores increase over pretrained GPT-2*
and MemTRM.

Enhanced performance in 4-shot

improved the in-context learning capabilities of LLMs on open-


ended generation tasks, achieving a +4.5 Exact Match (EM) score
increase on SQuAD.

 demonstration examples loaded in cached memory can be regarded as auxiliary contextual


demonstrations, aiding in-context learning. LongMem can utilize the task-relevant knowledge from both
local contextual demonstrations and in-memory augmented demonstrations for improved in-context
learning.

3.4. Ablation Studies

configuration of the memory bank, specifically parameters like memory size (referred to as msz) and
chunk-size (referred to as csz), play a crucial role in how well our system performs.
 perform a series of ablation studies to evaluate the effects of these hyperparameters on task
performance.

Effects of Chunk-Size: fig 4a  chunk size of 2 yields the best performance on in-context learning tasks
on five NLU datasets.

Thử với các chunk-size 2,4,8 nma trên hình 4a lại là 1, 2, 4

Effects of Memory Size: the memory size should be compatible with the average length of documents or
contexts. Test with msz in 8k, 16k, 32k, 65k during the inference stage on the PG-22

To model the books with average 8k-50k length, the smaller memory size 16k which is consistent with
the average length of target books yields the best perplexity.

smaller chunk sizes work best for fine-grained tasks, while memory size should be aligned with the
average length of the documents or contexts being handled.

4. RELATED WORK

LLMs: GPT-2, GPT-3, OPT, and BLOOM have revolutionized NLP research and exhibit "emergent abilities"
like few-shot in-context learning and multi-step reasoning.

x-formers variants of Transformers designed to handle long-range contexts.

X-formers, such as Transformer-XL, LongFormer, and Routing Transformer, propose various


sparse attention mechanisms to decrease complexity, but their efficiency gains are not remarkable for
modeling book-level length sequences.

Side-Tuning: Side-Tuning is a task-specific tuning method for pre-trained models, while LongMem
proposes to augment LLMs with decoupled memory for memorizing long past inputs without task-
specific tuning.

You might also like