Stochastic Gradient Descent (SGD) with PyTorch and
In previous posts we implemented Gradient Descent in Ruby to approximate a quadratic function, extended it to any function with the help of PyTorch tensors and gradients calculation and made a simple digit classifier.
In this post we’ll build on top of those posts and introduce Stochastic Gradient Descent while leveraging further PyTorch and libraries to simplify the code we’ve been using so far.
Stochastic Gradient Descent
According to the Cambridge Dictionary definition, a stochastic process or system is connected with random probability. Why would we add randomness in the Gradient Descent we’ve seen so far? It’s all about making the process faster, specially for large datasets.
As a reminder, in previous posts we generated a x
and y
pair and ran all our inputs x
through our neural network to get some predictions preds
which we would later compare to our expected outputs y
preds = model(x)
In our last post model
was a linear function at first:
def linear1(x): return x@weights + bias
which we later switched to a neural network to improve our predictions:
def simple_neural_net(x):
res = x@w1 + b1
res = res.sigmoid()
res = res@w2 + b2
return res
We can see that the amount of parameters (weights and biases) and activation functions (sigmoid in this sample) vary depending on what model function we are using in our error optimisation through Gradient Descent. As we increase the hidden layers of a neural network we increase the number of gradients we need to calculate for each input. Additionaly we can increase the dimensions of our input (1 for a quadratic function, 784 for a simple 28x28 grayscale png image).
The simple digit classifier sample (which uses ~12k samples for its training) with no GPU acceleration already takes 13.8 seconds to train in my MacBook Pro. With such a small sample my CPU is already struggling!
That’s where Stochastic Gradient Descent kicks in: it uses randomised chunks of our x
and y
input dataset pairs to calculate the gradient on the loss and adjusts our parameters before seeing our whole dataset. E.g. if we split our dataset into 100 chunks, Stochastic Gradient Descent will update the parameters 100 times before it did a full run on our data. That implies that once we did a full epoch (we ran through all data) we already updated our weights and biases 100 times as opposed to only just once while still converging due to the law of large numbers.
We can use the PyTorch Dataloders and Datasets to get an iterator of our x
and y
A Dataset is just a list of key-value pairs (in its simplest form, there is also a Dataset class for more structured and efficient data handling):
x = [1, 2, 3, 4]
y = [1, 0, 1, 1]
dset = list(zip(x,y))
# [(1, 1), (2, 0), (3, 1), (4, 1)]
And a Dataloader gives us an iterator that can go through all data shuffling it randomly in batches of fixed size:
dl = DataLoader(dset, batch_size=2, shuffle=True)
# (tensor([3, 1]), tensor([1, 1])), # random batch 1
# (tensor([4, 2]), tensor([1, 0])) # random batch 2
With that, we are ready to go from running through the whole dataset each epoch:
preds = model(x)
loss = loss_fx(preds, y)
with torch.no_grad(): parameters -= parameters.grad * lr
To doing it in batches, hence adjusting our parameters data_length / batches
times before we do a full epoch:
ds = list(zip(x, y))
dl = DataLoader(ds, batch_size=256, shuffle=True)
for xb, yb in dl
preds = model(xb)
loss = loss_fx(preds, yb)
with torch.no_grad(): parameters -= parameters.grad * lr
Only applying this change I was able to re-run the digit classifier sample and reduced the runtime from 14 to 4 seconds! A side-effect of running multiple adjustments per epoch is that we will need less epochs to optimise our parameters enough so that we get a good accuracy. In this sample I was able to achieve the same accuracy result (94%) with 2 epochs and a 256 batch size as I did with 100 epochs and no batching at all (with a fixed learning rate).
PyTorch and constructs
There are plenty of modules that PyTorch and provide to ease the process we’ve described in the last posts.
The first simplification we can do is around how we define our linear model:
-def init_params(size, std=1.0): return (torch.randn(size)*std).requires_grad_()
-def linear1(x): return x@weights + bias
-weights = init_params((28*28,1))
-bias = init_params(1)
+linear1 = nn.Linear(28*28,1)
+weights, bias = linear1.parameters()
Same principle applies to our neural network:
-def init_params(size, std=1.0): return (torch.randn(size)*std).requires_grad_()
-def simple_net(xb):
- res = xb@w1 + b1
- res = res.sigmoid()
- res = res@w2 + b2
- return res
-w1 = init_params((28*28,50))
-b1 = init_params(50)
-w2 = init_params((50,1))
-b2 = init_params(1)
+simple_net = nn.Sequential(
+ nn.Linear(28*28,50),
+ nn.Sigmoid(),
+ nn.Linear(50,1)
+w1, b1, w2, b2 = simple_net.parameters()
Another useful abstraction we can use is in the optimisation step where fastai provides SGD class to handle it:
-with torch.no_grad(): parameters -= parameters.grad * lr
-parameters.grad = None
+opt = SGD(simple_net.parameters(), lr)
We can go one step further with and actually user a learner for the training process with no custom training code at all:
-for xb, yb in dl
- preds = simple_net(xb)
- loss = loss_fx(preds, yb)
- loss.backward()
- with torch.no_grad(): parameters -= parameters.grad * lr
- parameters.grad.zero_()
+dls = DataLoaders(dl, ()) # skipping the validation set and accuracy
+learn = Learner(dls, simple_net, opt_func=SGD,
+ loss_func=mnist_loss), lr=0.01)
Full example
The entire digit classifier example would turn into the following (non-refactored for the sake of explicitness) code:
Import dependencies
from import *
Data setup
path = untar_data(URLs.MNIST)
# fives
fives_filenames = (path/'training'/'5').ls().sorted()
fives_tensors = [tensor( for o in fives_filenames]
fives = torch.stack(fives_tensors).float()/255
fours_filenames = (path/'training'/'4').ls().sorted()
fours_tensors = [tensor( for o in fours_filenames]
fours = torch.stack(fours_tensors).float()/255
# fours
validation_fives_filenames = (path/'testing'/'5').ls().sorted()
validation_fives_tensors = [tensor( for o in validation_fives_filenames]
validation_fives = torch.stack(validation_fives_tensors).float()/255
validation_fours_filenames = (path/'testing'/'4').ls().sorted()
validation_fours_tensors = [tensor( for o in validation_fours_filenames]
validation_fours = torch.stack(validation_fours_tensors).float()/255
# x and y for training and validation
x =[fives, fours]).view(-1, 28*28)
y = tensor([1]*len(fives) + [0]*len(fours)).unsqueeze(1)
validation_x =[validation_fives, validation_fours]).view(-1, 28*28)
validation_y = tensor([1]*len(validation_fives) + [0]*len(validation_fours)).unsqueeze(1)
# datasets and dataloaders
dset = list(zip(train_x,train_y))
dl = DataLoader(dset, batch_size=256, shuffle=True)
valid_dset = list(zip(valid_x,valid_y))
valid_dl = DataLoader(valid_dset, batch_size=256)
dls = DataLoaders(dl, valid_dl)
Functions used during training
def mnist_loss(predictions, targets):
predictions = predictions.sigmoid()
return torch.where(targets==1, 1-predictions, predictions).mean()
def batch_accuracy(xb, yb):
preds = xb.sigmoid()
correct = (preds>0.5) == yb
return correct.float().mean()
Training with simple neural net
simple_net = nn.Sequential(
learn = Learner(dls, simple_net, opt_func=SGD,
loss_func=mnist_loss, metrics=batch_accuracy), lr=0.01)
Leveraging further
We could go a few steps further and leverage a vision_learner
from with its ImageDataLoaders
that handle all the data loading (for all numbers), training, accuracy calculation and fitting for us within only 3 lines of code (and a 18-layers architecture):
dls = ImageDataLoaders.from_folder(path, train='training', valid='testing')
learn = vision_learner(dls, resnet18, pretrained=False,
loss_func=F.cross_entropy, metrics=accuracy)
learn.fit_one_cycle(1, 0.1) # fit with learning rate scheduling for quicker convergence
I’ve deployed another sample leveraging abstractions further to finetune a resnet18 pretrained model to recognise LOTR characters from photos with a working custom deployment.
