top of page
logo_edited.png
logo_edited.png
logo_edited.png
logo_edited.png
Picture1.png
Picture2.png
Picture3.png
image.png
circle.png

Multi-Class Convolutional Neural Network (CNN)

Computer Vision (AI)

Role: Lead Data Scientist

Duration: Ongoing 2024 

Process Deck 🚧

Contributor: Tyler Gustafson

Overview

This project utilizes a Multi-Label Convolutional Neural Network (CNN) to identify and classify Pokémon based on images. The model is trained on a diverse dataset of Pokémon images, allowing it to predict multiple labels such as type and evolutionary stage. The goal is to achieve high accuracy in multi-label classification, showcasing the potential of CNNs in image recognition tasks.​

 

Our key question:

Can we build a robust model capable of accurately identifying different Pokémon species based solely on image data?

Why is it important?

Screenshot 2024-07-25 115922.png
Picture1.png

What is computer vision (AI)?

Computer vision is a branch of artificial intelligence (AI) that helps computers understand and make decisions based on visual data, like images and videos. It involves creating algorithms and models that enable machines to analyze and interpret visual information, similar to how humans do. In our case, we are going to apply a Convolutional Neural Network (CNN) architecture.

To ground ourselves here, remember how a CNN works: we have hidden layers where numeric grid-like filters or kernels are going over the images to extract features and learn different weights. These weights are updated through a process called backpropagation, where the model uses long strings of calculus to fine-tune the weights and minimize the difference between its predictions and the actual results by adjusting the weights. Then, we have the fully connected layer that’s used for classification. We will go into more detail later about the actual configuration of our model. 

2024-07-27_10-25-17.jpg

Multi-Label Convolutional Neural Network (CNN) Model

Key Question: Can we develop a model that accurately identifies a Pokémon from its image?

How does the model "learn" weights?

This is where the heavy math comes in (feel free to skip this section if you prefer to stay high-level). As I mentioned to key to "learning" is all done through backpropagation, which involves long complex sequences of partial derivatives.

Here are the steps the model goes through during the training process:

1

Error Calculation: First, we determine which loss metric we want to use; in our case, it might be a function like Mean Squared Error (MSE).

image.png

2

Forward Pass: Next, we flow our image data forward through the network from input to output as the initial training prediction is made.

2024-07-27_10-25-17.jpg
image.png

3

image.png

Backpropagation: Now, we flow backward through the network, adjusting the weights of our nodes and filters at each connection point based on the MSE loss. This process involves calculating the gradient of the loss function using partial derivatives and the chain rule (to the right)

Where                         and         is the weighted sum before activation.

image.png
image_edited.png
image.png

Example... this calculation gets very long very quickly

4

Weight Update: The weights are then updated to reduce the loss using a method like gradient descent. Gradient descent works by moving the weights in the direction that decreases the loss, with the learning rate (𝜂) controlling the size of each update step.

image.png

Essentially, what we’re trying to do is find the global minimum of our loss function, even though we don’t know exactly what this “map” looks like. The two charts illustrate this concept: the one on the left is in a simple 2D plane, while the one on the right is a 3D plane. But in reality, our model is operating in even higher dimensions, far beyond what we can easily visualize.

image_edited.png

Example of 2-Dimensional Gradient Descent

image.png

SOURCE IMAGE: Mohit Mishra on Medium

A common analogy is this: imagine you’re walking down a mountain at night, trying to find the lowest point (the global minimum), but you can’t see the path. You have to feel your way down, step by step, adjusting your direction based on the slope beneath your feet. This is similar to what our model does during training—it makes adjustments based on the gradients (like the slope) to gradually find the optimal weights that minimize the loss, all without having a clear view of the entire landscape.

This is just a high-level overview of the forward pass, error calculation, and how that error is backpropagated to update the weights using calculus. As we add more hidden layers and increase the number of fully connected nodes, this process becomes even more complex. We also have to consider additional factors like bias and activation functions, which further complicate the calculations. These are heavy, calculation-driven processes that provide the precision needed, with equations that grow in complexity and length as the network architecture becomes more intricate.

Okay stepping back from the extra technical pieces of a CNN... in the following case write up I'm going to go through the approach, talking a little bit about exploratory data analysis, pre-processing and augmentation, and the model itself. We'll then examine the learned features and generalization, taking a look at how our models perform. We've actually built four different models: one that identifies individual Pokémon, one by Pokémon type (Fire, Ghost, etc.), one by evolution group (e.g. Charmander, Chameleon, and Charizard - combined together as one), and one that determines whether or not the image is a legendary Pokémon.

