Understanding Long Short-Term Memory

Posted on 6/9/2025

Recurrent Neural Networks (RNNs) are a type of neural network designed to process sequential data by retaining an internal memory of the previous inputs. This makes them particularly suited for problems such as speech recognition, non-Markovian control, or music composition.

However, in practice, training RNNs successfully was a challenge. Methods like Backpropagation Through Time (BPTT) or Real-Time Recurrent Learning (RTRL) suffer from vanishing or exploding gradients, making it difficult for RNNs to learn long-range dependencies.

This is the problem that Long Short-Term Memory (LSTM) networks were designed to solve when they were introduced in the Long Short-Term Memory seminal paper by Sepp Hochreiter and Jürgen Schmidhuber in 1997. LSTMs modified the recurrent architecture to create a constant error flow (Constant Error Carousel, or CEC) across time, allowing gradients to propagate over hundreds of steps without vanishing.

In this post, we’ll walk through the motivations, mathematical foundations, and original architectural decisions behind the LSTM formulation. We’ll start by examining why standard RNNs fail to retain long-term dependencies, explain the concept of error signals, and then follow the derivation that led them to the idea of the Constant Error Carousel (CEC), which is the key mechanism that allows LSTMs to overcome the vanishing gradient problem. From there, we’ll explore the full LSTM architecture, its advantages and limitations, how it has been used in practice, and its current relevance.

Exponentially Decaying Error in Conventional BPTT

First, they state that the contribution of the output error to the update of wklw_{kl} is computed using by αϑk(t)yl(t1)\alpha \vartheta_k (t)y^l(t-1) with α\alpha the learning rate, ll an arbitrary unit connected to kk and ϑk\vartheta_k the error signal at unit kk and time tt, which is given by

ϑk(t)=fk(netk(t))(dk(t)yk(t))\vartheta_k(t)=f_k'(net_k(t))(d_k(t)-y^k(t))

for an output unit, for hidden units this error signal is given by

ϑk(t)=fk(netk(t))iwikϑi(t+1)\vartheta_k(t)=f_k'(net_k(t))\sum_iw_{ik}\vartheta_i(t+1)

Yes, that’s a rough start. Let’s take one step back and understand where these formulas come from. First of all we have to understand the difference between error and error signal:

Ek(t)=12(dk(t)yk(t))2E_k(t)=\frac{1}{2}(d_k(t)-y^k(t))^2

where dk(t)d_k(t) is the desired output of unit kk at time tt and yk(t)y^k(t) is the output or activation of unit kk at time tt.

ϑk(t)=Enetk(t)\vartheta_k(t) = \frac{\partial E}{\partial net_k(t)}

Error signal for output units

We can then expand the error signal using the chain rule:

ϑk(t)=Enetk(t)=Eyk(t)yk(t)netk(t)\vartheta_k(t) = \frac{\partial E}{\partial net_k(t)} = \frac{\partial E}{\partial y^k(t)}\frac{\partial y^k(t)}{\partial net_k(t)}

where netknet_k is the weighted sum of all incoming signals before applying the activation function

netk(t)=jwkjyj(t1)net_k(t) = \sum_j w_{kj} y^j(t-1)

and yj(t)y^j(t) and yj(t1)y^j(t-1) are the activations or output from time tt and the previous time step t1t-1 respectively. Note that the real implementation of the RNN would also include the external input x(t)x(t), however the external input does not affect the error signal, so for clarity it is omitted in the paper. We can then compute each of the partial derivatives that make up the error signal after applying the chain rule. First by taking the partial derivative of the error function definition with respect to yk(t)y^k(t)

Eyk(t)=(dk(t)yk(t))\frac{\partial E}{\partial y^k(t)} = -(d_k(t) - y^k(t))

and secondly by taking the partial derivative of yk(t)=fk(netk(t))y^k(t)=f_k(net_k(t)) with respect to netk(t)net_k(t)

yk(t)netk(t)=fk(netk(t))\frac{\partial y^k(t)}{\partial net_k(t)} = f'_k(net_k(t))

where fkf'_k is the derivative of the activation function at unit kk. By multiplying them we obtain the error signal for an output unit kk:

