DendroNN, or the Resurrection of the Tree Classification Project

TreeID Dataset

Collecting and processing the tree data was a big learning experience. That’s the encouraging, optimistic way of saying that I made a ton of mistakes that I will never, ever repeat because I cringe at how obvious they feel in retrospect. This is still under active development and has attracted some outside interest, so you can follow along at the GitHub repo. So, what are we looking at here?

How was the data collected?

The beginning of this project was a collection of nearly 5,000 photos around my own neighborhood in Cleveland at the end of 2020. Not being an arborist, or even someone who leaves the house much, I hadn’t anticipated how difficult it would be to learn how to identify trees from just bark at the time. Because it was only of personal interest at the time, the project was put on ice.

At the end of 2022, I learned of a couple of helpful projects. The first and primary source of my data, the City of Pittsburgh paid for a tree inventory a few years ago. The species of tree, its GPS-tagged location, and other relevant data can be accessed as part of its Burghs Eye View program. It’s worth noting that Burghs Eye View isn’t just about trees, but is an admirable civic data resource in general. The second, and one which I unfortunately didn’t have time to use, is Falling Fruit, which has a bit of a different aim.

Using the Burghs Eye View map, I took a trip to Pittsburgh and systematically collected several thousand photos of tree bark using a normal smart phone camera. The procedure was simple:

  1. Start taking regularly spaced photos from the roots, less than a meter away from the tree bark, and track upward.
  2. When the phone it at arm’s reach, angle it upward to get one more shot of the canopy.
  3. Step to another angle of the tree and repeat, usually capturing about six angles of an adult tree.

Mistake 1 – Images too Large

Problem

The photos taken in Pittsburgh had a resolution of 3000×4000. An extremely common preprocessing technique in deep learning is scaling images down to, e.g. 224×224 or 384×384. Jeremy Howard in the Fast.ai course even plays with this, and developed a technique called progressive upscaling; images are scaled to 224×224, a model is trained on them, and then the model is trained on images that were scaled to a resolution more like 384×384.

I spent a lot of time trying to make this work, to the point where I started using cloud compute services to handle much larger images, to no avail. Ready to give up, I scoured my notebooks and noticed that some of the trees that the model was confusing more than the others looked very similar when scaled down to that size. It occurred to me that a lot of important fine details of bark was probably getting lost in that kind of compression. Okay, but so what?

Solution

Cutting the images up. I suspect this solution is somewhat specific to problems like this one. It doensn’t seem like it would be useful for classification purposes if you were to cut photos of fruit or houses into many much smaller patches. But tree bark is basically just a repeating texture. Even before realizing the next mistake, I noticed improvements by scaling down to 500×500 images, then finally a more drastic improvement by going down to 250×250.

In some ways this gave me a lot more data. If you follow the match, a 3000×4000 image becomes at most 192 usable 250×250 patches. I at first thought it was a little suspicious, but I looked around and doing it this way isn’t without precedent. There is a Kaggle competition for histopathological cancer detection where this technique comes up, for instance.

Mistake 2 – Too Much Extraneous Data

Problem

A non-trivial amount of this data got thrown out. At the time, I was working on the impression that the AI would need to take in as much data as it could. What I hadn’t considered was that some kinds of tree data, even some kinds of bark data, would be substantially more useful than others in classifying the trees. To borrow a data science idea, I hadn’t done a principle component analysis and wasted a lot of time. Many early training sessions were spent trying to get the model to classify trees based on images that included:

  • Roots and irregular branches
  • Soil and stray leaves
  • Tree canopy that wasn’t specific
  • Excessively blurry images
  • Tree bark that was covered in lichen or moss, damaged, diseased
  • Potentially useful, but at an angle or a distance that just got in the way

The end result was that it might require an inhumanly large dataset to achieve learning in any meaningful way. Early models could somewhat distinguish photos of oaks and pines in this way, but the results were too poor to be worth reporting.

Solution

Well, I ended up with two solutions.

The first was training a binary classifier to detect usable bark. The first time this came up, I was still working with the 500×500 patches, which worked out to just shy of 200,000 images. Not exactly a manual number, but I’ve never seen an ocean that I didn’t think I could boil. I spent an afternoon sorting 30,000 images, realized I had only sorted 30,000 images, and then realized I accidentally made a halfway decent training set for a binary classifer.

That classifier sorted the remainder of those images in under an hour.

The second solution was, unfortunately, quite a bit less dramatic. It involved simply going through the original photos, picking out the ones that basically weren’t platonic bark taken at about torso level, and just repeating the process of dividing them into patches and feeding those into the usable bart sorter.

So, what does it actually look like?

Alright, we’re getting into a little of the code. First, we need to import some standard libraries. I’ll do some explaining for the uninitiated.

pandas handles data manipulation, here mostly CSV files. If you’ve done any work with neural networks before, you might have seen images being loaded into folders, one per class. I personally prefer CSV files because they make it easier to include other relevant things besides just the specific class of the image. I’ll show you an example shortly.

matplotlib is a library for annotation and displaying charts.

numpy is a library for more easily handling matrix operations.

PIL is a library for processing images. It’s used here mostly for loading purposes, and we can see that because we are loading the Image class from it.

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

import os

Something that was really helpful was the use of confidence values. For every patch that the usable bark sorter classified, it also returned its confidence in the classification. Looking at the results, it had a strong inclination to reject perfectly usable bark, but it also wasn’t especially confident about those decisions. We can use that to our advantange by only taking bark that it was very confident about.

In [2]:
def confidence_threshold(df, thresh=0.9, below=True):
    if below==True:
        return df[df['confidence'] < thresh]
    else:
        return df[df['confidence'] > thresh]

Loading the Data

What these files will ultimately look like is still in flux. Today, I’ll be walking you through some older ones that still have basically what we need right now. First, we load the partitions that were accepted or rejected by the usable bark sorter. We also do a little cleanup of the DataFrame.

In [49]:
reject_file = "../torso_reject.csv"
accept_file = "../torso_accept.csv"

reject_df = pd.read_csv(reject_file, index_col=0)
accept_df = pd.read_csv(accept_file, index_col=0)

