Long Short Term Memory (LSTM) is a popular Recurrent Neural Network (RNN)
architecture. This tutorial covers using LSTMs on PyTorch for generating
text; in this case - pretty lame jokes.
For this tutorial you need:
- Basic familiarity with Python, PyTorch, and machine learning
- A locally installed Python v3+,
PyTorch v1+, NumPy v1+
What is LSTM?
LSTM is a variant of RNN used in deep learning. You can use LSTMs if you
are working on sequences of data.
Here are the most straightforward use-cases for LSTM networks you might be
familiar with:
- Time series forecasting (for example, stock prediction)
- Text generation
- Video classification
- Music generation
- Anomaly detection
RNN
Before you start using LSTMs, you need to understand how RNNs work.
RNNs are neural networks that are good with sequential data. It can be video,
audio, text, stock market time series or even a single image cut into a
sequence of its parts.
Standard neural networks (convolutional or vanilla) have one major shortcoming
when compared to RNNs - they cannot reason about previous inputs to inform
later ones. You cannot solve some machine learning problems without
some kind of memory of past inputs.
For example, you might run into a problem when you have some video frames
of a ball moving and want to predict the direction of the ball.
The way a standard neural network sees the problem is: you have a ball in
one image and then you have a ball in another image.
It does not have a mechanism for connecting these two images as a sequence.
Standard neural networks cannot connect two separate images of the ball to
the concept of “the ball is moving.” All it sees is that there is a ball in the
image #1 and that there's a ball in the image #2, but network outputs
are separate.
Convolutional Neural Network prediction
Compare this to the RNN, which remembers the last frames and can use that
to inform its next prediction.
Recurrent Neural Network prediction
LSTM vs RNN
Typical RNNs can't memorize long sequences. The effect called “vanishing
gradients” happens during the backpropagation phase of the RNN cell network.
The gradients of cells that carry information from the start of a
sequence goes through matrix multiplications by small
numbers and reach close to 0 in long sequences.
In other words - information at the start of the sequence has almost
no effect at the end of the sequence.
You can see that illustrated in the Recurrent Neural Network example.
Given long enough sequence, the information from the first element
of the sequence has no impact on the output of the last element of
the sequence.
LSTM is an RNN architecture that can memorize long sequences - up to 100 s of
elements in a sequence. LSTM has a memory gating
mechanism that allows the long term memory to continue flowing into the LSTM
cells.
Long Short Term Memory cell
Text generation with PyTorch
You will train a joke text generator using LSTM networks in PyTorch and follow
the best practices. Start by creating a new folder where you'll store the code:
Model
To create an LSTM model, create a file
model.py
in the
text-generation
folder
with the following content:
import torch
from torch import nn
class Model(nn.Module):
def __init__(self, dataset):
super(Model, self).__init__()
self.lstm_size = 128
self.embedding_dim = 128
self.num_layers = 3
n_vocab = len(dataset.uniq_words)
self.embedding = nn.Embedding(
num_embeddings=n_vocab,
embedding_dim=self.embedding_dim,
)
self.lstm = nn.LSTM(
input_size=self.lstm_size,
hidden_size=self.lstm_size,
num_layers=self.num_layers,
dropout=0.2,
)
self.fc = nn.Linear(self.lstm_size, n_vocab)
def forward(self, x, prev_state):
embed = self.embedding(x)
output, state = self.lstm(embed, prev_state)
logits = self.fc(output)
return logits, state
def init_state(self, sequence_length):
return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
torch.zeros(self.num_layers, sequence_length, self.lstm_size))
This is a standard looking PyTorch model.
Embedding
layer
converts word indexes to word vectors.
LSTM
is the main learnable part of the
network - PyTorch implementation has the gating mechanism implemented inside
the
LSTM
cell that can learn long sequences of data.
As described in the earlier
What is LSTM? section - RNNs and
LSTMs have extra state information they carry between training episodes.
forward
function has a
prev_state
argument. This state is kept
outside the model and passed manually.
It also has
init_state
function. Calling this at the start of every epoch
to initializes the right shape of the state.
Dataset
For this tutorial, we use Reddit clean jokes dataset to train the network.
Download (139KB)
the dataset and put it in the
text-generation/data/
folder.
The dataset has 1623 jokes and looks like this:
ID,Joke
1,What did the bartender say to the jumper cables? You better not try to start anything.
2,Don't you hate jokes about German sausage? They're the wurst!
3,Two artists had an art contest... It ended in a draw
…
To load the data into PyTorch, use PyTorch
Dataset
class. Create
a
dataset.py
file with the following content:
import torch
import pandas as pd
from collections import Counter
class Dataset(torch.utils.data.Dataset):
def __init__(
self,
args,
):
self.args = args
self.words = self.load_words()
self.uniq_words = self.get_uniq_words()
self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
self.words_indexes = [self.word_to_index[w] for w in self.words]
def load_words(self):
train_df = pd.read_csv('data/reddit-cleanjokes.csv')
text = train_df['Joke'].str.cat(sep=' ')
return text.split(' ')
def get_uniq_words(self):
word_counts = Counter(self.words)
return sorted(word_counts, key=word_counts.get, reverse=True)
def __len__(self):
return len(self.words_indexes) - self.args.sequence_length
def __getitem__(self, index):
return (
torch.tensor(self.words_indexes[index:index+self.args.sequence_length]),
torch.tensor(self.words_indexes[index+1:index+self.args.sequence_length+1]),
)
This
Dataset
inherits from the PyTorch's
torch.utils.data.Dataset
class
and defines two important methods
__len__
and
__getitem__
. Read more about
how
Dataset
classes work in PyTorch
Data loading tutorial.
load_words
function loads the dataset. Unique words are calculated
in the dataset to define the size of the network's vocabulary and embedding
size.
index_to_word
and
word_to_index
converts
words to number indexes and visa versa.
This is part of the process is
tokenization
.
In the future,
torchtext team plan to
improve this part, but they are re-designing it and the
new API is too unstable for this tutorial today.
Training
Create a
train.py
file and define a
train
function.
import argparse
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
from model import Model
from dataset import Dataset
def train(dataset, model, args):
model.train()
dataloader = DataLoader(dataset, batch_size=args.batch_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(args.max_epochs):
state_h, state_c = model.init_state(args.sequence_length)
for batch, (x, y) in enumerate(dataloader):
optimizer.zero_grad()
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
loss = criterion(y_pred.transpose(1, 2), y)
state_h = state_h.detach()
state_c = state_c.detach()
loss.backward()
optimizer.step()
print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })
Use PyTorch
DataLoader
and
Dataset
abstractions to load the jokes data.
Use
CrossEntropyLoss
as a loss function and
Adam
as an optimizer with
default params. You can tweak it later.
In his famous
post Andrew
Karpathy also recommends keeping this part simple at first.
Text generation
Add
predict
function to the
train.py
file:
def predict(dataset, model, text, next_words=100):
model.eval()
words = text.split(' ')
state_h, state_c = model.init_state(len(words))
for i in range(0, next_words):
x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
last_word_logits = y_pred[0][-1]
p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
word_index = np.random.choice(len(last_word_logits), p=p)
words.append(dataset.index_to_word[word_index])
return words
Execute predictions
Add the following code to
train.py
file to execute the defined functions:
parser = argparse.ArgumentParser()
parser.add_argument('--max-epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--sequence-length', type=int, default=4)
args = parser.parse_args()
dataset = Dataset(args)
model = Model(dataset)
train(dataset, model, args)
print(predict(dataset, model, text='Knock knock. Whos there?'))
Run the
train.py
script with:
You can see the loss along with the epochs. The model
predicts the next 100 words after
Knock knock. Whos there?
when the training
finishes. By default, it runs for 10 epochs and
takes around 15 mins to finish training.
{'epoch': 9, 'batch': 91, 'loss': 5.953955173492432}
{'epoch': 9, 'batch': 92, 'loss': 6.1532487869262695}
{'epoch': 9, 'batch': 93, 'loss': 5.531163215637207}
['Knock', 'knock.', 'Whos', 'there?', '3)', 'moostard', 'bird', 'Book,',
'What', 'when', 'when', 'the', 'Autumn', 'He', 'What', 'did', 'the',
'psychologist?', 'And', 'look', 'any', 'jokes.', 'Do', 'by', "Valentine's",
'Because', 'I', 'papa', 'could', 'believe', 'had', 'a', 'call', 'decide',
'elephants', 'it', 'my', 'eyes?', 'Why', 'you', 'different', 'know', 'in',
'an', 'file', 'of', 'a', 'jungle?', 'Rock', '-', 'and', 'might', "It's",
'every', 'out', 'say', 'when', 'to', 'an', 'ghost', 'however:', 'the', 'sex,',
'in', 'his', 'hose', 'and', 'because', 'joke', 'the', 'month', '25', 'The',
'97', 'can', 'eggs.', 'was', 'dead', 'joke', "I'm", 'a', 'want', 'is', 'you',
'out', 'to', 'Sorry,', 'the', 'poet,', 'between', 'clean', 'Words', 'car',
'his', 'wife', 'would', '1000', 'and', 'Santa', 'oh', 'diving', 'machine?',
'He', 'was']
If you skipped to this part and want to run the code, here's a Github
repository
you can clone.
Next steps
Congratulations! You've written your first PyTorch LSTM network and generated
some jokes.
Here's what you can do next to improve the model:
- Clean up the data by removing non-letter characters.
- Increase the model capacity by adding more
Linear
or LSTM
layers. - Split the dataset into train, test, and validation sets.
- Add checkpoints so you don't have to train the model every time you want
to run prediction.