Approach

1

2

3

4

5

6

7

image.png
image.png
image.png
image.png
image.png
image.png
image.png

Data Collection

Exploratory Data Analysis (EDA)

Data Preprocessing & Augmentation

Naive Baseline Model & CNN Training

Hyperparameter Tuning

Evaluation & Learned Parameters

Generalization

  • Model 1: Individual Pokémon (82%)

  • Model 2: By Pokémon Type (82%)

  • Model 3: By Evolution Group (84%)

  • Model 4: By Legendary Pokémon (98%)

1 | Data Collection

To kick off the data collection process, it all started with a small original dataset containing just one image of each Pokémon. From there, additional data was gathered from two Kaggle datasets, along with images scraped using the Google API, grabbing the first 15 images for each Pokémon. Data was also pulled from official Pokémon websites. After pulling everything together, a thorough preliminary preprocessing step was conducted to clean up the dataset by removing unsupported files and poor-quality images.

 

This resulted in a large and diverse database of over 28,000 Pokémon images, captured from multiple angles to create a solid foundation for training the machine learning model.

image.png
27.png

2 | Exploratory Data Analysis (EDA)

To initiate our study, we conducted extensive exploratory work to understand class balance, image dimensions, image sizes, and other pertinent characteristics. The details of this analysis are available in the project's GitHub repository.

 

However, the key understanding we want after this foundational data analysis is to determine the suitability of our image dataset, in order to do this we need to address two critical questions:

image.png

Image Sample (Abra)

image.png

Does using color matter?

image.png

Are the classes really that different?

image.png

It's clear that color really stands out as a key feature in our dataset

  • This is crucial because, if color did not provide meaningful differentiation, it could introduce noise and unnecessarily increase the dimensionality and complexity of our model.

  • Holds true also for our other scenarios such as Pokémon type (Fire, Ghost, Poison, etc.)

It appears that we can see quite a difference. Here's what we're doing:

  • PCA* Transformation reduces the image dimensions to 50 components

  • t-SNE* Transformation further reduces the image dimensions to 2

  • Clusters: Images from the same class will cluster together if they are similar. Distinct clusters indicate that images of different classes have unique features.

We can see that our image set likely has enough distinction to accurately predict classes.

image.png

* PCA: Principle Component Analysis

*  t-SNE: t-distributed Stochastic Neighbor Embedding

3 | Data Preprocessing & Augmentation

Before diving into model training, it's important to set up our data correctly to give the model the best chance to learn effectively. TensorFlow has a special function called image_dataset_from_directory that helps ingest large image datasets efficiently and optimally for TensorFlow.

 

In the model code, several key processes are happening:

  • Initial Preprocessing: Images are being standardized to a consistent size of 128x128 pixels.

  • Train/Validation/Test Split: Images are randomly assigned to Train, Validation, and Test groups in a 60%/20%/20% ratio for modeling.

  • Batching: To further optimize processing, images within each group (e.g., Train) are placed into batches.

 

Data Augmentation
Initially, I performed data augmentation in real-time as Keras processed the data, but this led to the validation curve outperforming the training loss curve, which isn’t ideal. Now, we're doing the augmentation before compiling and training the model. This helps the model learn better by seeing varied versions of the data during training. By exposing the model to different perspectives and distortions, we encourage it to become more adaptable and capable of recognizing patterns, even when they appear in new or unexpected ways.

image.png
  • Rescaling: We rescale pixel values to the [0,1] range to standardize the input.

  • RandomFlip: Images are randomly flipped horizontally to help the model learn from different orientations.

  • RandomRotation: Images are rotated randomly between -40 and +40 degrees, adding variability that helps the model become more robust during training.

 

This approach ensures that our model gets the most out of the data, learning to generalize better across different scenarios.

59.png

4 | The CNN Model

Let's kick off by compiling the model structure with a series of hidden layers, primarily composed of convolutional blocks using TensorFlow / Keras. 

image.png
image.png
image.png

These blocks are designed to extract and learn various features from the input images, which is key to how the model understands and processes visual data.

1. Convolutional Blocks

  • Filter count: 32–256

  • Filter size: 3x3

  • Activation: ReLU/tanh

  • Includes layers for Batch Normalization and Max pooling