accept_df
Out[49]:
path confidence
0 dataset0/pittsburgh_torso-250×250/0001/2023010… 0.776984
1 dataset0/pittsburgh_torso-250×250/0001/2023010… 0.999994
2 dataset0/pittsburgh_torso-250×250/0001/2023010… 0.890233
3 dataset0/pittsburgh_torso-250×250/0001/2023010… 0.503523
4 dataset0/pittsburgh_torso-250×250/0001/2023010… 0.999959
108057 dataset0/pittsburgh_torso-250×250/0079/2023010… 0.998933
108058 dataset0/pittsburgh_torso-250×250/0079/2023010… 0.999988
108059 dataset0/pittsburgh_torso-250×250/0079/2023010… 0.998984
108060 dataset0/pittsburgh_torso-250×250/0079/2023010… 1.000000
108061 dataset0/pittsburgh_torso-250×250/0079/2023010… 0.999999

108062 rows × 2 columns

I accidentally goofed up when generating the files by accidentally leaving a redundant “.jpg” in a script. The files have since been renamed, but because we’re using older CSVs, we need to make a quick fix here. The below is just a quick helper function that pandas can use to map to each path in the DataFrame.

In [50]:
def fix_path(path):
    path_parts = path.split('/')
    
    # Change 'dataset0' to 'dataset'
    #fn[0] = "dataset"
    
    # Remove redundant '.jpg'
    fn = path_parts[-1]
    fn = fn.split('_')    
    fn[1] = fn[1].split('.')[0]
    fn = '_'.join(fn)
    path_parts[-1] = fn
    path_parts[0] = "../dataset"
    path_parts = '/'.join(path_parts)
    return path_parts
In [51]:
accept_df['path'] = accept_df['path'].map(lambda x: fix_path(x))
reject_df['path'] = reject_df['path'].map(lambda x: fix_path(x))

And just to make sure it actually worked okay:

In [52]:
accept_df['path'].iloc[0]
Out[52]:
'../dataset/pittsburgh_torso-250x250/0001/20230106_113732(0)_7_1.jpg'

Alright, so what’s this?

In [53]:
accept_df
Out[53]:
path confidence
0 ../dataset/pittsburgh_torso-250×250/0001/20230… 0.776984
1 ../dataset/pittsburgh_torso-250×250/0001/20230… 0.999994
2 ../dataset/pittsburgh_torso-250×250/0001/20230… 0.890233
3 ../dataset/pittsburgh_torso-250×250/0001/20230… 0.503523
4 ../dataset/pittsburgh_torso-250×250/0001/20230… 0.999959
108057 ../dataset/pittsburgh_torso-250×250/0079/20230… 0.998933
108058 ../dataset/pittsburgh_torso-250×250/0079/20230… 0.999988
108059 ../dataset/pittsburgh_torso-250×250/0079/20230… 0.998984
108060 ../dataset/pittsburgh_torso-250×250/0079/20230… 1.000000
108061 ../dataset/pittsburgh_torso-250×250/0079/20230… 0.999999

108062 rows × 2 columns

In [54]:
 reject_df
Out[54]:
path confidence
0 ../dataset/pittsburgh_torso-250×250/0001/20230… 1.000000
1 ../dataset/pittsburgh_torso-250×250/0001/20230… 0.999809
2 ../dataset/pittsburgh_torso-250×250/0001/20230… 0.999998
3 ../dataset/pittsburgh_torso-250×250/0001/20230… 0.997332
4 ../dataset/pittsburgh_torso-250×250/0001/20230… 0.999979
93149 ../dataset/pittsburgh_torso-250×250/0079/20230… 0.651545
93150 ../dataset/pittsburgh_torso-250×250/0079/20230… 0.999999
93151 ../dataset/pittsburgh_torso-250×250/0079/20230… 0.982236
93152 ../dataset/pittsburgh_torso-250×250/0079/20230… 0.991354
93153 ../dataset/pittsburgh_torso-250×250/0079/20230… 0.897549

93154 rows × 2 columns

Around 108K patches were accepted, and around 93K were rejected. Looking at the confidence column, we see that there are a few that the model is somewhat uncertain about some of these. Let’s take a look at that.

First, we need a couple more helper functions. This one opens an image, converts it to RGB, resizes it to something presentable, and converts it to a numpy array.

In [22]:
def image_reshape(path):
    image = Image.open(path).convert("RGB")
    image = image.resize((224, 224))
    image = np.asarray(image)
    return image

Next, it will be helpful to be able to see a few of these patches at once. This next function will get patches in batches of 16 and arranges them in a 4×4 grid with matplotlib.

In [23]:
def get_sample(path_list):
    print("Generating new sample")
    new_sample = np.random.choice(path_list, 16, replace=False)
    
    samples = []
    paths = []
    for image in new_sample:
        samples.append(image_reshape(image))
        paths.append(image)
    return samples, paths

get_sample() takes a list of paths, so let’s extract those from the DataFrame.

In [24]:
reject_paths = reject_df['path'].values.tolist()
accept_paths = accept_df['path'].values.tolist()
In [26]:
samples, paths = get_sample(reject_paths)
plt.imshow(samples[1])
Generating new sample
Out[26]:
<matplotlib.image.AxesImage at 0x7fb632dda470>

Okay, let’s see a whole grid of rejects!

In [27]:
def show_grid(sample):
    rows = 4
    cols = 4
    img_count = 0
    fix, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(15,15))
    
    for i in range(rows):
        for j in range(cols):
            if img_count < len(sample):
                axes[i, j].imshow(sample[img_count])
                img_count += 1
In [28]:
def new_random_grid(path_list):
    sample, paths = get_sample(path_list)
    show_grid(sample)
    return sample, paths
In [30]:
sample, paths = new_random_grid(reject_paths)
Generating new sample

These are samples of the whole of the rejects list. I encourage you to run this cell a bunch of times if you’re following along at home. You should see a goodly number of things that aren’t remotely bark: bits of streets, signs, grass, dirt, and other stuff like that. You’ll also see a number of patches that are bark, but aren’t very good for training. Some bark is blurry, out of focus, or a part of the tree that I later learned wasn’t actually useful.