ϑk(t)=fk(netk(t))(dk(t)yk(t))\boxed{\vartheta_k(t) = f'_k(net_k(t))(d_k(t) - y^k(t))}

Note that the minus sign is not included in the paper’s expression because it cancels out with the negative sign from gradient descent, which results in a positive term in the weight update rule.

Computing the output error signal was relatively easy as we have access to the desired outputs and the error function. But hidden units don’t have any of these, so how do we compute their error signals?

Error signal for hidden units

We can compute the error signal for hidden units by backpropagating the error from the units they feed into, this is called backpropagation through time (BPTT). The difficult part when computing the error signal for a hidden unit is the term Eyk(t)\frac{\partial E}{\partial y^k(t)} as there is no error function to be used, instead we have to use the errors from t+1t+1 in order to backpropagate this term.

The output of all hidden units at time tt become part of the inputs to unit kk at t+1t+1, such that

jwkjyj(t)=netk(t+1)\sum_j w_{kj}y^j(t)=net_k(t+1)

If we take the partial derivative of netk(t+1)net_k(t+1) with respect to a unit jj we just retain the jth element of the summation

netk(t+1)yj(t)=wkj\frac{\partial net_k(t+1)}{\partial y^j(t)} = w_{kj}

We can then compute Eyk(t)\frac{\partial E}{\partial y^k(t)} applying the chain rule using the weighted input of the future state neti(t+1)net_i(t+1).

Eyk(t)=iEneti(t+1)neti(t+1)yk(t)\frac{\partial E}{\partial y^k(t)}=\sum_i\frac{\partial E}{\partial net_i(t+1)}\frac{\partial net_i(t+1)}{\partial y^k(t)}

we are basically using the error signals of the next layer to obtain the current one. We can then compute each of the terms. First, by definition:

Eneti(t+1)=ϑi(t+1)\frac{\partial E}{\partial net_i(t+1)} = \vartheta_i(t+1)

then the second term has already been computed before. By multiplying them we get

Eyk(t)=iϑi(t+1)wik\frac{\partial E}{\partial y^k(t)}=\sum_i \vartheta_i(t+1) w_{ik}

which allows us to derive the error signal for a hidden unit as

ϑk(t)=Eyk(t)yk(t)netk(t)=fk(netk(t))iϑi(t+1)wik\boxed{\vartheta_k(t) = \frac{\partial E}{\partial y^k(t)}\frac{\partial y^k(t)}{\partial net_k(t)} = f'_k(net_k(t)) \sum_i \vartheta_i(t+1) w_{ik}}

Note that the weight update equation which is normally defined as

Δwij(t)=αEwij\Delta w_{ij}(t) = -\alpha \frac{\partial E}{\partial w_{ij}}

where by using the chain rule again

Ewij(t)=Enetj(t)netj(t)wij(t)\frac{\partial E}{\partial w_{ij}(t)} = \frac{\partial E}{\partial net_j(t)} \frac{\partial net_j(t)}{\partial w_{ij}(t)}

since netj(t)=iwjiyi(t1)net_j(t)=\sum_i w_{ji} y^i(t-1)

neti(t)wij=yi(t1)\frac{\partial net_i(t)}{\partial w_{ij}} = y^i(t-1)

so that finally the weight update equations is

Δwij(t)=αϑi(t)yj(t1)\Delta w_{ij}(t) = \alpha \vartheta_i(t)y^j(t-1)

note that the minus sign was taken into the output error signal.

Next, we analyze how the error signal from a unit uu at time tt propagates backward through time to affect a unit vv at time tqt-q. When there is only one timestep of difference we can directly compute the partial derivative of the hidden units error signal.

ϑv(t1)ϑu(t)={fv(netv(t1))wuvq=1fv(netv(tq))lϑl(tq+1)ϑu(t)wlvq>1\frac{\partial \vartheta_v(t-1)}{\partial \vartheta_u(t)}= \left\{ \begin{aligned} f'_v(net_v(t-1)) w_{uv} & \quad q=1\\ f'_v(net_v(t-q)) \sum_{l}\frac{\partial \vartheta_l(t-q+1)}{\partial \vartheta_u(t)}w_{lv} & \quad q>1 \end{aligned} \right.

by unrolling the recursiveness the following general formulation can be obtained:

ϑv(tq)ϑu(t)=l1=1nlq1=1nm=1qflm(netlm(tm))wlmlm1\boxed{\frac{\partial \vartheta_v(t - q)}{\partial \vartheta_u(t)} = \sum_{l_1=1}^{n} \cdots \sum_{l_{q-1}=1}^{n} \prod_{m=1}^{q} f_{l_m}'\left( \text{net}_{l_m}(t - m) \right) w_{l_m l_{m-1}}}

with lq=vl_q=v and l0=ul_0=u. This formula is basically computing for all the possible paths that connect uu with vv the product of the weight connecting that path with the activation function derivative and summing them. Therefore there are a total of nq1n^{q-1} product terms, since the summation terms can have different sign increasing the number of units nn will not necessarily increase the error flow.

From that formula we can see that if flm(netlm(tm))wlmlm1>1|f_{l_m}'\left( \text{net}_{l_m}(t - m) \right) w_{l_m l_{m-1}}|>1 for all mm, the product term will increase exponentially with qq. Equally, if flm(netlm(tm))wlmlm1<1|f_{l_m}'\left( \text{net}_{l_m}(t - m) \right) w_{l_m l_{m-1}}|<1 then the product term decreases exponentially. In both cases, learning at the deeper layer will become impossible.

The vanishing local error flow problem can be extrapolated to the global error since this global error is the sum of the of the local error over the output units.

NAIVE CONSTANT ERROR FLOW

The error signal for a unit which is connected only to itself is given by

ϑk(t)=fk(netk(t))wkkϑk(t+1)\vartheta_k(t)=f_k'(net_k(t))w_{kk}\vartheta_k(t+1)

if we enforce a constant error flow through this unit (ϑk(t)=ϑi(t+1)\vartheta_k(t)=\vartheta_i(t+1))

fk(netk(t))wkk=1f_k'(net_k(t))w_{kk}=1

we can then integrate this equation to obtain

fk(netk(t))=netk(t)wkkf_k(net_k(t)) = \frac{net_k(t)}{w_{kk}}

which means that fkf_k must be linear. In the original paper this is called Constant Error Carousel (CEC) and will be a key element in LSTM. Since the units of the CEC will not only be connected to themselves but also other units, a more complex approach is required to control the read and write operations in these units, as otherwise, conflicts may arise during weight updates. These interference and stability problems become worse as the time lag increases

LONG SHORT-TERM MEMORY

The naive Constant Error Carousel (CEC) enables constant error flow, but it lacks control over when information is written to or read from memory. To solve this, a multiplicative input gate was introduced to protect the memory contents of unit kk from being overwritten by irrelevant inputs. Similarly, a multiplicative output gate was added to prevent the memory content from affecting other units when it’s not needed. The resulting unit is called a memory cell, and it is built around the constant error carousel idea. The kth memory cell is referred to as ckc_k.

youtk=fout(netoutj(t))y^{out_k}=f_{out}(net_{out_j}(t)) yink=fin(netinj(t))y^{in_k}=f_{in}(net_{in_j}(t))

where

netoutj(t)=uwoutj,u(t)yu(t1)net_{out_j}(t) = \sum_u w_{out_j,u}(t) y^u(t-1) netinj(t)=uwinj,u(t)yu(t1)net_{in_j}(t) = \sum_u w_{in_j,u}(t) y^u(t-1) netcj(t)=uwcj,u(t)yu(t1)net_{c_j}(t) = \sum_u w_{c_j,u}(t) y^u(t-1)

then, at time tt, the output ycky^{c_k} of cjc_j is computed as

yck=yout,kh(sck(t))y^{c_k} = y^{out,k} h(s_{c_k}(t))

where the internal state is

scj(0)=0s_{c_j}(0) = 0 scj(t)=scj(t1)+yink(t)g(netck(t))s_{c_j}(t) = s_{c_j}(t-1) + y^{in_k}(t)g(net_{c_k}(t))

The function gg squashes the input netck(t)net_{c_k}(t), producing a candidate value to be written into the memory cell. The function hh scales the output based on the internal state scj(t)s_{c_j}(t). We can distinguish two components in the update of scj(t)s_{c_j}(t). The first is the internal state from the previous timestep, scj(t1)s_{c_j}(t-1). The second is the product yink(t)g(netck(t))y^{in_k}(t) \cdot g(net_{c_k}(t)), where g(netck(t))g(net_{c_k}(t)) generates the data to write, and yink(t)y^{in_k}(t) decides whether this data should be written or ignored. Even though the output is passed through the non-linearity hh, the Constant Error Carousel is still respected because the internal state scj(t)s_{c_j}(t) flows linearly through time within the memory cell.

Architecture of the memory cell.
Architecture of the memory cell. Source: Original paper.

Architecture

The networks used in the original paper consist of:

The authors noted that, during early training, the network may overuse the memory cells by keeping the gates open and exploiting the extra capacity. It is also possible for multiple cells to redundantly store the same information. To address these issues, they proposed:

The paper includes a detailed description of six experiments of increasing complexity, where LSTM is benchmarked against traditional RNN training algorithms such as BPTT and RTRL. LSTM demonstrated a stronger capability to capture longer temporal dependencies.

From these experiments, the authors highlight the following advantages and limitations of LSTM:

Disadvantages

Advantages

EVOLUTION, ACHIEVEMENTS AND PRESENT STATE

After the publication of this paper, considerable work was dedicated to further improve the LSTM architecture:

LSTMs have been used in a wide variety of applications, ranging from speech and handwriting recognition to machine translation, time-series forecasting, and creative generation. Bidirectional LSTMs were employed in speech and online handwriting tasks, while Google’s first neural translation system used stacked LSTMs to significantly reduce error rates. They remain widely used in industrial forecasting and on-device anomaly detection, especially when long-range memory is needed but transformers are too resource-intensive.

However, since the rise of transformers in 2017, which scale better on modern hardware and achieve higher accuracy, LSTMs have declined in popularity. Furthermore, newer state-space models like Mamba can capture longer dependencies with fewer parameters.

Footnotes

  1. Gers, F. A., Schmidhuber, J., & Cummins, F. (2000). Learning to forget: Continual prediction with LSTM. Neural computation12(10), 2451-2471.

  2. Gers, F. A., Schraudolph, N. N., & Schmidhuber, J. (2002). Learning precise timing with LSTM recurrent networks. Journal of machine learning research3(Aug), 115-143.

  3. Graves, A., & Schmidhuber, J. (2005). Framewise phoneme classification with bidirectional LSTM and other neural network architectures. Neural networks18(5-6), 602-610.

  4. Beck, Maximilian, Korbinian Pöppel, Markus Spanring, Andreas Auer, Oleksandra Prudnikova, Michael Kopp, Günter Klambauer, Johannes Brandstetter, and Sepp Hochreiter. “xlstm: Extended long short-term memory.” arXiv preprint arXiv:2405.04517 (2024).

Comments

No comments yet.