PyTorch LSTM: Text Generation Tutorial

Key element of LSTM is the ability to work with sequences and its gating mechanism.
Domas Bitvinskas
Domas Bitvinskas
Jun 15, 2020
Long Short Term Memory (LSTM) is a popular Recurrent Neural Network (RNN) architecture. This tutorial covers using LSTMs on PyTorch for generating text; in this case - pretty lame jokes.
For this tutorial you need:
  • Basic familiarity with Python, PyTorch, and machine learning
  • A locally installed Python v3+, PyTorch v1+, NumPy v1+

What is LSTM?

LSTM is a variant of RNN used in deep learning. You can use LSTMs if you are working on sequences of data.
Here are the most straightforward use-cases for LSTM networks you might be familiar with:
  • Time series forecasting (for example, stock prediction)
  • Text generation
  • Video classification
  • Music generation
  • Anomaly detection


Before you start using LSTMs, you need to understand how RNNs work.
RNNs are neural networks that are good with sequential data. It can be video, audio, text, stock market time series or even a single image cut into a sequence of its parts.
Standard neural networks (convolutional or vanilla) have one major shortcoming when compared to RNNs - they cannot reason about previous inputs to inform later ones. You cannot solve some machine learning problems without some kind of memory of past inputs.
For example, you might run into a problem when you have some video frames of a ball moving and want to predict the direction of the ball. The way a standard neural network sees the problem is: you have a ball in one image and then you have a ball in another image. It does not have a mechanism for connecting these two images as a sequence. Standard neural networks cannot connect two separate images of the ball to the concept of “the ball is moving.” All it sees is that there is a ball in the image #1 and that there's a ball in the image #2, but network outputs are separate.
Convolutional Neural Network prediction
Compare this to the RNN, which remembers the last frames and can use that to inform its next prediction.
Recurrent Neural Network prediction


Typical RNNs can't memorize long sequences. The effect called “vanishing gradients” happens during the backpropagation phase of the RNN cell network. The gradients of cells that carry information from the start of a sequence goes through matrix multiplications by small numbers and reach close to 0 in long sequences. In other words - information at the start of the sequence has almost no effect at the end of the sequence.
You can see that illustrated in the Recurrent Neural Network example. Given long enough sequence, the information from the first element of the sequence has no impact on the output of the last element of the sequence.
LSTM is an RNN architecture that can memorize long sequences - up to 100 s of elements in a sequence. LSTM has a memory gating mechanism that allows the long term memory to continue flowing into the LSTM cells.
Long Short Term Memory cell

Text generation with PyTorch

You will train a joke text generator using LSTM networks in PyTorch and follow the best practices. Start by creating a new folder where you'll store the code:
$ mkdir text-generation