A fun thing to try is to see the stuff that the model rejected, but wasn’t very confident about.

In [35]:
reject_df_below_85 = confidence_threshold(reject_df, thresh=0.85, below=True)
reject_df_below_85_paths = reject_df_below_85['path'].values.tolist()
sample, paths = new_random_grid(reject_df_below_85_paths)
Generating new sample

Running this just once or twice usually reveals images that are a bit more… On the edge of acceptability. It’s not always clear why something got rejected. I haven’t ended up doing this yet, but something in the works is examining these low-confidence rejects in either training or, more likely, testing of the model. Mercifully, there is enough data that was unambiguously accepted that doing so hasn’t been a pressing need.

Speaking of, we should take a look at what got accepted, too.

In [39]:
sample, paths = new_random_grid(accept_paths)
Generating new sample

As above, these are all of the accepted bark patches, not just those that have a high confidence. Let’s see what happens when we look at the low-confidence accepted patches.

In [45]:
accept_df_below_65 = confidence_threshold(accept_df, thresh=0.65, below=True)
accept_df_below_65_paths = accept_df_below_65['path'].values.tolist()
sample, paths = new_random_grid(accept_df_below_65_paths)
Generating new sample

Here we see that bias that I was talking about before. Note that when we were looking at the rejected bark, we were looking at patches that were divided by a threshold of 0.85 and were already seeing a lot of patches that could easily be accepted. Here, we are looking at a confidence threshold of 0.65 and are still not seeing many that would definitely be rejected.

The cause of the bias is unknown. I made it a special point of training the usable bark sorter on a roughly even split of acceptable and unacceptable bark. Because this was just a secondary tool for the real project, I haven’t had time to deeply investigate why this might be. I suspect there is some deep information theoretical reason for why this happened, perhaps one that will be painfully obvious to any high schooler once the field is more mature. The important thing now is it’s a quirk of the model that I caught early enough to use.

And what do these patches represent?

Having seen the images we are working we, it might be a good idea to look at what species we’re actually working with.

In [55]:
specimen_df = pd.read_csv("../specimen_list.csv", index_col=0)
specimen_df
Out[55]:
common_name latin_name family
id
1 norway_maple Acer_platanoides maple
2 norway_maple Acer_platanoides maple
3 norway_maple Acer_platanoides maple
4 red_maple Acer_rubrum maple
5 freeman_maple Acer_x_freemanii maple
75 northern_red_oak Quercus_rubra beech
76 northern_red_oak Quercus_rubra beech
77 white_oak Quercus_alba beech
78 bur_oak Quercus_macrocarpa beech
79 swamp_white_oak Quercus_bicolor beech

79 rows × 3 columns

Okay, I took photos of 79 different trees. It was actually 81, but the GPS signal on the last two was too spotty to match them on the map, and they had to be excluded. How can we break this down?

In [76]:
family_names = specimen_df.family.value_counts()
family_names = family_names.to_dict()
f_names = list(family_names.keys())
f_values = list(family_names.values())

header_font = {'family': 'serif', 'color': 'black', 'size': 20}
axis_font = {'family': 'serif', 'color': 'black', 'size': 15}
plt.rcParams['figure.figsize'] = [10, 5]

plt.bar(range(len(family_names)), f_values, tick_label=f_names)
plt.title("Breakdown of Specimens Collected in Pittsburgh, by Family",
         fontdict=header_font)
plt.xlabel("Family", fontdict=axis_font)
plt.ylabel("Number of Specimens", fontdict=axis_font)
plt.show()

When I started training, it made sense to start training focused on the family level. A family will inherently have at least as many images to work with as a species, and usually many more, and I had the assumption that variation would be smaller within the family. Interestingly enough, at least within this dataset, the difference in the quality of the model at the family and species levels has so far been negligible.

In [80]:
common_names = specimen_df.common_name.value_counts()
common_names = common_names.to_dict()
c_names = list(common_names.keys())
c_values = list(common_names.values())
In [83]:
plt.bar(range(len(common_names)), c_values, tick_label=c_names)
plt.title("Breakdown of Specimens Collected in Pittsburgh, by Species",
         fontdict=header_font)
plt.xlabel("Species (common names)", fontdict=axis_font)
plt.ylabel("Number of Specimens", fontdict=axis_font)
plt.xticks(rotation=45, ha='right')
plt.show()

The Model and Dataset Code

Full code can be found on the GitHub repo, but here are some important parts of the training code. First, because we are using a custom dataset, we need to make a class that will tell the dataloaders what to do. Some of this might need a little bit of explaining.

We have to import some things to make this part of the notebook work. BarkDataset inherits from the Dataset class. To initialize it, we only have to bring a given DataFrame, e.g. accept_df into the class. I’ve shown before that CSVs will let us work with a lot of other data that goes into the support and interpretation of the dataset, but BarkDataset itself only needs two things: the column of all the paths of the images themselves, and the column that defines their labels.

You might be a little confused about the line self.labels = df["factor"].values. The full code converts either the species-level or family-level specimens into a numerical class. For example:

"eastern_white_pine": 0

The label is the 0. When making predictions, we will convert back from this label for clarity to the user, but that isn’t how the model sees it.

After loading the DataFrame, we also define a set of transforms in self.train_transforms. At minimum, this is where we scale images down to 224×224 and normalize them. If you’re wondering, the values for normalization are standard in the field from ImageNet statistics.

In addition to these standard changes, transforms also has a wide variety of transforms that facilitate data augmentation. We can use data augmentation to give us more information from a base dataset; we just need to keep in mind to introduce the kinds of variability that would actually occur in the collection of more data.

You’ll notice two other methods in this class: __len__ and __getitem__. The former just returns the number of items in the DataFrame. The latter is where a single image is actually loaded using Image from the PIL library, and the label is matched from that image’s location in the DataFrame. The transforms are then applied to the image, and we get both the loaded image and its label returned in a dictionary.

In [86]:
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

class BarkDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.fns = df["path"].values
        self.labels = df["factor"].values

        self.train_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(60),
            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row.path)
        label = self.labels[idx]

        image = self.train_transform(image)

        return {
            "image": image,
            "label": torch.tensor(label, dtype=torch.long)
        }

