machine-learning transformers



Abstract

This paper introduces a structured memory which can be easily integrated into a neural network. The memory is very large by design and significantly increases the capacity of the architecture, by up to a billion parameters with a negligible computational overhead. Its design and access pattern is based on product keys, which enable fast and exact nearest neighbor search. The ability to increase the number of parameters while keeping the same computational budget lets the overall system strike a better trade-off between prediction accuracy and computation efficiency both at training and test time. This memory layer allows us to tackle very large scale language modeling tasks. In our experiments we consider a dataset with up to 30 billion words, and we plug our memory layer in a state-of-the-art transformer-based architecture. In particular, we found that a memory augmented model with only 12 layers outperforms a baseline transformer model with 24 layers, while being twice faster at inference time. We release our code for reproducibility purposes.

Authors

Contributions

  • We introduce a new layer that provides a large capacity to a neural network for only a slight computational overhead both at train and test time.
  • Our fast indexing strategy offers exact nearest neighbor search by construction, and avoids the pitfall of relying on an indexing structure that needs to be re-learned during training.
  • We demonstrate our method within a large state-of-the-art transformer, composed of 24 layers of dimension 1600. Our method with 1 memory and 12 layers outperforms a 24- layer transformer while being twice faster at inference time. We show that adding more memory layers to transformers of various complexities provides systematic and significant improvements on our target task.

Implementation

Motivation

Neural networks, particularly transformers, need more parameters to handle complex tasks, but adding parameters slows doen training/inference.

The goal of this paper is to add a large memory to store more knowledge (parameters) without slowing down computation.

I’m wary of drawing too detailed of a comparison to the brain, but I think an analogy is warranted to motivate the use of disparate kinds of computation in natural language processing models. As humans, the process of recalling a specific fact or event feels nearly instantaneous. Any owner of a brain can attest that recalling specific memories doesn’t require an exhaustive linear search through past memories. 1

Query Network

The query network converts input tokens into a query vector.

The query network is analogous to the query projection in self-attention – it could be as simple as the linear projection of the hidden state using in self-Attention or could be parameterized by a small MLP. 1

Footnotes

  1. https://www.pragmatic.ml/large-memory-layers-with-product-keys/ 2