I recently had the luxury/annoyance of having a short layover in the Hartsfield-Jackson Atlanta International Airport. Having basically not heard of this airport (long story), it was not apparent upon landing that I would be walking for half an hour to get to my next flight, or that I would be waiting among crowds of people from Sweden or Nigeria. Amidst the chaos and strss, I was reminded that the Earth is an immense planet that one person can never completely explore, and that this is true in more ways than one.
You’ll find from reading more of Horseless that I’m a generalist. It is of nothing. Right now, I’m doing the fast.ai course v3, and today I’m following along with lesson 1. Special thanks to Salamander, because frankly, AWS has some work to do on the front end.
It’s been a little while since I’ve used Jupyter notebooks (I’m writing all of this in one right now to get the hang of it again), so I totally forget about the existence of magics. There are tons, but the ones that are relevant to this image classification example are as follows.
We use these magics to reload modules automatically if they have been changed. If that sounds weird, there’s a brief, straighforward example in the docs.
%reload_ext autoreload
%autoreload 2
A lot of magics can also be used in IPython, but this one is specific to Jupyter notebooks. This just allows us to display matplotlib graphs in the notebook itself.
%matplotlib inline
Naturally, we have a couple of library imports. I’ve gotten into the habit of at least peaking at the docs whenever I run into a new library, especially when we’re importing everything. The documentation for fastai.vision is actually pretty short and readable.
from fastai.vision import *
from fastai.metrics import error_rate
bs just sets the batch size.
bs = 64
Deception!¶
This blog entry isn’t completely following the pets example. I’m also looking at an image download guide that I guess is from the next lesson. For this, I wanted to build a classifier of fruit using scraped images and see how well it scales with the number of categories and such.
Well, that’s what’s going on here. Basically, we make a Path object to hold all of our different classes. Classes here are assigned new folders; this snippet will make a new folder for each class if it doesn’t already exist. It looks like overkill here, and even the original tutorial in the image download tutorial only had three different classes for bears, but part of my exploration of lesson one will be looking at how the model responds to different numbers of classes.
path = Path("data/fruit1")
folders = ['apples','bananas', 'oranges']
for folder in folders:
dest = path/folder
dest.mkdir(parents=True, exist_ok=True)
If you don’t recognize ls
as a CLI command, this just gets a list of all the non-hidden files and folders in the current working directory. I’ve probably run this notebook half a dozen times, so my original view of it was pretty thoroughly populated. You’ll see in a minute why having multiple folders for each of these runs was a good idea.
path.ls()
This next bit is where we get into some annoying stuff in the tutorial that I want to come back and fix later. We need data to many anything happen here, and a quick way to get data for our purposes here is to run a search on Google Images for, let’s say, apples. The tutorial wants us to save those URLs, and to do that, we need a bit of code:
urls = Array.from(document.querySelectorAll('.rg_di .rg_meta')).map(el=>JSON.parse(el.textContent).ou);
window.open('data:text/csv;charset=utf-8,' + escape(urls.join('\n')));
It’s JavaScript, my old nemesis. The tutorial just says that this is to be entered into the browser’s ‘Console’, and suggests that Ctrl + Shift + J brings this up for Firefox. Maybe I spent too much time messing with the settings or something, because that brought up a (for this) worthless display console. For Firefox, I needed to go to Tools -> Web Developer -> Web Console to enter this code, which I only later learned was Shift + Ctrl + K.
Anyway, this needs to be repeated for each class. It’s automatish. The video links to a different set of tips for downloading images using a tool called googleimagesdownload that I had a chance to use once, but getting that working involved enough installation, weird complications, and documentation that would be better suited for another entry. We’re not fighting Raktabija today.
This snippet is straightforward enough. Go through each of the classes and use download_images to download up to 200 images for each one. Yes, I’m leaving those prints in. Full disclosure, I was looking at three different examples and was way more confused than I should have been for about 20 minutes. Those things are heroes.
That said, funny story, true story. For some reason, my urls_apples.txt file has 200 links, twice as many as the others here. I don’t know why. Maybe I scrolled down farther in the Google Images results than the others and the code included the extra ones that were dynamically loaded on the page. Maybe I was playing with a different piece of code when I was first doing this at 0230 and forgot about it. Maybe I’m a crazy man on the Internet who doens’t know what he’s talking about because it works on your machine.
The thing is, this isn’t trivial! My original run had twice as many examples of apples to draw from as anything else, which could easily change its performance. My point is, be careful. There’s a lot of work to do here.
for folder in folders:
file = "urls_" + folder + ".txt"
print(file)
print(folder)
print(path)
download_images(path/file, path/folder, max_pics = 100)
verify_images is used to do a bit of preliminary cleanup. Here, it will delete images that are broken and resize them so that each one is no larger than 500 pixels on a side. The documentation promises that the original ratios are preserved, so we should be good to go.
for f in folders:
print(f)
verify_images(path/f, delete=True, max_size=500)
Seriously, that many?! I’m going to be using this thing all the time.
Get on with it!¶
We set the random seed for reproducibility. I remember this being wonky when I did this in Keras, but we’ll roll with it for now.
I do recommend having a look at the documentation for ImageDataBunch. What’s going on here? Images from the specified folder are collected. valid_pct
holds out 20% of the images for validation (0.2 is provided as a default if you leave it out). ds_tfms=get_transforms()
is gives us a set of default data augmentation features like flipping the images on the axis.
np.random.seed(42)
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2,
ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)
Just a quick peek to see where we are that we have the right classes and see if anything obvious weird happened with our images.
data.classes
data.show_batch(rows=3, figsize=(7,8))
This is the juicy part (but it does look a bit light if you’re familiar with Keras). Here we put our dataset into a convolutional neural network. The model that we’re using here is the pretrained ResNet-34, a midsized residual network. If you’re into the background, I recommend familiarizing yourself with the paper Deep Residual Learning for Image Recognition, which I’ll be talking about in a later post. There is a lot to unpack here.
learn = cnn_learner(data, models.resnet34, metrics=error_rate)
Now, we just train the learner for four epochs.
learn.fit_one_cycle(4)
Now we save these weights, because we’re about to see what happens when we train the entire model. Up until now, we’ve only trained a portion of the network, but unfreeze()
will allow us to train on all of the weights. lr_find()
will train with a number of different learnings rates.
learn.save('fruit-1')
learn.unfreeze()
learn.fit_one_cycle(1)
This bit of code will graph the results of lr_find()
.
learn.lr_find()
learn.recorder.plot()
We run fit one cycle, using values from the graph where the loss starts to be minimized.
learn.fit_one_cycle(3, max_lr=slice(3e-4, 3e-3))
… Okay! The training loss and error rate are decreasing, but the validation loss is on the rise. The network might be overfitting a bit here, but we shall see how it performs. I’m not especially confident that we have a total improvement here, so we will save this as a separate model.
learn.save('fruit-2')
ClassificationInterpretation can take a Learner and plot a confusion matrix, a grid indicating what the model predicted a given image was versus what it actually was. Because of the way the graphic is set up, it is visually easy to see that a model has achieved a low error rate by looking at its characteristic solid diagonal line with maybe a few missed cases here and there.
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
Expanded number of classes.¶
Not bad, but we can try something a bit more complicated. Let’s repeat the above steps using nine categories instead of three.
path = Path("data/fruit")
folders = ['apples', 'bananas', 'oranges', 'grapes', 'pears', 'pineapples', 'nectarines', 'kiwis', 'plantains']
for folder in folders:
dest = path/folder
dest.mkdir(parents=True, exist_ok=True)
path.ls()
for folder in folders:
file = "urls_" + folder + ".txt"
download_images(path/file, path/folder, max_pics = 200)
for f in folders:
print(f)
verify_images(path/f, delete=True, max_size=500)
data2 = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2,
ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)
learn2 = cnn_learner(data2, models.resnet34, metrics=error_rate)
learn2.fit_one_cycle(4)
learn2.save('large-fruit-1')
interp2 = ClassificationInterpretation.from_learner(learn2)
interp2.plot_confusion_matrix()
Much more interesting. Even without the unfreezing and optimization, we see decent results with a larger number of categories. Additionally, where you see some more prominent spots of confusing, it is worth actually thinking about what might be tripping the model up. Depending on the image, a nectarine can sort of look like an apple. Being from the US, plantains to me are basically bananas that are always green.
You’ll notice that we didn’t do that unfreeze-and-optimize strategy from before, so let’s see what that looks like.
learn2.unfreeze()
learn2.fit_one_cycle(1)
learn2.lr_find()
learn2.recorder.plot()
Well, that sure looks like a different loss graph.
learn2.save('large-fruit-2')
learn2.load('large-fruit-2')
learn2.fit_one_cycle(3, max_lr=slice(1e-5, 1e-3))
interp2 = ClassificationInterpretation.from_learner(learn2)
interp2.plot_confusion_matrix()
Training loss has been cut significantly, but validation loss and error rates are holding steady. There is a lot of unpacking to be done. Early days!