Next, we’ll look at the model. The model can be expanded in a lot of ways with a class of its own, but at this stage of the project’s development, we are just starting with the pretrained weights and unchanged architecture as provided by timm.

In [ ]:
model = timm.create_model("deit3_base_patch16_224", pretrained=True, num_classes=N_CLASSES)
model.to(device)

criterion = nn.CrossEntropy
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=1,
                                                        eta_min=0.00001, last_epoch=-1)
model, history = training(model, optimizer, criterion, scheduler, device=device, num_epochs=100)
score = testing(test_df, model)

Now, I’ve gone through a lot of iterations of the various components here. Just off the top of my head:

  • I first started using an EfficientNet architecture, but got curious to see how one based on vision transformers would compare, and it wasn’t really a contest.
  • I originally used an Adam optimizer. Training with vanilla SGD proved slower, but also gave more consistent results.
  • I’ve settled on cosine annealing as a learning rate scheduler, but also have intermittent success with CyclicLR and MultiCyclicLR.
  • Most recently, I’ve been wondering if cross-entropy is the right loss metric. I quickly replaced accuracy with ROC because there are multiple, imblanaced classes in the full dataset. I suspect this also has implications for training, but have so far not found a loss metric that works better.

Changing the parameters here is still an active area of my research.

Results

This is a confusion matrix of the results so far. It plots the predicted label against the actual label, and makes it a little easier to see where things are getting mixed up. Ideally, there would only be nonzero values along the diagonal.

In [152]:
#!pip install seaborn
import confusion_stuff
import seaborn as sn

matrix = confusion_stuff.matrix

ind = confusion_stuff.individuals
ind = {k: v for k, v in sorted(ind.items(), key=lambda item: item[1])}
ind_keys = ind.keys()
ind_vals = ind.values()

df_cm = pd.DataFrame(confusion_stuff.matrix)
sn.heatmap(df_cm, cmap="crest")
Out[152]:
<AxesSubplot:>

The average ROC for this model across all 21 species is about 0.80. For reference, a 1.0 would be a perfect score, and 0.5 would be random guessing. The graphis is somewhat muddled because there are so many classes, so you can see the scores for individual classes below.

In [153]:
for i, j in zip(ind_keys, ind_vals):
    print(f"{i}: {j}")
common_pear: 0.508
pin_oak: 0.542
red_maple: 0.611
colorado_spruce: 0.69
kentucky_coffeetree: 0.736
chestnut_oak: 0.741
japanese_zelkova: 0.784
northern_red_oak: 0.823
swamp_white_oak: 0.838
scotch_pine: 0.847
sugar_maple: 0.848
callery_pear: 0.857
norway_maple: 0.864
austrian_pine: 0.873
thornless_honeylocust: 0.875
bur_oak: 0.888
eastern_white_pine: 0.888
ginkgo: 0.893
white_oak: 0.912
amur_corktree: 0.921
freeman_maple: 0.922
In [154]:
ind_keys = sorted(list(ind_keys))
ind_vals = sorted(list(ind_vals))
ind_keys.reverse()
ind_vals.reverse()
In [155]:
plt.bar(range(len(ind_keys)), ind_vals, tick_label=ind_keys)
plt.title("ROC of Individual Species in Best Model",
         fontdict=header_font)
plt.xlabel("Species", fontdict=axis_font)
plt.ylabel("ROC", fontdict=axis_font)
plt.xticks(rotation=45, ha='right')
plt.show()

Having plotted these, it’s not hard to see where the model is weak.

Next Steps

I’m still trying new things and learning a lot from this project. Some things on the horizon:

  • Trying one of the larger DeiT models.
  • Ensemble method: training a number of classifiers with a smaller number of classes and having them vote using their confidence scores.
  • Ensemble method: break one large test image into patches, have models vote on each of the patches, and use the majority, the highest confidence, or some other metric as the prediction.
  • Gathering more data, especially expanding to include more species.
  • Weird idea: information theoretical analysis of the tree bark.

Early days!

For the ftagn of it

My first entry going over lesson one was basically a whirlwind tour of how to do image classification. Naturally, a lot of interesting content was left out. Today, I’m just poking at the classification system to see if anything changes.

First, I mentioned in the first post that the model was trained using ResNet-34. If you didn’t have a chance to look at Deep Residual Learning for Image Classification, you might not know that there are different sizes of ResNet models that can be used here of various sizes. I’m going to give them a shot. You have an idea what the code should look like. Not that for my purposes, I’m using the dataset with nine classes.

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from fastai.vision import *
from fastai.metrics import error_rate

bs = 64

# Leaving off most of the infrastructure for making and populating the folders for readability.

path = Path("data/fruit")
classes = ['apples', 'bananas', 'oranges', 'grapes', 'pears', 'pineapples', 'nectarines', 'kiwis', 'plantains']

data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2,
                                 ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)

Now for the first bit we are using that is different, training with ResNet50. If it weren’t apparent from the name, it uses 50 layers instead of 34. This theoretically gives it a chance to learn more nuance, but it also introduces the risk of overfitting. We go!

In [3]:
learn = cnn_learner(data, models.resnet50, metrics=error_rate)
learn.fit_one_cycle(4)
Total time: 01:20

epoch train_loss valid_loss error_rate time
0 1.540670 0.397189 0.143713 00:38
1 0.851569 0.252047 0.071856 00:13
2 0.590861 0.260350 0.077844 00:14
3 0.448579 0.258741 0.083832 00:13

Differences? Here’s what jumps out at me:

  • This network took about twice as long to train, which is understandable given that it has many more layers.
  • The training time here is not uniformly distributed.
    • Running the other notebook again revealed that the timing for the equivalent using ResNet-34 also wasn’t uniformly distributed; it was just a fluke. I’m still curious as to why it might change, though.
  • We see a substantial improvement in training loss (~0.65 vs. ~0.41), and slight improvement in validation loss (~0.37 vs. 0.32) and error rate (~0.12 vs. ~0.1).