To create an LSTM model, create a file
in the
folder with the following content:
import torch
from torch import nn
class Model(nn.Module):
def __init__(self, dataset):
super(Model, self).__init__()
self.lstm_size = 128
self.embedding_dim = 128
self.num_layers = 3
n_vocab = len(dataset.uniq_words)
self.embedding = nn.Embedding(
self.lstm = nn.LSTM(
self.fc = nn.Linear(self.lstm_size, n_vocab)
def forward(self, x, prev_state):
embed = self.embedding(x)
output, state = self.lstm(embed, prev_state)
logits = self.fc(output)
return logits, state
def init_state(self, sequence_length):
return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
torch.zeros(self.num_layers, sequence_length, self.lstm_size))
This is a standard looking PyTorch model.
layer converts word indexes to word vectors.
is the main learnable part of the network - PyTorch implementation has the gating mechanism implemented inside the
cell that can learn long sequences of data.
As described in the earlier What is LSTM? section - RNNs and LSTMs have extra state information they carry between training episodes.
function has a
argument. This state is kept outside the model and passed manually.
It also has
function. Calling this at the start of every epoch to initializes the right shape of the state.


For this tutorial, we use Reddit clean jokes dataset to train the network. Download (139KB) the dataset and put it in the
The dataset has 1623 jokes and looks like this:
1,What did the bartender say to the jumper cables? You better not try to start anything.
2,Don't you hate jokes about German sausage? They're the wurst!
3,Two artists had an art contest... It ended in a draw
To load the data into PyTorch, use PyTorch
class. Create a
file with the following content:
import torch
import pandas as pd
from collections import Counter
class Dataset(
def __init__(
self.args = args
self.words = self.load_words()
self.uniq_words = self.get_uniq_words()
self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
self.words_indexes = [self.word_to_index[w] for w in self.words]
def load_words(self):
train_df = pd.read_csv('data/reddit-cleanjokes.csv')
text = train_df['Joke']' ')
return text.split(' ')
def get_uniq_words(self):
word_counts = Counter(self.words)
return sorted(word_counts, key=word_counts.get, reverse=True)
def __len__(self):
return len(self.words_indexes) - self.args.sequence_length
def __getitem__(self, index):
return (
inherits from the PyTorch's
class and defines two important methods
. Read more about how
classes work in PyTorch Data loading tutorial.
function loads the dataset. Unique words are calculated in the dataset to define the size of the network's vocabulary and embedding size.
converts words to number indexes and visa versa.
This is part of the process is
. In the future, torchtext team plan to improve this part, but they are re-designing it and the new API is too unstable for this tutorial today.


Create a
file and define a
import argparse
import torch
import numpy as np
from torch import nn, optim
from import DataLoader
from model import Model
from dataset import Dataset
def train(dataset, model, args):
dataloader = DataLoader(dataset, batch_size=args.batch_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(args.max_epochs):
state_h, state_c = model.init_state(args.sequence_length)
for batch, (x, y) in enumerate(dataloader):
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
loss = criterion(y_pred.transpose(1, 2), y)
state_h = state_h.detach()
state_c = state_c.detach()
print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })
Use PyTorch
abstractions to load the jokes data.
as a loss function and
as an optimizer with default params. You can tweak it later.
In his famous post Andrew Karpathy also recommends keeping this part simple at first.

Text generation

function to the
def predict(dataset, model, text, next_words=100):
words = text.split(' ')
state_h, state_c = model.init_state(len(words))
for i in range(0, next_words):
x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
last_word_logits = y_pred[0][-1]
p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
word_index = np.random.choice(len(last_word_logits), p=p)
return words

Execute predictions

Add the following code to
file to execute the defined functions:
parser = argparse.ArgumentParser()
parser.add_argument('--max-epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--sequence-length', type=int, default=4)
args = parser.parse_args()
dataset = Dataset(args)
model = Model(dataset)
train(dataset, model, args)
print(predict(dataset, model, text='Knock knock. Whos there?'))
Run the
script with:
$ python
You can see the loss along with the epochs. The model predicts the next 100 words after
Knock knock. Whos there?
when the training finishes. By default, it runs for 10 epochs and takes around 15 mins to finish training.
{'epoch': 9, 'batch': 91, 'loss': 5.953955173492432}
{'epoch': 9, 'batch': 92, 'loss': 6.1532487869262695}
{'epoch': 9, 'batch': 93, 'loss': 5.531163215637207}
['Knock', 'knock.', 'Whos', 'there?', '3)', 'moostard', 'bird', 'Book,',
'What', 'when', 'when', 'the', 'Autumn', 'He', 'What', 'did', 'the',
'psychologist?', 'And', 'look', 'any', 'jokes.', 'Do', 'by', "Valentine's",
'Because', 'I', 'papa', 'could', 'believe', 'had', 'a', 'call', 'decide',
'elephants', 'it', 'my', 'eyes?', 'Why', 'you', 'different', 'know', 'in',
'an', 'file', 'of', 'a', 'jungle?', 'Rock', '-', 'and', 'might', "It's",
'every', 'out', 'say', 'when', 'to', 'an', 'ghost', 'however:', 'the', 'sex,',
'in', 'his', 'hose', 'and', 'because', 'joke', 'the', 'month', '25', 'The',
'97', 'can', 'eggs.', 'was', 'dead', 'joke', "I'm", 'a', 'want', 'is', 'you',
'out', 'to', 'Sorry,', 'the', 'poet,', 'between', 'clean', 'Words', 'car',
'his', 'wife', 'would', '1000', 'and', 'Santa', 'oh', 'diving', 'machine?',
'He', 'was']
If you skipped to this part and want to run the code, here's a Github repository you can clone.

Next steps

Congratulations! You've written your first PyTorch LSTM network and generated some jokes.
Here's what you can do next to improve the model:
  • Clean up the data by removing non-letter characters.
  • Increase the model capacity by adding more
  • Split the dataset into train, test, and validation sets.
  • Add checkpoints so you don't have to train the model every time you want to run prediction.
Latest Machine Learning posts

ELU Activation Function

Jul 21, 2020