This article deals with image classification of cats and dogs using convolutional neural networks (CNNs). Why cats and dogs? The idea came to me when my wife and I were in Herceg Novi, Montenegro over Christmas and New Year. I finished the first article abouts craft beer and launched this website. I thought about what topic I could cover in the next article. I only knew that I wanted to train a neural network for image classification, but I didn't know which images to use. Then I got the idea of cats and dogs, because Montenegro is the country of cats, unlike Georgia, the Caucasus, which is the country of dogs, but that is another story. Cats are everywhere in Montenegro. So every day I had the opportunity to take pictures of cats.
The first article, article used many standard artificial intelligence models and methods from Scikit-learn [1] to analyse craft beer. In this article, PyTorch [2] is used to implement and train deep convolutional neural networks. The trained models are compared with state-of-the-art pre-trained models.
The goal is to train a neural network with a set of images from a public database and hopefully use the trained model to classify cats in images taken in Montenegro. The training, testing and validation approach is shown in Figure 1.1.
The first step was to train a model using the images in the training set. After each training iteration, the test set is used to evaluate the accuracy of the trained model. Based on the accuracy of training and testing, the model was manually adopted to find a compromise between overfitting and underfitting. Both steps were repeated for 50 epochs. After 50 epochs, the model with the best accuracy for the test set was selected as the final model. The validation set was then used to assess the accuracy of the final model.
There can be confusion in the terminology between the test set and the validation set. Sometimes they are used interchangeably [3], but it seems that the terminology presented here is more common [4] [5] [6]. Nevertheless, the set used to determine the accuracy of the final model should only be used for the final model.
Another important note. The images for the training and test sets come from a public database. They are from one source. Presumably it was one set and the data was split into a training and a test set. The images for the validation set come from a different source, the images taken by the author in Montenegro. This is not common. Usually a set from one source is divided into training, test and validation sets [7]. Or even the whole validation step is skipped and only a training and a test set are used [5] [6].
A data set from Kaagle [8] containing images of cats and dogs was used as the training and test set to train a classifier to classify cats and dogs.
The dataset contains 25,000 labelled images, 12,500 images of cats and 12,500 images of dogs. The training set contains 20,000 images, 10,000 images of cats and 10,000 images of dogs. The test set contains 5000 images, 2,500 images of cats and 2,500 images of dogs. The images range in pixel size from 32 x 42 to 500 x 500.
The dataset was pre-processed by removing all images smaller than 224 x 224 pixels. After preprocessing, the training set contains 17,650 images. The set is still balanced with 8810 cats and 8840 dogs. The test set contains 4395 images and is still balanced. The number of cats is 2198 and the number of dogs is 2197.
Figure 2.1.1 shows some random samples from the training set and Figure 2.1.2 shows some random samples from the test set.
The samples show cats and dogs in different angles and settings: from the front, from the side, indoors, outdoors, sometimes with people and there are even pictures with several cats and dogs.
I wrote a script to compare the images in the training and test sets to make sure that the test and training sets didn't contain the same images, see Jupyter Notebook. Only nine identical images were found in both sets. That's good, but it's proof that the two sets probably came from the same source and were split into two sets.
Figure 2.2.1 shows some random samples from the validation set. Figure 2.2.2 shows the same sample, but cropped, from the cropped validation set.
The samples show cats in different positions, from the front, from the side, rarely with people and rarely several cats. All images were taken outside, sometimes with a very detailed natural background.
After many iterations, I found a CNN (Convolutional Neural Network) architecture that gave very good results in predicting cats and dogs in the test set. Figure 3.1 shows the architecture of the CNN.
The convolutional layers include a convolution followed by batch normalisation, a ReLU activation function and a max pooling layer. The input to the first layer has the shape 224x224x3. This is the shape of the colour image with a red, green and blue channel. The output of the first convolution has the shape 224x224x32. The output of the following max pooling layer has the shape 128x128x32. This means that in each convolution step the number of channels is doubled to extract more features. And in each max pooling step, the height and width of the image is reduced by a factor of 2 to reduce the information. The convolution, batch normalisation, ReLU activation and max pooling steps are repeated five times. After five convolutional layers, the output of the last max pooling layer with a shape of 7x7x512 is flattened to a shape of 1x25088. Feature extraction is now complete.
The 25088 extracted features are the input layer for a neural network, with one hidden layer and an output layer of shape 1x1. Usually the shape of the output layer is equal to the shape of the classes to be predicted, here a class for cats and a class for dogs. However, for binary classification, a 1x1 shape layer is recommended, using binary cross entropy as the criterion to minimise. The output of the neural network is fed into a sigmoid activation function. If the output of the sigmoid function is < 0.5, the input image is classified as a cat, otherwise as a dog.
The model was trained on 50 epochs. Figure 4.1 shows the accuracy for the training and test sets plotted against the number of epochs. This graph shows the learning curve. A dot highlights the maximum accuracy in the training and test sets. The model with the highest accuracy for the test set is selected as the final model, highlighted by a dashed line.
The highest accuracy for the test set is 95.77% at epoch 50. The model does not overfit and generalises well. After 20 epochs there is no significant improvement in the classification of the test set.
Figure 4.2 shows the confusion matrix for the test set. The confusion matrix gives a more detailed overview and insight into the predictions of the trained model. The number of misclassified cats is 4.8%; 105 out of 105+2093. The number of misclassified dogs is 3.7%; 81 out of 81+2116. So the model has no tendency to predict cats much better or dogs much better.
Nice, the accuracy for the test set is very good. It took some time to find a suitable CNN architecture by changing the number of layers and the number of channels in the convolution layers. But it wasn't that hard.
This is where many articles, blogs, books and competitions end. A model has been trained and evaluated using training and test data that (presumably) come from the same source.
Now the model is put into the real world and fed with images from a different source!
Figure 4.1.1 and Figure 4.1.2 show the confusion matrices for the validation set and the cropped validation set. As there are no dogs in the validation sets, the first row of the confusion matrix is always empty.
In comparison to the test set, the accuracy of the validation sets is much lower, at around 46%. These are the originally captured images with a large detailed background. Classifying the cropped validation set gives better results. The accuracy is about 64%. However, compared to about 96% for the test set, this is still not satisfactory.
To see why the accuracy for classifying the validation sets is so poor, the learning curves for the validation sets have been added to Figure 4.1.3. Again, the best accuracy for each set is highlighted by a dot.
Figure 4.1.3 gives a deeper insight into how the accuracy for the validation sets changes during the training process. The learning curves for the validation sets oscillate very strongly. It's obvious that an update in the model has a big influence on the accuracy of the validation sets. The learning curve for the validation set does not increase with the number of epochs and only oscillates strongly. In contrast, the learning curve for the cropped validation set tends to increase slightly with the number of epochs and also oscillates strongly.
It was just luck that the best model for the test set at epoch 50 leads to a good accuracy of about 64% for the cropped validation set. Because of the strong oscillation it could have been 30% or less.
This leads to the problem that it is not possible to select a suitable model to classify cats from Montenegro based only on the analysis of the training and test sets. For validation, images from the intended environment should be used. Furthermore, this leads to the question: How will the chosen model perform in other environments, e.g. with images taken with different backgrounds or in different seasons?
In summary, the model is appropriate and gives very good results on the test set. But it doesn't generalise to images from a different distribution, the validation sets. This effect is called distribution shift and is well known in the literature [9][10]. However, despite the fact that it renders a trained model completely useless in the real world, with different distributions, distribution shift is not addressed or not very present. For example, most datasets with training and test sets come from a single source [9]. Furthermore, it is state of the art to split a dataset into training, test and optionally a validation set [5][6][7].
The model's poor performance on validation sets was unexpected. In the remaining analysis I want to find a way to reduce the oscillation and get a higher accuracy for the validation sets.
Have a look at this Jupyter Notebook for a deeper insight.One way to make a model more robust and generalisable is to train it on augmented images [11]. Eight transformations were applied to each image in the training set. These transformations are colour jitter, gray scale, inverted colours, zoom out (to account for the fact that the cats in the validation set are smaller), two different perspectives and rotations. The training set now contains 158850 images, 17650 original images and 141200 augmented images. Figure 5.1 shows the augmented images for one sample from the training set.
Figure 5.2 shows the learning curves for 10 epochs. The highest accuracy for the test set is 95.08%, slightly lower than without augmented images. However, with augmented images, higher accuracy is achieved in fewer epochs, although these epochs take much longer because there are more images to process.
Using augmented images doesn't save any computation time, but it does reveal an interesting fact. The learning curves for the test, training and even validation sets are less oscillating. The augmentation leads to a more robust model. Unfortunately, the accuracy for the validation sets didn't increase. The accuracy of around 41% for the cropped validation set is not satisfactory.
Figure 5.3.1 shows the pie chart for the validation set and Figure 5.3.2 shows the pie chart for the validation set with cropped images. As there are no dogs in the validation set, a pie chart is sufficient to show the accuracy. Therefore, for the following analysis, a pie chart will be used to visualise the accuracy for the validation sets.
For a deeper insight, see this Jupyter Notebook.
The idea of removing the influence of colour and thus the influence of different backgrounds and lighting conditions from the images leads to the conversion of all images in the training, test and validation sets into grayscale images. Figure 6.1 shows the learning curves using grayscale images.
The highest accuracy for the test set is 95.31%. This is almost the same as for colour
images or augmented colour images.
Obviously, the lack of colour information does not affect the classification accuracy
of the training and test sets.
However, this is not the case for the validation set.
The learning curve for the cropped validation set does not increase with the number of epochs.
It simply oscillates. This means that the missing colour information
leads to a worse generalisation of the model.
Figure 6.2.1 shows the pie chart for the validation set and Figure 6.2.2 shows the pie chart for the cropped validation set. The accuracy of around 33% for the cropped validation set is not very satisfactory.
Another approach to deal with the distribution shift leads to the idea of equalising the histogram of the images [10] in the training set. Figure 7.1.1 shows examples of training images and Figure 7.1.2 shows the corresponding histogram.
Figure 7.3 shows the learning curves using grayscale images with equalised histogram for the training set.
The accuracy for the test set is 93.58% compared to 95.31% using only grayscale images for training. Again, the learning curves for the validation sets do not increase and oscillations are still present. However, the positive effect is that an offset of approximately 5 percentage points has shifted upwards the learning curves for the validation sets.
Figure 7.4.1 shows the pie chart for the validation set and Figure 7.4.2 shows the pie chart for the cropped validation set. The accuracy of about 54% for the cropped validation set is a big improvement compared to using only grayscale images.
Histogram equalisation has a positive effect on the classification of the validation sets and leads to a more generalised model. However, using grayscale images negates these improvements, as using only colour images gives better results.
For a deeper insight, see this Jupyter Notebook.Another idea was to calculate the edges of an image and train the model on images with edges. This approach is not like augmenting images. It is more like transforming the original images into another space. The edges were calculated for the training set, the test set and the validation sets.
Figure 8.1.1 shows samples of cropped validation images and Figure 8.1.2 shows the corresponding images with calculated edges.
The learning curves for the cropped validation set increase slightly with the number of epochs and the oscillation is reduced. However, the accuracy is low. Only a few outliers reach more than 40%. Using the edges for classification is not a good idea when the background is very detailed, as in the validation sets.
Figure 8.3.1 shows the pie chart for the validation set and Figure 8.3.2 shows the pie chart for the cropped validation set.
For a deeper insight, see this Jupyter Notebook.
Many approaches were used to improve the accuracy of the prediction for the validation sets and reduce the oscillation, with marginal success. Therefore, some basic unit tests were added to ensure that there were no errors in the training process. Prior to this, the training and test sets were compared to ensure that there were no errors in the datasets. Finally, to make sure there is no error in the model, some state of the art pre-trained models are used for cross validation of different models.
The following four models were used to evaluate the accuracy for the test and validation sets:
In contrast, my model, let's call it MyModel, has 4,382,467 weights and a file size of 16.7 MB. The pre-trained models are more powerful because they have many more weights to train and use additional concepts and architectures. But they are trained to classify 1000 classes, not two classes like MyModel. All models and pre-trained weights are downloaded from PyTorch [12]. All these models have been trained on the ImageNet database [13]. This dataset covers 1000 object classes and contains 1,281,167 training images, 50,000 validation images and 100,000 test images. This means that one class is represented by approximately 1280 training images.
The four pre-trained models were used to identify cats versus other images. Table 9.1 shows the number of trainable weights (in trillion) and the accuracies for the test and validation sets for MyModel and the pre-trained models.
Model | Weights | Test Set | Vali Set | Vali Set cropped |
---|---|---|---|---|
MyModel | 4 | 96 | 46 | 64 |
EfficientNetB0 | 5 | 91 | 2 | 51 |
DenseNet201 | 20 | 89 | 0 | 44 |
ResNet152 | 60 | 92 | 5 | 54 |
ConvNextLarge | 198 | 92 | 7 | 67 |
The models perform almost equally on the test set. MyModel is the best with 96%. But it's specially trained on cats and dogs. The other models are slightly worse, with accuracies between 89% and 92%, but they are able to predict 991 more classes (because the ImageNet database contains 9 different classes of cats).
On the validation set, the pre-trained models fail completely, with accuracies ranging from 0% to 7%. MyModel performs best with 46%. On the cropped validation set, the pre-trained models perform much better with accuracies ranging from 44% to 67%. The best model is ConvNextLarge with an accuracy of about 67%. MyModel has the second highest accuracy with 64%.
The cross validation shows that the pre-trained models have the same problems with the validation sets as MyModel. They perform very well on the test set, but much worse on the validation sets.
For deeper insights, see this Jupyter Notebook.The next step was to fine-tune the pre-trained models. To do this, the fully connected layer was redefined to classify only two classes instead of 1000. All trainable weights, except for the fully connected layer, were kept constant. The fully connected layer was trained with the training set for 5 epochs. Figure 10.1 to Figure 10.4 show the learning curves for the pre-trained fine-tuned models.
Table 10.1 lists the number of trainable weights (in trillions) and the accuracies for the test and validation sets for the pre-trained models.
Model | Weights | Test Set | Vali Set | Vali Set cropped |
---|---|---|---|---|
EfficientNetB0 | 5 | 97 | 56 | 72 |
DenseNet201 | 20 | 99 | 69 | 82 |
ResNet152 | 60 | 99 | 48 | 85 |
ConvNextLarge | 198 | 100 | 56 | 100 |
The fine-tuned pre-trained models perform very well on the training and test sets. A little too well. The learning curves for the training and test sets are almost identical. They become more and more identical as the number of weights increases. The ConvNextLarge model achieves 100% accuracy on the training, test and cropped validation sets. The models are too complex. There is over-fitting.
Furthermore, all fine-tuned pre-trained models have problems classifying the validation sets. The accuracy for the cropped validation set only improves with a massive increase in weights by using more complex models. But the learning curve doesn't oscillate as much. So the model is more stable, but it needs a lot more weights to reach acceptable results of > 80%. And the cropped validation set is harder to classify than the test set.
For deeper insights, see these Jupyter Notebooks.In the previous chapter it was shown that the models were too complex and overfitting. However, to complete this analysis, an EfficientNetB0 model with randomly initialised weights was trained from scratch. This model was chosen because it is the simplest model in terms of the number of trainable weights. Figure 11.1 shows the learning curves for 50 epochs.
So even this powerful model could not identify cats from Montenegro. At this point, we can stop training and fine-tuning models. The last approach would be to extend the training and test set with images of cats from Montenegro. Unfortunately, 61 images of cats from Montenegro are too few to split and add to 17500.
The answer to the question of what AI would think about cats in Montenegro is that cats from Montenegro are special! Almost all models trained from scratch think that these cats are dogs. All, and I mean all, models have a hard time classifying cats of Montenegro as opposed to cats in the test set.
The analysis shows the importance of an appropriate validation set. Using a validation set from a different source than the training and test sets is essential for evaluating the generalization and practical use of a trained model.
Check out my GitHub repo. There you will find all the analysis in Jupyter notebooks and the source code.
[1] PyTorch. Accessed April 2024.
[2] Scikit-learn: Machine Learning in Python. Accessed April 2024.
[4] sklearn: train test split. Accessed April 2024.
[5] C. Albon. 2018. Python Machine Learning Cookbook. First Edition. O'Reilly.
[6] A. C. Müller, S. Guido. 2016. Introduction to Machine Learning with Python. First Edition. O'Reilly.
[7] D. Sarkar, R. Bali, T. Sharma. 2018 Practical Machine Learning with Python. Apress. p. 288-289.[8] Charm Myae Zaw, Gwenn Tan, Cofycharm. (2022, November). Cats and dogs, Version 1, Retrieved 2023, January from Kaggle
[9] Pang Wei Koh, Shiori Sagawa, Henrik Marklund, Sang Michael Xie, Marvin Zhang, Akshay Balsubramani, Weihua Hu, Michihiro Yasunaga, Richard Lanas Phillips, Irena Gao, et al. WILDS: A benchmark of in-the-wild distribution shifts. arXiv:2012.07421, 2020. Accessible on https://arxiv.org/abs/2012.07421.
[10] D. Hendrycks and T. Dietterich. Benchmarking Neural Network Robustness to Common Corruptions and Perturbations. arXiv:1903.12261, 2019. Accessible on https://arxiv.org/abs/1903.12261.
[11] Mingle Xu, Sook Yoon, Alvaro Fuentes, Dong Sun Park. A Comprehensive Survey of Image Augmentation Techniques for Deep Learning. Pattern Recognition. Volume 137. 2023.
[12] PyTorch models and pretrained weights. Accessed April 2024.
[13] ImageNet. Accessed April 2024
[14] Tan, M. and Le, Q.V. EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. arXiv:1905.11946, 2019. Accessible on https://arxiv.org/abs/1905.11946.
[15] G. Huang, Z. Liu, L. van der Maaten and K. Q. Weinberger. Densely Connected Convolutional Networks. arXiv:1608.06993, 2018. Accessible on https://arxiv.org/abs/1608.06993.
[16] K. He, X. Zhang, S. Ren, and J. Sun. Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015. Accessible on https://arxiv.org/abs/1512.03385.
[17] Z. Liu, H. Mao, C.Y. Wu, C. Feichtenhofer, T. Darrell, S. Xie. A ConvNet for the 2020s. arXiv:2201.03545, 2022. Accessible on https://arxiv.org/abs/2201.03545.
We use only strictly necessary cookies. By continuing to browse the site, you consent to the use of cookies and our privacy policy.
You can find more information about cookies and our privacy in our Privacy Policy.
Copyright © 2022-Present whatwouldaithink.com - All Rights Reserved