Okay, not bad. On closer inspection, the ResNet-50 example uses a resolution of 299 instead of 224 for its input images. Additionally, it halved the batch size and used the default number of workers (I am operating under the hypothesis that this is for memory reasons). How do those compare?

In [4]:
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2,
                                 ds_tfms=get_transforms(), size=299, bs=bs//2).normalize(imagenet_stats)
learn = cnn_learner(data, models.resnet50, metrics=error_rate)
learn.fit_one_cycle(4)
Total time: 01:51

epoch train_loss valid_loss error_rate time
0 1.252562 0.303597 0.089820 00:42
1 0.704307 0.329760 0.101796 00:23
2 0.469999 0.270663 0.125749 00:23
3 0.361943 0.265046 0.119760 00:23

With the higher resolution images, we still see improvement with the training loss, but also some fluctuations in the validation loss and error rate. Granted, my source ran this four eight cycles and saw a similar fluctuation in the initial four. Easily remedied.

In [5]:
learn.fit_one_cycle(4)
Total time: 01:35

epoch train_loss valid_loss error_rate time
0 0.120617 0.297528 0.107784 00:23
1 0.134901 0.270895 0.077844 00:23
2 0.124303 0.296516 0.101796 00:23
3 0.117052 0.297157 0.113772 00:23

I’m seeing training loss of ~0.12, validation loss of ~0.30, and an error rate of ~0.11. What does all of this mean?

Relationship of Training Loss, Validation Loss, and Error Rate

All of this is very abstract, so I decided to do some searching. Andrej Karpathy has a GitHub page that, among other things, includes tips for the interpretation of these values. Using this as a reference, because our training loss is much smaller than our validation loss, this model is probably overfitting.

In fact, inferring from this advice, we actually want our training loss to be higher than our validation loss. From the looks of it, this learner has been overfitting since epoch 3.

Why?

Simply, loss is an expression of an incorrect prediction; a loss of 0.0 would mean that all of our predictions were correct. Going from Karpathy’s post above, why do we want our training loss to be larger than our validation loss? In that situation, it would mean that our training was stringent enough and our model confident enough that it will perform better than expected when being validated.

Why Don’t You Fix Your Little Problem and Light This Candle?

I recently had the luxury/annoyance of having a short layover in the Hartsfield-Jackson Atlanta International Airport. Having basically not heard of this airport (long story), it was not apparent upon landing that I would be walking for half an hour to get to my next flight, or that I would be waiting among crowds of people from Sweden or Nigeria. Amidst the chaos and strss, I was reminded that the Earth is an immense planet that one person can never completely explore, and that this is true in more ways than one.

You’ll find from reading more of Horseless that I’m a generalist. It is of nothing. Right now, I’m doing the fast.ai course v3, and today I’m following along with lesson 1. Special thanks to Salamander, because frankly, AWS has some work to do on the front end.

It’s been a little while since I’ve used Jupyter notebooks (I’m writing all of this in one right now to get the hang of it again), so I totally forget about the existence of magics. There are tons, but the ones that are relevant to this image classification example are as follows.

We use these magics to reload modules automatically if they have been changed. If that sounds weird, there’s a brief, straighforward example in the docs.

In [1]:
%reload_ext autoreload
%autoreload 2

A lot of magics can also be used in IPython, but this one is specific to Jupyter notebooks. This just allows us to display matplotlib graphs in the notebook itself.

In [2]:
%matplotlib inline

Naturally, we have a couple of library imports. I’ve gotten into the habit of at least peaking at the docs whenever I run into a new library, especially when we’re importing everything. The documentation for fastai.vision is actually pretty short and readable.

In [3]:
from fastai.vision import *
from fastai.metrics import error_rate

bs just sets the batch size.

In [4]:
bs = 64

Deception!

This blog entry isn’t completely following the pets example. I’m also looking at an image download guide that I guess is from the next lesson. For this, I wanted to build a classifier of fruit using scraped images and see how well it scales with the number of categories and such.

Well, that’s what’s going on here. Basically, we make a Path object to hold all of our different classes. Classes here are assigned new folders; this snippet will make a new folder for each class if it doesn’t already exist. It looks like overkill here, and even the original tutorial in the image download tutorial only had three different classes for bears, but part of my exploration of lesson one will be looking at how the model responds to different numbers of classes.

In [5]:
path = Path("data/fruit1")

folders = ['apples','bananas', 'oranges']
for folder in folders:
    dest = path/folder
    dest.mkdir(parents=True, exist_ok=True)

If you don’t recognize ls as a CLI command, this just gets a list of all the non-hidden files and folders in the current working directory. I’ve probably run this notebook half a dozen times, so my original view of it was pretty thoroughly populated. You’ll see in a minute why having multiple folders for each of these runs was a good idea.

In [15]:
path.ls()
Out[15]:
[PosixPath('data/fruit1/apples'),
 PosixPath('data/fruit1/bananas'),
 PosixPath('data/fruit1/urls_apples.txt'),
 PosixPath('data/fruit1/.ipynb_checkpoints'),
 PosixPath('data/fruit1/oranges'),
 PosixPath('data/fruit1/urls_oranges.txt'),
 PosixPath('data/fruit1/urls_bananas.txt')]

This next bit is where we get into some annoying stuff in the tutorial that I want to come back and fix later. We need data to many anything happen here, and a quick way to get data for our purposes here is to run a search on Google Images for, let’s say, apples. The tutorial wants us to save those URLs, and to do that, we need a bit of code:

urls = Array.from(document.querySelectorAll('.rg_di .rg_meta')).map(el=>JSON.parse(el.textContent).ou);
window.open('data:text/csv;charset=utf-8,' + escape(urls.join('\n')));

It’s JavaScript, my old nemesis. The tutorial just says that this is to be entered into the browser’s ‘Console’, and suggests that Ctrl + Shift + J brings this up for Firefox. Maybe I spent too much time messing with the settings or something, because that brought up a (for this) worthless display console. For Firefox, I needed to go to Tools -> Web Developer -> Web Console to enter this code, which I only later learned was Shift + Ctrl + K.