2. Fully Connected Block / Output Layer

  • Flatten: 3D feature maps to 1D vector

  • Dense: 512 units

  • Dropout: 0.5 for regularization

  • Activation: softmax (to convert outputs to a probability distribution for each class representing the model's predictions)

From here, the next step will be to hyper-tune parameters to optimize the model's performance and ensure it accurately classifies the Pokémon images across the dataset. This fine-tuning will help achieve the best possible results from the model.

151.png

5 | Training & Hyperparameter Tuning

Now we can begin training our model as I mentioned the next step is to begin to extensively tune hyperparameters, testing various combinations of:

(1) Layer configurations, (2) Filter sizes, (3) Optimizers, (4) Activations, (5) Learning Rate we were able to determine a desirable architecture.

Below and to the right is a sample of various configurations (ultimately chose Approach F's structure)

image.png
image.png
image.png
image.png

6 | Evaluation & Learned Parameters

So now let's visualize what's happening under the hood here. There are two parts we're showing here.
 

On the far left, we have an example of the learned filter weights at the first layer. If you remember, our first layer had 32 filters, each of which was 3x3. Now, we can see these learned number weights for each of those squares represented visually with the shadings. These learned filters are used to extract features.


Remember, these are actually 3x3x3 filters since we applied color (RGB) making it three-dimensional. However, for visualization, it’s shown in grayscale using just the red channel, which allows us to see the filter patterns that the model is using to process and extract these initial image features.

Convolutional Layer Feature Maps

image.png
image.png

As images pass through the CNN, the layers progressively learn increasingly complex features, from basic edges and textures in early layers to intricate patterns and object parts in deeper layers.

Now, if we move to the main section of this slide, let’s take this image of Psyduck and pass it through each of the convolutional layers and see how these different filters are extracting features. Remember as we go across the layers, more complex features and patterns are being pulled out through the deeper layers.

In the first layer, you can see these simpler features being extracted—it’s kind of hard to see with the darker ones, but the point is that it’s really pulling out edges here. You can see outlines starting to emerge. Then, when go deeper into the layers, we start seeing more complex patterns, which might become unrecognizable to us as humans. But that’s the point of the machine learning algorithm: to pull out those patterns that we might not recognize. This is a snapshot of the features that our model is extracting as it goes into the actual classification —that was highlighted earlier.

103.png

7 | Generalization

Now we will visit the different scenarios we defined earlier to understand how they impact model performance. Each of these has different use cases that can be helpful in different ways once they are deployed. Below is how each of the models performed with validation data.

  • Model 1: Individual Pokemon

  • Model 2: By Pokemon Type

  • Model 3: By Evolution Group

  • Model 4: By Legendary Pokemon

image.png

Let’s wrap things up by testing how well these different model structures handle new, unseen data, which will give us a sense of their generalization abilities.

The model hits 82% accuracy across 151 different Pokémon classes, showing it can handle a pretty wide and varied dataset. When we group Pokémon by their evolution stages, accuracy bumps up by 2% to 84%.

 

It’s interesting that sorting by Pokémon type doesn’t really make a difference in performance, even though there are fewer classes to work with. And, as expected, the model nails it with legendary Pokémon, reaching an impressive 98% accuracy.

Model 1

(Full List - Individual Pokemon)

Prediction Confidence

image.png

Model test performance for various class structure

image.png

Model 4

(Legendary Status - Articuno, Zapdos, etc.)

image.png

Prediction Confidence

Input

In the image on the left, we have an example of our first model making predictions for an individual Pokémon. The model provides a percentage for each of the 151 classes representing its confidence level. For this particular Pokémon, the model's highest confidence prediction is 30.6%, which turns out to be correct. This shows that even though the confidence isn't overwhelming, the model is still able to make the right call.

On the right, we see a different scenario where our fourth model is tasked with predicting whether an image features a legendary Pokémon or not. Here, the model is much more confident, selecting the correct prediction with an 83.1% confidence level. This highlights how the model is more decisive when dealing with binary classifications, like identifying legendary status, compared to the more complex task of predicting individual Pokémon.

43.png

Discussion

To improve our model going forward, we could start by increasing the image dataset, adding more variability to the images, and including scenarios with multiple Pokémon in a single image. Expanding the number of classes by introducing the next generation of Pokémon, beyond the initial 151, would also enhance the model's ability to generalize.

 

Additionally, we could explore transfer learning by leveraging pre-trained models like VGG or ResNet, and then fine-tuning them on our specific dataset.

 

Another approach could be to build custom models tailored for each type of class predictor. For instance, something seems off with the Pokémon type predictor, as it should theoretically perform better than individual Pokémon predictors, given the smaller class set. Addressing this discrepancy might lead to further improvements.

🚧  Stay tuned for updates as the model's performance continues to improve!

Check out the code and more on my GitHub

GitHub-logo.png
bottom of page