Anyway, this needs to be repeated for each class. It’s automatish. The video links to a different set of tips for downloading images using a tool called googleimagesdownload that I had a chance to use once, but getting that working involved enough installation, weird complications, and documentation that would be better suited for another entry. We’re not fighting Raktabija today.

This snippet is straightforward enough. Go through each of the classes and use download_images to download up to 200 images for each one. Yes, I’m leaving those prints in. Full disclosure, I was looking at three different examples and was way more confused than I should have been for about 20 minutes. Those things are heroes.

That said, funny story, true story. For some reason, my urls_apples.txt file has 200 links, twice as many as the others here. I don’t know why. Maybe I scrolled down farther in the Google Images results than the others and the code included the extra ones that were dynamically loaded on the page. Maybe I was playing with a different piece of code when I was first doing this at 0230 and forgot about it. Maybe I’m a crazy man on the Internet who doens’t know what he’s talking about because it works on your machine.

The thing is, this isn’t trivial! My original run had twice as many examples of apples to draw from as anything else, which could easily change its performance. My point is, be careful. There’s a lot of work to do here.

In [16]:
for folder in folders:
    file = "urls_" + folder + ".txt"
    print(file)
    print(folder)
    print(path)
    download_images(path/file, path/folder, max_pics = 100)
urls_apples.txt
apples
data/fruit1
100.00% [100/100 00:14<00:00]
Error https://www.washingtonpost.com/resizer/Q2AWXkiQTsKIXHT_91kZWGIFMRY=/1484x0/arc-anglerfish-washpost-prod-washpost.s3.amazonaws.com/public/YCFJ6ZPYZQ2O5ALQ773U5BA7GU.jpg HTTPSConnectionPool(host='www.washingtonpost.com', port=443): Read timed out. (read timeout=4)
urls_bananas.txt
bananas
data/fruit1
100.00% [100/100 00:05<00:00]
Error https://www.kroger.com/product/images/xlarge/front/0000000004237 HTTPSConnectionPool(host='www.kroger.com', port=443): Read timed out. (read timeout=4)
urls_oranges.txt
oranges
data/fruit1
100.00% [100/100 00:03<00:00]

verify_images is used to do a bit of preliminary cleanup. Here, it will delete images that are broken and resize them so that each one is no larger than 500 pixels on a side. The documentation promises that the original ratios are preserved, so we should be good to go.

In [17]:
for f in folders:
    print(f)
    verify_images(path/f, delete=True, max_size=500)
apples
100.00% [97/97 00:03<00:00]
cannot identify image file <_io.BufferedReader name='data/fruit1/apples/00000096.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit1/apples/00000050.png'>
cannot identify image file <_io.BufferedReader name='data/fruit1/apples/00000088.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit1/apples/00000068.png'>
cannot identify image file <_io.BufferedReader name='data/fruit1/apples/00000031.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit1/apples/00000015.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit1/apples/00000055.jpg'>
bananas
100.00% [98/98 00:02<00:00]
cannot identify image file <_io.BufferedReader name='data/fruit1/bananas/00000015.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit1/bananas/00000061.jpg'>
oranges
100.00% [100/100 00:02<00:00]
cannot identify image file <_io.BufferedReader name='data/fruit1/oranges/00000009.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit1/oranges/00000011.png'>
cannot identify image file <_io.BufferedReader name='data/fruit1/oranges/00000048.jpg'>
/home/ubuntu/anaconda3/lib/python3.7/site-packages/PIL/Image.py:965: UserWarning: Palette images with Transparency   expressed in bytes should be converted to RGBA images
  ' expressed in bytes should be converted ' +
cannot identify image file <_io.BufferedReader name='data/fruit1/oranges/00000078.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit1/oranges/00000090.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit1/oranges/00000092.jpg'>

Seriously, that many?! I’m going to be using this thing all the time.

Get on with it!

We set the random seed for reproducibility. I remember this being wonky when I did this in Keras, but we’ll roll with it for now.

I do recommend having a look at the documentation for ImageDataBunch. What’s going on here? Images from the specified folder are collected. valid_pct holds out 20% of the images for validation (0.2 is provided as a default if you leave it out). ds_tfms=get_transforms() is gives us a set of default data augmentation features like flipping the images on the axis.

In [6]:
np.random.seed(42)
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2,
                                 ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)

Just a quick peek to see where we are that we have the right classes and see if anything obvious weird happened with our images.

In [19]:
data.classes
Out[19]:
['apples', 'bananas', 'oranges']
In [20]:
data.show_batch(rows=3, figsize=(7,8))

This is the juicy part (but it does look a bit light if you’re familiar with Keras). Here we put our dataset into a convolutional neural network. The model that we’re using here is the pretrained ResNet-34, a midsized residual network. If you’re into the background, I recommend familiarizing yourself with the paper Deep Residual Learning for Image Recognition, which I’ll be talking about in a later post. There is a lot to unpack here.

In [7]:
learn = cnn_learner(data, models.resnet34, metrics=error_rate)

Now, we just train the learner for four epochs.

In [71]:
learn.fit_one_cycle(4)
Total time: 00:16

epoch train_loss valid_loss error_rate time
0 1.376529 0.612187 0.232143 00:04
1 0.873382 0.127812 0.000000 00:04
2 0.622132 0.040325 0.000000 00:04
3 0.467340 0.029021 0.000000 00:03

Now we save these weights, because we’re about to see what happens when we train the entire model. Up until now, we’ve only trained a portion of the network, but unfreeze() will allow us to train on all of the weights. lr_find() will train with a number of different learnings rates.

In [72]:
learn.save('fruit-1')
learn.unfreeze()
learn.fit_one_cycle(1)
Total time: 00:04

epoch train_loss valid_loss error_rate time
0 0.097448 0.104915 0.053571 00:04
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

This bit of code will graph the results of lr_find().

In [74]:
learn.lr_find()
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

We run fit one cycle, using values from the graph where the loss starts to be minimized.

In [75]:
learn.fit_one_cycle(3, max_lr=slice(3e-4, 3e-3))
Total time: 00:13

epoch train_loss valid_loss error_rate time
0 0.068045 0.043532 0.017857 00:04
1 0.059558 0.174751 0.053571 00:04
2 0.046345 0.114111 0.035714 00:04

… Okay! The training loss and error rate are decreasing, but the validation loss is on the rise. The network might be overfitting a bit here, but we shall see how it performs. I’m not especially confident that we have a total improvement here, so we will save this as a separate model.

In [76]:
learn.save('fruit-2')

ClassificationInterpretation can take a Learner and plot a confusion matrix, a grid indicating what the model predicted a given image was versus what it actually was. Because of the way the graphic is set up, it is visually easy to see that a model has achieved a low error rate by looking at its characteristic solid diagonal line with maybe a few missed cases here and there.

In [79]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

Expanded number of classes.

Not bad, but we can try something a bit more complicated. Let’s repeat the above steps using nine categories instead of three.

In [8]:
path = Path("data/fruit")

folders = ['apples', 'bananas', 'oranges', 'grapes', 'pears', 'pineapples', 'nectarines', 'kiwis', 'plantains']

for folder in folders:
    dest = path/folder
    dest.mkdir(parents=True, exist_ok=True)
In [50]:
path.ls()
Out[50]:
[PosixPath('data/fruit/apples'),
 PosixPath('data/fruit/urls_kiwis.txt'),
 PosixPath('data/fruit/bananas'),
 PosixPath('data/fruit/grapes'),
 PosixPath('data/fruit/urls_plantains.txt'),
 PosixPath('data/fruit/urls_pineapples.txt'),
 PosixPath('data/fruit/urls_apples.txt'),
 PosixPath('data/fruit/models'),
 PosixPath('data/fruit/kiwis'),
 PosixPath('data/fruit/.ipynb_checkpoints'),
 PosixPath('data/fruit/oranges'),
 PosixPath('data/fruit/urls_pears.txt'),
 PosixPath('data/fruit/nectarines'),
 PosixPath('data/fruit/urls_nectarines.txt'),
 PosixPath('data/fruit/urls_oranges.txt'),
 PosixPath('data/fruit/urls_grapes.txt'),
 PosixPath('data/fruit/pears'),
 PosixPath('data/fruit/plantains'),
 PosixPath('data/fruit/pineapples'),
 PosixPath('data/fruit/urls_bananas.txt')]
In [51]:
for folder in folders:
    file = "urls_" + folder + ".txt"
    download_images(path/file, path/folder, max_pics = 200)
100.00% [100/100 00:13<00:00]
Error https://www.washingtonpost.com/resizer/Q2AWXkiQTsKIXHT_91kZWGIFMRY=/1484x0/arc-anglerfish-washpost-prod-washpost.s3.amazonaws.com/public/YCFJ6ZPYZQ2O5ALQ773U5BA7GU.jpg HTTPSConnectionPool(host='www.washingtonpost.com', port=443): Read timed out. (read timeout=4)
100.00% [100/100 00:05<00:00]
Error https://www.kroger.com/product/images/xlarge/front/0000000004237 HTTPSConnectionPool(host='www.kroger.com', port=443): Read timed out. (read timeout=4)
100.00% [100/100 00:03<00:00]
100.00% [100/100 00:06<00:00]
Error https://www.washingtonpost.com/resizer/c8KGxTwvfOVZuG4hd_vlgUIHdwU=/1484x0/arc-anglerfish-washpost-prod-washpost.s3.amazonaws.com/public/XXA47NWIRII6NC7OKTUAB3ZKMM.jpg HTTPSConnectionPool(host='www.washingtonpost.com', port=443): Read timed out. (read timeout=4)
Error https://www.kroger.com/product/images/xlarge/front/0000000094022 HTTPSConnectionPool(host='www.kroger.com', port=443): Read timed out. (read timeout=4)
Error https://www.washingtonpost.com/rf/image_982w/2010-2019/WashingtonPost/2017/10/11/Food/Images/SheetPanSausageDinner-1733.jpg HTTPSConnectionPool(host='www.washingtonpost.com', port=443): Read timed out. (read timeout=4)
100.00% [100/100 00:03<00:00]
100.00% [100/100 00:04<00:00]
Error https://cdn.teachercreated.com/20180323/covers/900sqp/2156.png HTTPSConnectionPool(host='cdn.teachercreated.com', port=443): Max retries exceeded with url: /20180323/covers/900sqp/2156.png (Caused by SSLError(SSLError("bad handshake: Error([('SSL routines', 'tls_process_server_certificate', 'certificate verify failed')])")))
100.00% [100/100 00:09<00:00]
100.00% [100/100 00:04<00:00]
100.00% [100/100 00:04<00:00]
Error https://blog.heinens.com/wp-content/uploads/2016/03/Plantains_Blog-Feature.jpg HTTPSConnectionPool(host='blog.heinens.com', port=443): Max retries exceeded with url: /wp-content/uploads/2016/03/Plantains_Blog-Feature.jpg (Caused by SSLError(SSLError("bad handshake: Error([('SSL routines', 'tls_process_server_certificate', 'certificate verify failed')])")))
Error https://www.washingtonpost.com/resizer/dJ0Bg2LQrSAaSo6UlPbKm6J4SJ8=/1484x0/arc-anglerfish-washpost-prod-washpost.s3.amazonaws.com/public/Y6C5GNMXEA2XJL5TQXJTF6DXPQ.jpg HTTPSConnectionPool(host='www.washingtonpost.com', port=443): Read timed out. (read timeout=4)
In [52]:
for f in folders:
    print(f)
    verify_images(path/f, delete=True, max_size=500)
apples
100.00% [97/97 00:03<00:00]
cannot identify image file <_io.BufferedReader name='data/fruit/apples/00000050.png'>
cannot identify image file <_io.BufferedReader name='data/fruit/apples/00000096.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/apples/00000088.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/apples/00000068.png'>
cannot identify image file <_io.BufferedReader name='data/fruit/apples/00000031.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/apples/00000015.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/apples/00000055.jpg'>
bananas
100.00% [98/98 00:02<00:00]
cannot identify image file <_io.BufferedReader name='data/fruit/bananas/00000015.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/bananas/00000061.jpg'>
oranges
100.00% [100/100 00:02<00:00]
cannot identify image file <_io.BufferedReader name='data/fruit/oranges/00000009.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/oranges/00000011.png'>
cannot identify image file <_io.BufferedReader name='data/fruit/oranges/00000048.jpg'>
/home/ubuntu/anaconda3/lib/python3.7/site-packages/PIL/Image.py:965: UserWarning: Palette images with Transparency   expressed in bytes should be converted to RGBA images
  ' expressed in bytes should be converted ' +
cannot identify image file <_io.BufferedReader name='data/fruit/oranges/00000078.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/oranges/00000090.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/oranges/00000092.jpg'>
grapes
100.00% [97/97 00:02<00:00]
cannot identify image file <_io.BufferedReader name='data/fruit/grapes/00000022.png'>
cannot identify image file <_io.BufferedReader name='data/fruit/grapes/00000097.jpeg'>
cannot identify image file <_io.BufferedReader name='data/fruit/grapes/00000027.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/grapes/00000092.png'>
/home/ubuntu/anaconda3/lib/python3.7/site-packages/PIL/Image.py:965: UserWarning: Palette images with Transparency   expressed in bytes should be converted to RGBA images
  ' expressed in bytes should be converted ' +
pears
100.00% [100/100 00:03<00:00]
cannot identify image file <_io.BufferedReader name='data/fruit/pears/00000045.png'>
cannot identify image file <_io.BufferedReader name='data/fruit/pears/00000058.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/pears/00000084.png'>
/home/ubuntu/anaconda3/lib/python3.7/site-packages/PIL/Image.py:965: UserWarning: Palette images with Transparency   expressed in bytes should be converted to RGBA images
  ' expressed in bytes should be converted ' +
/home/ubuntu/anaconda3/lib/python3.7/site-packages/PIL/Image.py:965: UserWarning: Palette images with Transparency   expressed in bytes should be converted to RGBA images
  ' expressed in bytes should be converted ' +
cannot identify image file <_io.BufferedReader name='data/fruit/pears/00000068.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/pears/00000029.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/pears/00000031.png'>
/home/ubuntu/anaconda3/lib/python3.7/site-packages/PIL/Image.py:965: UserWarning: Palette images with Transparency   expressed in bytes should be converted to RGBA images
  ' expressed in bytes should be converted ' +
cannot identify image file <_io.BufferedReader name='data/fruit/pears/00000081.jpg'>
pineapples
100.00% [99/99 00:04<00:00]
cannot identify image file <_io.BufferedReader name='data/fruit/pineapples/00000048.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/pineapples/00000032.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/pineapples/00000012.jpg'>
/home/ubuntu/anaconda3/lib/python3.7/site-packages/PIL/Image.py:1018: UserWarning: Couldn't allocate palette entry for transparency
  warnings.warn("Couldn't allocate palette entry " +
cannot identify image file <_io.BufferedReader name='data/fruit/pineapples/00000034.jpg'>
nectarines
100.00% [100/100 00:02<00:00]
cannot identify image file <_io.BufferedReader name='data/fruit/nectarines/00000011.png'>
cannot identify image file <_io.BufferedReader name='data/fruit/nectarines/00000051.png'>
cannot identify image file <_io.BufferedReader name='data/fruit/nectarines/00000088.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/nectarines/00000005.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/nectarines/00000062.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/nectarines/00000098.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/nectarines/00000095.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/nectarines/00000080.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/nectarines/00000038.jpg'>
kiwis
100.00% [98/98 00:02<00:00]
cannot identify image file <_io.BufferedReader name='data/fruit/kiwis/00000084.png'>
cannot identify image file <_io.BufferedReader name='data/fruit/kiwis/00000078.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/kiwis/00000056.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/kiwis/00000068.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/kiwis/00000075.jpg'>
plantains
100.00% [97/97 00:02<00:00]
cannot identify image file <_io.BufferedReader name='data/fruit/plantains/00000052.jpg'>
local variable 'photoshop' referenced before assignment
cannot identify image file <_io.BufferedReader name='data/fruit/plantains/00000083.jpg'>
cannot identify image file <_io.BufferedReader name='data/fruit/plantains/00000044.jpg'>
In [12]:
data2 = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2,
                                 ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)
In [13]:
learn2 = cnn_learner(data2, models.resnet34, metrics=error_rate)
In [14]:
learn2.fit_one_cycle(4)
Total time: 00:31

epoch train_loss valid_loss error_rate time
0 2.030864 0.831149 0.221557 00:08
1 1.239848 0.415597 0.119760 00:07
2 0.863067 0.377215 0.119760 00:07
3 0.657382 0.371677 0.113772 00:07
In [93]:
learn2.save('large-fruit-1')
interp2 = ClassificationInterpretation.from_learner(learn2)
interp2.plot_confusion_matrix()

Much more interesting. Even without the unfreezing and optimization, we see decent results with a larger number of categories. Additionally, where you see some more prominent spots of confusing, it is worth actually thinking about what might be tripping the model up. Depending on the image, a nectarine can sort of look like an apple. Being from the US, plantains to me are basically bananas that are always green.

You’ll notice that we didn’t do that unfreeze-and-optimize strategy from before, so let’s see what that looks like.

In [94]:
learn2.unfreeze()
learn2.fit_one_cycle(1)
Total time: 00:10

epoch train_loss valid_loss error_rate time
0 0.366309 0.297740 0.113772 00:10
In [95]:
learn2.lr_find()
learn2.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

Well, that sure looks like a different loss graph.

In [97]:
learn2.save('large-fruit-2')
In [106]:
learn2.load('large-fruit-2')
learn2.fit_one_cycle(3, max_lr=slice(1e-5, 1e-3))
Total time: 00:30

epoch train_loss valid_loss error_rate time
0 0.226404 0.285319 0.101796 00:10
1 0.167561 0.261312 0.083832 00:10
2 0.128132 0.265344 0.071856 00:09
In [107]:
interp2 = ClassificationInterpretation.from_learner(learn2)
interp2.plot_confusion_matrix()

Training loss has been cut significantly, but validation loss and error rates are holding steady. There is a lot of unpacking to be done. Early days!