Introduction to Image Classification
Image classification is the task of assigning a label to an image from a predefined set of categories. It forms the foundation of many computer vision applications, from identifying objects in photos to medical diagnosis from X-rays. Understanding how to build robust classifiers is essential for any AI practitioner working with visual data.
What is Image Classification?
Image classification is a supervised learning task where a model learns to categorize images into predefined classes based on their visual content. The model takes an image as input and outputs a probability distribution over all possible classes.
Why it matters: Image classification powers applications like photo organization, content moderation, medical imaging, autonomous vehicles, and quality control in manufacturing.
Types of Classification Problems
Image classification problems come in different forms depending on the number of classes and whether images can belong to multiple categories. Understanding these distinctions helps you choose the right model architecture and loss function for your specific problem.
Binary Classification
Two classes only (e.g., cat vs dog, spam vs not spam). Uses sigmoid activation and binary cross-entropy loss.
Multi-class Classification
Multiple mutually exclusive classes (e.g., digits 0-9). Uses softmax activation and categorical cross-entropy.
Multi-label Classification
Multiple labels per image (e.g., beach, sunset, people). Uses sigmoid activation and binary cross-entropy per label.
The Classification Pipeline
Building an image classifier involves a series of well-defined steps. From collecting and preprocessing data to training and deploying the model, each stage requires careful consideration to achieve optimal results.
# High-level image classification pipeline
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# Step 1: Load and preprocess data
train_dataset = keras.utils.image_dataset_from_directory(
'data/train',
image_size=(224, 224),
batch_size=32
)
# Step 2: Normalize pixel values
normalization = layers.Rescaling(1./255)
train_dataset = train_dataset.map(lambda x, y: (normalization(x), y))
# Step 3: Build model (simplified)
model = keras.Sequential([
layers.Conv2D(32, 3, activation='relu', input_shape=(224, 224, 3)),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(10, activation='softmax')
])
# Step 4: Compile and train
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(train_dataset, epochs=10)
Popular Datasets for Practice
Learning image classification is best done with well-known benchmark datasets. These datasets have been extensively studied, making it easy to compare your results with established benchmarks and learn from existing solutions.
| Dataset | Classes | Images | Image Size | Best For |
|---|---|---|---|---|
| MNIST | 10 | 70,000 | 28x28 (grayscale) | Beginners, quick experiments |
| CIFAR-10 | 10 | 60,000 | 32x32 (color) | Basic CNN development |
| CIFAR-100 | 100 | 60,000 | 32x32 (color) | Fine-grained classification |
| ImageNet | 1,000 | 14M+ | Variable (usually 224x224) | Transfer learning, research |
| Fashion-MNIST | 10 | 70,000 | 28x28 (grayscale) | MNIST alternative |
Loading Built-in Datasets
TensorFlow and Keras provide easy access to popular datasets. These datasets are automatically downloaded and cached, making it simple to start experimenting with classification models immediately.
import tensorflow as tf
from tensorflow.keras.datasets import cifar10, mnist, fashion_mnist
# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print(f"CIFAR-10 Training: {x_train.shape}, Labels: {y_train.shape}")
print(f"CIFAR-10 Test: {x_test.shape}, Labels: {y_test.shape}")
# CIFAR-10 class names
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
# Normalize pixel values to [0, 1]
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
# Display sample images
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, ax in enumerate(axes.flat):
ax.imshow(x_train[i])
ax.set_title(class_names[y_train[i][0]])
ax.axis('off')
plt.tight_layout()
plt.show()
Practice: Classification Basics
Task: Load the Fashion-MNIST dataset, print its shape, and display the first 10 images with their class labels.
Show Solution
from tensorflow.keras.datasets import fashion_mnist
import matplotlib.pyplot as plt
# Load dataset
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# Class names for Fashion-MNIST
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
print(f"Training shape: {x_train.shape}")
print(f"Test shape: {x_test.shape}")
# Display first 10 images
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, ax in enumerate(axes.flat):
ax.imshow(x_train[i], cmap='gray')
ax.set_title(class_names[y_train[i]])
ax.axis('off')
plt.suptitle("Fashion-MNIST Samples")
plt.tight_layout()
plt.show()
Task: Load CIFAR-10, count the number of images per class in the training set, and create a bar chart showing the distribution.
Show Solution
from tensorflow.keras.datasets import cifar10
import matplotlib.pyplot as plt
import numpy as np
# Load dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
# Count images per class
unique, counts = np.unique(y_train, return_counts=True)
class_counts = dict(zip(unique, counts))
# Create bar chart
plt.figure(figsize=(10, 6))
plt.bar(class_names, counts, color='steelblue')
plt.xlabel('Class')
plt.ylabel('Number of Images')
plt.title('CIFAR-10 Training Set Class Distribution')
plt.xticks(rotation=45)
for i, count in enumerate(counts):
plt.text(i, count + 100, str(count), ha='center')
plt.tight_layout()
plt.show()
print(f"Total training images: {len(y_train)}")
print(f"Images per class: {counts[0]} (balanced dataset)")
Task: Write a function that loads images from a directory structure (class_name/image.jpg), resizes them to 128x128, and returns normalized numpy arrays with labels.
Show Solution
import os
import numpy as np
from PIL import Image
from pathlib import Path
def load_custom_dataset(data_dir, image_size=(128, 128)):
"""Load images from directory with class subfolders."""
images = []
labels = []
class_names = sorted(os.listdir(data_dir))
class_to_idx = {name: idx for idx, name in enumerate(class_names)}
for class_name in class_names:
class_path = Path(data_dir) / class_name
if not class_path.is_dir():
continue
for img_path in class_path.glob('*.jpg'):
try:
img = Image.open(img_path).convert('RGB')
img = img.resize(image_size)
img_array = np.array(img) / 255.0 # Normalize
images.append(img_array)
labels.append(class_to_idx[class_name])
except Exception as e:
print(f"Error loading {img_path}: {e}")
return np.array(images), np.array(labels), class_names
# Usage example
# x_data, y_data, classes = load_custom_dataset('my_dataset/')
# print(f"Loaded {len(x_data)} images from {len(classes)} classes")
CNN Architecture
Convolutional Neural Networks (CNNs) are the backbone of modern image classification. They use specialized layers designed to automatically learn spatial hierarchies of features, from simple edges to complex patterns. Understanding CNN architecture is crucial for building effective image classifiers.
Convolutional Neural Network (CNN)
A CNN is a deep learning architecture designed specifically for processing grid-like data such as images. It uses convolutional layers to automatically learn spatial feature hierarchies, reducing the need for manual feature engineering.
Key Innovation: CNNs use parameter sharing (same filter applied across the image) and local connectivity, making them efficient and effective at detecting patterns regardless of their position in the image (translation invariance).
CNN Building Blocks
A typical CNN consists of several types of layers, each serving a specific purpose in the feature extraction and classification pipeline. Understanding these components is crucial for designing effective architectures.
Convolutional Layer
Applies learnable filters to extract local features. Each filter detects specific patterns like edges, textures, or shapes.
Conv2D(filters=32, kernel_size=3)
Pooling Layer
Reduces spatial dimensions while retaining important features. Provides translation invariance and reduces computation.
MaxPooling2D(pool_size=2)
Flatten Layer
Converts 2D feature maps to 1D vector, connecting convolutional layers to fully connected layers for classification.
Flatten()
Dense Layer
Fully connected layers that perform classification based on the extracted features. Final layer outputs class probabilities.
Dense(units=10, activation='softmax')
Understanding Convolution Operations
The convolution operation slides a small filter (kernel) across the input image, computing element-wise multiplications and summing the results. This process extracts local patterns while preserving spatial relationships.
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
# Create a simple 5x5 grayscale "image"
image = np.array([
[0, 0, 1, 0, 0],
[0, 1, 1, 1, 0],
[1, 1, 1, 1, 1],
[0, 1, 1, 1, 0],
[0, 0, 1, 0, 0]
], dtype=np.float32)
# Reshape for Conv2D: (batch, height, width, channels)
image = image.reshape(1, 5, 5, 1)
# Define an edge detection filter
edge_filter = np.array([
[-1, -1, -1],
[-1, 8, -1],
[-1, -1, -1]
], dtype=np.float32).reshape(3, 3, 1, 1)
# Apply convolution manually
conv_layer = layers.Conv2D(1, kernel_size=3, padding='same', use_bias=False)
conv_layer.build((None, 5, 5, 1))
conv_layer.set_weights([edge_filter])
output = conv_layer(image)
print("Input shape:", image.shape)
print("Output shape:", output.shape)
print("\nEdge detection result:")
print(output.numpy().squeeze())
Key CNN Parameters
Understanding the parameters that control convolutional layers is essential for designing effective networks. These parameters directly impact the model's ability to learn useful features and its computational requirements.
| Parameter | Description | Typical Values | Impact |
|---|---|---|---|
| Filters | Number of feature detectors | 32, 64, 128, 256 | More filters = more features learned |
| Kernel Size | Size of the sliding window | 3x3, 5x5, 7x7 | Larger = captures bigger patterns |
| Stride | Step size when sliding filter | 1, 2 | Larger stride = smaller output |
| Padding | How to handle borders | 'valid', 'same' | 'same' preserves dimensions |
| Activation | Non-linearity function | ReLU, LeakyReLU | Enables learning complex patterns |
Feature Hierarchy in CNNs
One of the most powerful aspects of CNNs is their ability to learn hierarchical representations. Early layers detect simple patterns, while deeper layers combine these to recognize complex objects.
# Visualizing feature maps at different layers
from tensorflow.keras import Model
from tensorflow.keras.datasets import cifar10
import matplotlib.pyplot as plt
# Build a simple CNN
model = tf.keras.Sequential([
layers.Conv2D(32, 3, activation='relu', padding='same', input_shape=(32, 32, 3)),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, activation='relu', padding='same'),
layers.MaxPooling2D(),
layers.Conv2D(128, 3, activation='relu', padding='same'),
layers.Flatten(),
layers.Dense(10, activation='softmax')
])
# Load sample image
(x_train, _), _ = cifar10.load_data()
sample = x_train[0:1] / 255.0
# Create models to extract intermediate outputs
layer_outputs = [layer.output for layer in model.layers if 'conv' in layer.name]
feature_model = Model(inputs=model.input, outputs=layer_outputs)
# Get feature maps
features = feature_model.predict(sample)
# Display feature maps from first conv layer
fig, axes = plt.subplots(4, 8, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
if i < features[0].shape[-1]:
ax.imshow(features[0][0, :, :, i], cmap='viridis')
ax.axis('off')
plt.suptitle("Feature Maps - First Conv Layer (Edge Detection)")
plt.tight_layout()
plt.show()
Pooling Operations Explained
Pooling reduces the spatial dimensions of feature maps, making the network more efficient and providing some degree of translation invariance. The two most common types are max pooling and average pooling.
# Comparing Max Pooling vs Average Pooling
import numpy as np
from tensorflow.keras import layers
# Sample feature map
feature_map = np.array([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]
], dtype=np.float32).reshape(1, 4, 4, 1)
# Max Pooling (takes maximum value in each window)
max_pool = layers.MaxPooling2D(pool_size=2)
max_output = max_pool(feature_map)
# Average Pooling (takes average value in each window)
avg_pool = layers.AveragePooling2D(pool_size=2)
avg_output = avg_pool(feature_map)
print("Original (4x4):")
print(feature_map.squeeze())
print("\nMax Pooling (2x2):")
print(max_output.numpy().squeeze())
print("\nAverage Pooling (2x2):")
print(avg_output.numpy().squeeze())
# Output:
# Max Pooling: Average Pooling:
# [[ 6, 8], [[3.5, 5.5],
# [14, 16]] [11.5, 13.5]]
Practice: CNN Architecture
Task: Given a 224x224x3 input image, calculate the output shape after applying Conv2D(64, kernel_size=3, padding='same') followed by MaxPooling2D(pool_size=2).
Show Solution
# Output dimension calculation
from tensorflow.keras import layers, Sequential
model = Sequential([
layers.Conv2D(64, 3, padding='same', input_shape=(224, 224, 3)),
layers.MaxPooling2D(pool_size=2)
])
model.summary()
# Manual calculation:
# After Conv2D (padding='same'): 224 x 224 x 64
# After MaxPooling2D (pool_size=2): 112 x 112 x 64
#
# Formula for 'valid' padding: output = (input - kernel + 1) / stride
# Formula for 'same' padding: output = input / stride
# MaxPooling: output = input / pool_size
Task: Create a CNN with 3 convolutional blocks (Conv2D + ReLU + MaxPooling), doubling filters at each block (32→64→128), for 32x32 RGB input and 10 output classes.
Show Solution
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
model = Sequential([
# Block 1: 32 filters
Conv2D(32, 3, activation='relu', padding='same', input_shape=(32, 32, 3)),
MaxPooling2D(pool_size=2), # Output: 16x16x32
# Block 2: 64 filters
Conv2D(64, 3, activation='relu', padding='same'),
MaxPooling2D(pool_size=2), # Output: 8x8x64
# Block 3: 128 filters
Conv2D(128, 3, activation='relu', padding='same'),
MaxPooling2D(pool_size=2), # Output: 4x4x128
# Classification head
Flatten(), # 4*4*128 = 2048
Dense(256, activation='relu'),
Dense(10, activation='softmax')
])
model.summary()
print(f"\nTotal parameters: {model.count_params():,}")
Task: Write a pure NumPy function that performs 2D convolution on a grayscale image with a given kernel. Handle 'valid' padding (no zero-padding).
Show Solution
import numpy as np
def convolve2d(image, kernel):
"""
Perform 2D convolution with 'valid' padding.
Args:
image: 2D numpy array (H, W)
kernel: 2D numpy array (kH, kW)
Returns:
2D numpy array of shape (H-kH+1, W-kW+1)
"""
img_h, img_w = image.shape
ker_h, ker_w = kernel.shape
out_h = img_h - ker_h + 1
out_w = img_w - ker_w + 1
output = np.zeros((out_h, out_w))
for i in range(out_h):
for j in range(out_w):
region = image[i:i+ker_h, j:j+ker_w]
output[i, j] = np.sum(region * kernel)
return output
# Test with edge detection
image = np.array([
[10, 10, 10, 0, 0, 0],
[10, 10, 10, 0, 0, 0],
[10, 10, 10, 0, 0, 0],
[10, 10, 10, 0, 0, 0],
[10, 10, 10, 0, 0, 0]
], dtype=np.float32)
sobel_x = np.array([
[-1, 0, 1],
[-2, 0, 2],
[-1, 0, 1]
], dtype=np.float32)
result = convolve2d(image, sobel_x)
print("Vertical edge detection result:")
print(result)
Building Classification Models
Building an image classifier involves several key steps: preparing your dataset, designing the model architecture, configuring training parameters, and implementing proper data augmentation. In this section, you will build a complete image classifier from scratch using TensorFlow and Keras.
The Model Building Pipeline
Building a robust image classifier requires a systematic approach: data loading and preprocessing, model architecture design, compilation with appropriate loss and optimizer, training with callbacks, and evaluation on held-out test data.
Success Factors: Good data quality, appropriate architecture complexity, sufficient training data, and proper regularization techniques are key to building classifiers that generalize well to new images.
Step 1: Data Loading and Preprocessing
Proper data preparation is crucial for training effective models. This includes loading images, resizing to consistent dimensions, normalizing pixel values, and creating train/validation splits.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
# Method 1: Load from directory structure
train_ds = keras.utils.image_dataset_from_directory(
'data/train',
image_size=(224, 224),
batch_size=32,
validation_split=0.2,
subset='training',
seed=42
)
val_ds = keras.utils.image_dataset_from_directory(
'data/train',
image_size=(224, 224),
batch_size=32,
validation_split=0.2,
subset='validation',
seed=42
)
# Get class names
class_names = train_ds.class_names
print(f"Classes: {class_names}")
print(f"Number of training batches: {len(train_ds)}")
print(f"Number of validation batches: {len(val_ds)}")
# Optimize dataset performance with prefetching
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
This code demonstrates loading images from a directory structure using image_dataset_from_directory(), which automatically infers class labels from folder names. The validation_split and subset parameters create train/validation splits from the same directory. The seed=42 ensures reproducible splits. Finally, we optimize performance using cache() (keeps data in memory after first load), shuffle() (randomizes order each epoch), and prefetch() (loads next batch while GPU processes current batch).
Step 2: Data Augmentation
Data augmentation artificially increases the diversity of your training data by applying random transformations. This helps prevent overfitting and improves model generalization, especially when you have limited training data.
# Create data augmentation layer
data_augmentation = keras.Sequential([
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomZoom(0.1),
layers.RandomContrast(0.1),
], name='data_augmentation')
# Visualize augmentation effects
import matplotlib.pyplot as plt
# Get a sample batch
for images, labels in train_ds.take(1):
plt.figure(figsize=(12, 6))
# Original image
plt.subplot(2, 4, 1)
plt.imshow(images[0].numpy().astype('uint8'))
plt.title("Original")
plt.axis('off')
# Augmented versions
for i in range(7):
augmented = data_augmentation(tf.expand_dims(images[0], 0))
plt.subplot(2, 4, i + 2)
plt.imshow(augmented[0].numpy().astype('uint8'))
plt.title(f"Aug {i+1}")
plt.axis('off')
plt.tight_layout()
plt.show()
We create a Sequential model containing augmentation layers that apply random transformations: RandomFlip mirrors images horizontally, RandomRotation(0.1) rotates up to 10% of a full circle (±36°), RandomZoom(0.1) zooms in/out by up to 10%, and RandomContrast adjusts brightness variation. The visualization loop shows the original image alongside 7 augmented versions, demonstrating how each training example can appear different every epoch—effectively multiplying your dataset size and teaching the model to recognize objects regardless of orientation or lighting.
Step 3: Building the Model Architecture
Now let's build a complete CNN classifier with all the components: normalization, augmentation, convolutional blocks, and the classification head with dropout for regularization.
# Complete CNN classifier with all components
def create_classifier(input_shape, num_classes):
"""Build a CNN image classifier with best practices."""
inputs = keras.Input(shape=input_shape)
# Normalization layer
x = layers.Rescaling(1./255)(inputs)
# Data augmentation (active only during training)
x = data_augmentation(x)
# Block 1
x = layers.Conv2D(32, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D()(x)
# Block 2
x = layers.Conv2D(64, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D()(x)
# Block 3
x = layers.Conv2D(128, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D()(x)
# Block 4
x = layers.Conv2D(256, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.GlobalAveragePooling2D()(x) # Better than Flatten
# Classification head with dropout
x = layers.Dropout(0.5)(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(num_classes, activation='softmax')(x)
return keras.Model(inputs, outputs, name='image_classifier')
# Create model
model = create_classifier(input_shape=(224, 224, 3), num_classes=10)
model.summary()
This architecture follows a proven pattern: Rescaling(1./255) normalizes pixel values from [0-255] to [0-1]. Each convolutional block uses Conv2D→BatchNormalization→ReLU→Pooling, progressively increasing filters (32→64→128→256) while spatial dimensions decrease via pooling. BatchNormalization stabilizes training by normalizing activations. GlobalAveragePooling2D replaces Flatten by averaging each feature map to a single value—reducing parameters and overfitting. The classification head uses Dropout(0.5) and Dropout(0.3) to randomly disable neurons during training, forcing the network to learn redundant representations.
Step 4: Compiling the Model
Compilation configures the model for training by specifying the optimizer, loss function, and metrics. Choosing the right combination is crucial for effective training.
| Component | Common Choices | When to Use |
|---|---|---|
| Optimizer | Adam, SGD with momentum | Adam for most cases; SGD for fine-tuning |
| Loss (Multi-class) | categorical_crossentropy, sparse_categorical_crossentropy | sparse_ when labels are integers |
| Loss (Binary) | binary_crossentropy | Two classes only |
| Metrics | accuracy, AUC, precision, recall | accuracy for balanced; F1/AUC for imbalanced |
# Compile with learning rate scheduler
initial_learning_rate = 0.001
# Option 1: Simple compilation
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=initial_learning_rate),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Option 2: With learning rate schedule
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=initial_learning_rate,
decay_steps=1000,
decay_rate=0.9
)
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
loss='sparse_categorical_crossentropy',
metrics=['accuracy', keras.metrics.TopKCategoricalAccuracy(k=5)]
)
Option 1 shows basic compilation with Adam optimizer at a fixed learning rate. Option 2 demonstrates ExponentialDecay which automatically reduces the learning rate over time—starting with larger steps for fast initial progress, then smaller steps for fine-tuning. The decay_steps=1000 means every 1000 batches, the rate is multiplied by decay_rate=0.9. We use sparse_categorical_crossentropy because our labels are integers (0, 1, 2...) rather than one-hot vectors. TopKCategoricalAccuracy(k=5) measures if the correct class is in the top 5 predictions—useful for datasets with many similar classes.
Step 5: Training with Callbacks
Callbacks provide powerful functionality during training: saving the best model, reducing learning rate on plateaus, early stopping, and logging to TensorBoard for visualization.
# Define comprehensive callbacks
callbacks = [
# Save best model based on validation accuracy
keras.callbacks.ModelCheckpoint(
'best_model.keras',
monitor='val_accuracy',
save_best_only=True,
mode='max',
verbose=1
),
# Reduce learning rate when validation loss plateaus
keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=3,
min_lr=1e-7,
verbose=1
),
# Stop training if no improvement
keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=10,
restore_best_weights=True,
verbose=1
),
# TensorBoard logging
keras.callbacks.TensorBoard(
log_dir='./logs',
histogram_freq=1
)
]
# Train the model
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=50,
callbacks=callbacks,
verbose=1
)
Four essential callbacks: ModelCheckpoint saves the model whenever val_accuracy improves, ensuring you keep the best version even if training continues past the optimal point. ReduceLROnPlateau halves the learning rate (factor=0.5) if validation loss doesn't improve for 3 epochs—helping escape local minima. EarlyStopping terminates training after 10 epochs without improvement, with restore_best_weights=True loading the best model at the end. TensorBoard logs metrics and histograms for interactive visualization. The model.fit() call trains for up to 50 epochs, but EarlyStopping typically ends it sooner.
Step 6: Visualizing Training Progress
Plotting training history helps diagnose issues like overfitting (training accuracy much higher than validation) or underfitting (both accuracies low).
def plot_training_history(history):
"""Plot training and validation metrics."""
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Accuracy plot
axes[0].plot(history.history['accuracy'], label='Train', linewidth=2)
axes[0].plot(history.history['val_accuracy'], label='Validation', linewidth=2)
axes[0].set_title('Model Accuracy', fontsize=14)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Loss plot
axes[1].plot(history.history['loss'], label='Train', linewidth=2)
axes[1].plot(history.history['val_loss'], label='Validation', linewidth=2)
axes[1].set_title('Model Loss', fontsize=14)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Print final metrics
print(f"\nFinal Training Accuracy: {history.history['accuracy'][-1]:.4f}")
print(f"Final Validation Accuracy: {history.history['val_accuracy'][-1]:.4f}")
plot_training_history(history)
This visualization function creates two subplots: accuracy over epochs (left) and loss over epochs (right). By plotting both training and validation metrics together, you can diagnose model behavior: parallel lines indicate good generalization, diverging lines (training improving, validation stagnating) indicate overfitting. The grid(True, alpha=0.3) adds subtle gridlines for easier reading. Final metrics are printed for quick reference. Healthy training shows both curves improving together, then validation plateauing while training continues—at which point EarlyStopping should trigger.
Practice: Building Models
Task: Build and train a simple CNN on CIFAR-10 for 10 epochs. Use the built-in dataset, normalize the data, and report the final validation accuracy.
Show Solution
from tensorflow.keras.datasets import cifar10
# Load data
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Build simple model
model = keras.Sequential([
layers.Conv2D(32, 3, activation='relu', padding='same', input_shape=(32, 32, 3)),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, activation='relu', padding='same'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
history = model.fit(x_train, y_train, epochs=10,
validation_data=(x_test, y_test))
print(f"\nFinal Validation Accuracy: {history.history['val_accuracy'][-1]:.4f}")
Task: Improve the CIFAR-10 classifier by adding data augmentation layers and batch normalization. Compare the validation accuracy with the simple model.
Show Solution
# Enhanced model with augmentation and batch norm
augmentation = keras.Sequential([
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomZoom(0.1),
])
model_enhanced = keras.Sequential([
layers.InputLayer(input_shape=(32, 32, 3)),
augmentation,
layers.Conv2D(32, 3, padding='same'),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same'),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.MaxPooling2D(),
layers.Conv2D(128, 3, padding='same'),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.GlobalAveragePooling2D(),
layers.Dropout(0.5),
layers.Dense(64, activation='relu'),
layers.Dense(10, activation='softmax')
])
model_enhanced.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
history = model_enhanced.fit(x_train, y_train, epochs=20,
validation_data=(x_test, y_test),
callbacks=[keras.callbacks.EarlyStopping(patience=5)])
print(f"\nEnhanced Model Accuracy: {history.history['val_accuracy'][-1]:.4f}")
Task: Implement a custom training loop using tf.GradientTape instead of model.fit(). Include batch processing, loss calculation, and metric tracking.
Show Solution
import tensorflow as tf
# Create model (without compile)
model = keras.Sequential([
layers.Conv2D(32, 3, activation='relu', padding='same', input_shape=(32, 32, 3)),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(10, activation='softmax')
])
# Define optimizer and loss
optimizer = keras.optimizers.Adam(learning_rate=0.001)
loss_fn = keras.losses.SparseCategoricalCrossentropy()
# Metrics
train_acc = keras.metrics.SparseCategoricalAccuracy()
val_acc = keras.metrics.SparseCategoricalAccuracy()
# Create datasets
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(10000).batch(32)
# Custom training loop
epochs = 5
for epoch in range(epochs):
train_acc.reset_state()
for step, (x_batch, y_batch) in enumerate(train_dataset):
with tf.GradientTape() as tape:
predictions = model(x_batch, training=True)
loss = loss_fn(y_batch, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_acc.update_state(y_batch, predictions)
# Validation
val_predictions = model(x_test, training=False)
val_acc.update_state(y_test, val_predictions)
print(f"Epoch {epoch+1}: Train Acc = {train_acc.result():.4f}, "
f"Val Acc = {val_acc.result():.4f}")
Transfer Learning
Transfer learning allows you to leverage models pretrained on large datasets like ImageNet to solve your specific classification problem. This technique dramatically reduces training time and often achieves better results than training from scratch, especially when you have limited data.
What is Transfer Learning?
Transfer learning is a machine learning technique where a model trained on one task is repurposed as the starting point for a model on a different task. For image classification, this typically means using a network trained on ImageNet (1000 classes, 14M+ images) as a feature extractor for your specific problem.
Key Advantage: Pretrained models have already learned powerful visual features (edges, textures, shapes, objects) that transfer well to new tasks, requiring far less data and training time.
Popular Pretrained Models
Several architectures pretrained on ImageNet are available in Keras. Each offers different trade-offs between accuracy, speed, and model size.
| Model | Parameters | Top-1 Accuracy | Size (MB) | Best For |
|---|---|---|---|---|
| MobileNetV2 | 3.4M | 71.3% | 14 | Mobile deployment, speed |
| ResNet50 | 25.6M | 74.9% | 98 | Good balance accuracy/speed |
| VGG16 | 138M | 71.3% | 528 | Simple, interpretable |
| EfficientNetB0 | 5.3M | 77.1% | 29 | Efficiency, modern arch |
| InceptionV3 | 23.9M | 77.9% | 92 | Multi-scale features |
Transfer Learning Strategies
When using a pretrained model, you have two main choices for how to use it: feature extraction and fine-tuning. Think of it like hiring an expert—you can either let them work exactly as they were trained (feature extraction), or you can slightly retrain them for your specific needs (fine-tuning).
The decision between these approaches depends on two factors: how much data you have and how similar your images are to ImageNet (the dataset these models were originally trained on, containing 1000 everyday objects like dogs, cars, and furniture).
Feature Extraction
What it means: You "freeze" the pretrained model so its weights never change. You only train a small new layer on top that learns to classify your specific categories.
Analogy: Like using a camera that's already perfectly focused—you just point it at your subject and take the photo.
Best when: You have a small dataset (hundreds to a few thousand images), or your images are similar to everyday objects (animals, vehicles, household items).
Advantages: Very fast training (often just minutes), impossible to accidentally ruin the pretrained knowledge, works well even with limited data.
Fine-Tuning
What it means: After training the classification head, you "unfreeze" some of the pretrained layers and continue training with a very small learning rate, allowing the model to adapt its features to your specific data.
Analogy: Like adjusting the focus on a camera—small tweaks to get a sharper image for your specific scene.
Best when: You have a large dataset (tens of thousands of images), or your images are quite different from everyday objects (medical scans, satellite images, microscopy).
Advantages: Can achieve higher accuracy, allows the model to learn features specific to your domain.
Feature Extraction Example
Let's use MobileNetV2 as a feature extractor for a custom 5-class classification problem. We'll freeze all pretrained layers and add our own classification head.
We start by importing the pretrained MobileNetV2 model from Keras applications. Setting include_top=False removes the original 1000-class ImageNet classification layers, giving us access to the powerful feature extraction backbone. The weights='imagenet' parameter loads the pretrained weights.
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras import layers, Model
# Load pretrained model without top layers
base_model = MobileNetV2(
input_shape=(224, 224, 3),
include_top=False, # Remove classification layers
weights='imagenet' # Use pretrained weights
)
By setting base_model.trainable = False, we freeze all layers in MobileNetV2. This means the pretrained weights won't be updated during training—the model acts purely as a fixed feature extractor. This is crucial for the feature extraction approach and makes training much faster.
# Freeze base model
base_model.trainable = False
We use the Keras Functional API to build our model. The input is defined with the expected image dimensions. Each pretrained model has its own preprocessing function that normalizes pixel values appropriately—for MobileNetV2, this scales pixels to the range [-1, 1].
# Build new model with custom classification head
inputs = keras.Input(shape=(224, 224, 3))
# Preprocessing for MobileNetV2
x = keras.applications.mobilenet_v2.preprocess_input(inputs)
We pass the preprocessed input through the frozen base model to extract features. Setting training=False is important because it tells BatchNormalization layers to use their learned statistics rather than batch statistics, ensuring consistent behavior during inference.
# Get features from base model
x = base_model(x, training=False) # training=False for BatchNorm layers
The classification head transforms the extracted features into class predictions. GlobalAveragePooling2D reduces spatial dimensions while preserving feature information. Dropout layers prevent overfitting, and the final Dense layer with softmax outputs probabilities for our 5 custom classes.
# Add classification head
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(5, activation='softmax')(x) # 5 classes
model = Model(inputs, outputs)
Finally, we compile the model with the Adam optimizer and sparse categorical cross-entropy loss (for integer labels). We can use a standard learning rate since we're only training the classification head. The summary shows how many parameters are trainable versus frozen.
# Compile
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
print(f"Total params: {model.count_params():,}")
print(f"Trainable params: {sum([tf.reduce_prod(w.shape) for w in model.trainable_weights]):,}")
Fine-Tuning Example
After training the classification head, we can unfreeze the top layers of the base model and continue training with a lower learning rate. This allows the model to adapt the pretrained features to your specific task.
The first step is to train the model with the base model frozen. This allows the randomly initialized classification head to learn meaningful weights without disrupting the pretrained features. We typically train for 5-10 epochs until the validation accuracy stabilizes.
# First: Train with frozen base (feature extraction)
print("Phase 1: Training classification head...")
history1 = model.fit(train_ds, validation_data=val_ds, epochs=10)
After the classification head is trained, we can begin fine-tuning. First, we set base_model.trainable = True to unfreeze all layers. Then, we selectively re-freeze the early layers by iterating through base_model.layers[:-20]. Early layers contain generic features (edges, textures) that transfer well to any task, so we keep them frozen.
# Second: Unfreeze top layers for fine-tuning
print("\nPhase 2: Fine-tuning top layers...")
base_model.trainable = True
# Freeze all layers except the last 20
for layer in base_model.layers[:-20]:
layer.trainable = False
# Check trainable layers
trainable_layers = [l.name for l in base_model.layers if l.trainable]
print(f"Fine-tuning {len(trainable_layers)} layers")
After changing which layers are trainable, we must recompile the model. The critical change here is the learning rate—we use 1e-5 instead of the typical 1e-3. This 100x reduction prevents the optimizer from making large updates that could destroy the carefully learned pretrained features.
# Recompile with lower learning rate (important!)
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-5), # 10x lower
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
Finally, we continue training with the unfrozen layers. The EarlyStopping callback monitors validation loss and stops training if it doesn't improve for 3 consecutive epochs, preventing overfitting. The restore_best_weights=True parameter ensures we keep the model from the epoch with the best validation performance.
# Continue training
history2 = model.fit(
train_ds,
validation_data=val_ds,
epochs=10,
callbacks=[
keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)
]
)
Complete Transfer Learning Pipeline
Here's a complete, production-ready example showing the full workflow from data loading to fine-tuning with EfficientNetB0.
The first step is to load a pretrained model. EfficientNetB0 is an excellent choice as it provides a good balance between accuracy and computational efficiency. We set include_top=False to remove the original classification head, allowing us to add our own custom layers for our specific number of classes.
def create_transfer_learning_model(num_classes, fine_tune_at=100):
"""Create a complete transfer learning model with EfficientNet."""
# Load EfficientNetB0
base_model = keras.applications.EfficientNetB0(
input_shape=(224, 224, 3),
include_top=False,
weights='imagenet'
)
Data augmentation is essential for preventing overfitting, especially when working with smaller datasets. By applying random transformations like horizontal flips, rotations, and zooms during training, we artificially increase the diversity of our training data and help the model learn more robust features.
# Data augmentation
data_augmentation = keras.Sequential([
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.2),
layers.RandomZoom(0.2),
])
Using the Keras Functional API, we connect the augmentation pipeline to the base model and add a custom classification head. The GlobalAveragePooling2D layer reduces the spatial dimensions while preserving feature information. We include BatchNormalization and Dropout layers to stabilize training and prevent overfitting.
# Build model
inputs = keras.Input(shape=(224, 224, 3))
x = data_augmentation(inputs)
x = keras.applications.efficientnet.preprocess_input(x)
x = base_model(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.BatchNormalization()(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(num_classes, activation='softmax')(x)
model = keras.Model(inputs, outputs)
Initially, we freeze all layers in the base model by setting trainable = False. This preserves the pretrained ImageNet features and allows us to train only the new classification head first, which is much faster and prevents catastrophic forgetting of the learned representations.
# Initially freeze base model
base_model.trainable = False
return model, base_model
# Create model
model, base_model = create_transfer_learning_model(num_classes=10)
In the first training phase (feature extraction), we train only the newly added classification layers while keeping the base model frozen. This allows the model to quickly learn how to map the pretrained features to our specific classes. A standard learning rate works well here since we're only training a small number of parameters.
# Phase 1: Feature extraction
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(train_ds, validation_data=val_ds, epochs=5)
In the second phase (fine-tuning), we unfreeze the top layers of the base model and continue training with a much lower learning rate (typically 10-100x smaller). This allows the model to adapt the pretrained features to our specific task. We keep the early layers frozen because they contain generic features (edges, textures) that transfer well to any image task.
# Phase 2: Fine-tuning
base_model.trainable = True
for layer in base_model.layers[:100]:
layer.trainable = False
model.compile(
optimizer=keras.optimizers.Adam(1e-5),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
model.fit(train_ds, validation_data=val_ds, epochs=10)
Practice: Transfer Learning
Task: Load VGG16, ResNet50, and MobileNetV2 (without top layers) and compare their parameter counts and memory sizes.
Show Solution
from tensorflow.keras.applications import VGG16, ResNet50, MobileNetV2
models_info = {
'VGG16': VGG16(include_top=False, weights='imagenet'),
'ResNet50': ResNet50(include_top=False, weights='imagenet'),
'MobileNetV2': MobileNetV2(include_top=False, weights='imagenet')
}
print("Model Comparison:")
print("-" * 50)
for name, model in models_info.items():
params = model.count_params()
size_mb = params * 4 / (1024 ** 2) # 4 bytes per float32
print(f"{name:15} Parameters: {params:>12,} Size: {size_mb:>6.1f} MB")
# Output:
# VGG16 Parameters: 14,714,688 Size: 56.1 MB
# ResNet50 Parameters: 23,587,712 Size: 90.0 MB
# MobileNetV2 Parameters: 2,257,984 Size: 8.6 MB
Task: Use ResNet50 with transfer learning to classify the TensorFlow flowers dataset (5 classes). Use feature extraction first, then fine-tune the top 30 layers.
Show Solution
import tensorflow_datasets as tfds
# Load flowers dataset
(train_ds, val_ds), info = tfds.load(
'tf_flowers',
split=['train[:80%]', 'train[80%:]'],
with_info=True,
as_supervised=True
)
# Preprocess
def preprocess(image, label):
image = tf.image.resize(image, [224, 224])
image = keras.applications.resnet50.preprocess_input(image)
return image, label
train_ds = train_ds.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)
# Build model
base_model = keras.applications.ResNet50(include_top=False, weights='imagenet')
base_model.trainable = False
model = keras.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.Dense(256, activation='relu'),
layers.Dropout(0.5),
layers.Dense(5, activation='softmax')
])
# Phase 1: Feature extraction
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(train_ds, validation_data=val_ds, epochs=5)
# Phase 2: Fine-tune top 30 layers
base_model.trainable = True
for layer in base_model.layers[:-30]:
layer.trainable = False
model.compile(optimizer=keras.optimizers.Adam(1e-5),
loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(train_ds, validation_data=val_ds, epochs=5)
Model Evaluation
Evaluating image classifiers goes beyond simple accuracy. Understanding metrics like precision, recall, F1-score, and confusion matrices helps you identify model weaknesses and improve performance. This section covers comprehensive evaluation techniques and strategies for handling common issues.
Beyond Accuracy
While accuracy tells you the overall percentage of correct predictions, it can be misleading for imbalanced datasets. A model predicting "not cancer" 99% of the time achieves 99% accuracy if only 1% of samples are cancerous—yet it's completely useless.
Better Approach: Use a combination of precision, recall, F1-score, and confusion matrices to get a complete picture of model performance across all classes.
Classification Metrics Explained
Understanding these metrics is essential for proper model evaluation. Each metric captures different aspects of model performance and is suited for different use cases.
| Metric | Formula | What it Measures | Prioritize When |
|---|---|---|---|
| Accuracy | (TP + TN) / Total | Overall correctness | Balanced classes |
| Precision | TP / (TP + FP) | Quality of positive predictions | Cost of false positives high |
| Recall | TP / (TP + FN) | Ability to find all positives | Cost of false negatives high |
| F1-Score | 2 × (P × R) / (P + R) | Balance of precision & recall | Imbalanced classes |
| AUC-ROC | Area under ROC curve | Ranking quality | Threshold-independent eval |
Computing Evaluation Metrics
Scikit-learn provides comprehensive tools for computing classification metrics. Here's how to generate a complete evaluation report for your image classifier.
First, we import the necessary metrics functions and get predictions from our trained model. The model.predict() returns probability distributions across all classes, so we use np.argmax() to convert these probabilities into class indices by selecting the class with the highest probability for each sample.
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
# Get predictions
y_pred_probs = model.predict(x_test)
y_pred = np.argmax(y_pred_probs, axis=1)
y_true = y_test.flatten() # Ensure 1D array
We define human-readable class names that correspond to the numeric labels (0-9) in our dataset. These names make our evaluation reports much more interpretable than raw numbers, allowing us to understand which real-world categories the model struggles with.
# Class names for CIFAR-10
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
The classification_report() function generates a comprehensive table showing precision, recall, and F1-score for each class, plus overall averages. Precision tells us what percentage of positive predictions were correct, recall shows what percentage of actual positives we found, and F1-score is their harmonic mean—useful when you need a single metric that balances both.
# Generate classification report
print("Classification Report:")
print("=" * 60)
print(classification_report(y_true, y_pred, target_names=class_names))
Finally, we compute overall summary metrics. accuracy_score gives the percentage of correct predictions. For multi-class problems, we calculate two versions of F1-score: macro (simple average across classes, treating all classes equally) and weighted (weighted by class frequency, giving more importance to common classes). Use macro F1 when all classes matter equally, and weighted F1 when you care more about overall performance on your actual data distribution.
# Key metrics summary
from sklearn.metrics import accuracy_score, f1_score
accuracy = accuracy_score(y_true, y_pred)
f1_macro = f1_score(y_true, y_pred, average='macro')
f1_weighted = f1_score(y_true, y_pred, average='weighted')
print(f"\nOverall Accuracy: {accuracy:.4f}")
print(f"Macro F1-Score: {f1_macro:.4f}")
print(f"Weighted F1-Score: {f1_weighted:.4f}")
Confusion Matrix Visualization
A confusion matrix shows exactly where your model makes mistakes—which classes it confuses with others. This is invaluable for identifying patterns in misclassifications.
We start by importing visualization libraries and computing the confusion matrix. The confusion_matrix() function creates a 2D array where rows represent true labels and columns represent predictions. Each cell [i,j] shows how many samples of class i were predicted as class j. We also create a normalized version by dividing each row by its sum, converting counts to percentages.
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
def plot_confusion_matrix(y_true, y_pred, class_names):
"""Create a beautiful confusion matrix heatmap."""
cm = confusion_matrix(y_true, y_pred)
# Normalize for percentages
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
The first heatmap displays raw counts using the 'Blues' colormap. The annot=True parameter displays the actual numbers in each cell, and fmt='d' formats them as integers. Raw counts are useful for understanding the absolute volume of errors—important when class sizes differ significantly.
# Raw counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names, ax=axes[0])
axes[0].set_title('Confusion Matrix (Counts)', fontsize=14)
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('True')
The second heatmap shows normalized percentages using 'RdYlGn' (red-yellow-green) colormap, where green indicates high accuracy and red indicates poor performance. The fmt='.2%' formats values as percentages with 2 decimal places. Normalized view is better for comparing performance across classes with different sample sizes.
# Normalized (percentages)
sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='RdYlGn',
xticklabels=class_names, yticklabels=class_names, ax=axes[1])
axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14)
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('True')
plt.tight_layout()
plt.show()
This bonus analysis automatically identifies the most commonly confused class pair. We zero out the diagonal (correct predictions) using np.fill_diagonal(), then find the maximum remaining value with np.argmax(). The np.unravel_index() converts the flat index back to 2D coordinates, telling us which true class is most often mistaken for which predicted class.
# Find most confused pairs
np.fill_diagonal(cm, 0) # Ignore correct predictions
idx = np.unravel_index(np.argmax(cm), cm.shape)
print(f"\nMost confused: '{class_names[idx[0]]}' → '{class_names[idx[1]]}' ({cm.max()} times)")
plot_confusion_matrix(y_true, y_pred, class_names)
Per-Class Performance Analysis
Identifying which classes perform poorly helps focus your improvement efforts. You might need more training data, better augmentation, or architectural changes for specific classes.
We use precision_recall_fscore_support() from scikit-learn, which computes all four metrics in one call. This function returns arrays where each element corresponds to one class. Precision measures how many predicted positives were correct, recall measures how many actual positives were found, F1-score is their harmonic mean, and support is the number of true samples for each class.
from sklearn.metrics import precision_recall_fscore_support
def analyze_per_class_performance(y_true, y_pred, class_names):
"""Detailed per-class analysis with improvement suggestions."""
precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred)
We organize the metrics into a pandas DataFrame for easier analysis and visualization. Each row represents one class with its metrics. Sorting by F1-Score puts the worst-performing classes at the top, immediately highlighting where your model needs the most improvement.
# Create DataFrame for easy analysis
import pandas as pd
results = pd.DataFrame({
'Class': class_names,
'Precision': precision,
'Recall': recall,
'F1-Score': f1,
'Support': support
}).sort_values('F1-Score')
print("Per-Class Performance (sorted by F1-Score):")
print("=" * 60)
print(results.to_string(index=False))
This diagnostic code automatically identifies classes performing below a threshold (F1 < 0.7) and provides actionable insights. By comparing precision and recall, we can diagnose the specific problem: low precision means too many false positives (model is over-predicting this class), while low recall means too many false negatives (model is missing instances of this class). This distinction guides your improvement strategy—add negative examples for low precision, or add more positive examples for low recall.
# Identify problematic classes
print("\n⚠️ Classes needing improvement (F1 < 0.7):")
weak_classes = results[results['F1-Score'] < 0.7]
for _, row in weak_classes.iterrows():
if row['Precision'] < row['Recall']:
print(f" • {row['Class']}: Low precision - too many false positives")
else:
print(f" • {row['Class']}: Low recall - missing many true instances")
return results
results = analyze_per_class_performance(y_true, y_pred, class_names)
Visualizing Misclassifications
Seeing actual misclassified images helps understand why the model fails. Are the images ambiguous? Mislabeled? Or does the model lack certain features?
We define a function that takes the test data, true labels, predictions, prediction probabilities, and class names. The np.where(y_true != y_pred) comparison finds all indices where the prediction doesn't match the true label. We add an early return check in case the model achieves perfect accuracy (unlikely but good practice).
def plot_misclassifications(x_test, y_true, y_pred, y_pred_probs, class_names, n=12):
"""Display misclassified images with confidence scores."""
# Find misclassifications
misclassified_idx = np.where(y_true != y_pred)[0]
if len(misclassified_idx) == 0:
print("No misclassifications found!")
return
Since there may be thousands of misclassifications, we randomly sample a subset to display. The np.random.choice() function selects n random indices without replacement. The min(n, len(misclassified_idx)) ensures we don't request more samples than available. We create a 3×4 grid of subplots to display 12 images.
# Sample random misclassifications
sample_idx = np.random.choice(misclassified_idx, min(n, len(misclassified_idx)), replace=False)
fig, axes = plt.subplots(3, 4, figsize=(14, 10))
We iterate through each subplot using axes.flat which flattens the 2D array of axes into a 1D iterator. For each position, we retrieve the corresponding image, extract the true and predicted labels using class name lookup, and calculate the model's confidence by accessing the predicted class probability from y_pred_probs (multiplied by 100 for percentage).
for i, ax in enumerate(axes.flat):
if i >= len(sample_idx):
ax.axis('off')
continue
idx = sample_idx[i]
img = x_test[idx]
true_label = class_names[y_true[idx]]
pred_label = class_names[y_pred[idx]]
confidence = y_pred_probs[idx][y_pred[idx]] * 100
Finally, we display each image with an informative title showing both the true label and the model's prediction with its confidence score. The red color immediately signals these are errors. High confidence on wrong predictions indicates the model is confidently wrong—often more problematic than low-confidence mistakes. The plt.tight_layout() ensures proper spacing between subplots.
ax.imshow(img)
ax.set_title(f"True: {true_label}\nPred: {pred_label} ({confidence:.1f}%)",
color='red', fontsize=10)
ax.axis('off')
plt.suptitle("Misclassified Images", fontsize=14)
plt.tight_layout()
plt.show()
plot_misclassifications(x_test, y_true, y_pred, y_pred_probs, class_names)
Model Inference and Deployment
Once your model is trained and evaluated, you need to use it for predictions on new images. Here's how to create a clean prediction pipeline.
We create a factory function that returns a prediction function. This closure pattern captures the model, class_names, and image_size parameters, making the returned function simple to use with just an image path. This is a common design pattern for production deployment where you want to hide complexity from the user.
def create_prediction_pipeline(model, class_names, image_size=(224, 224)):
"""Create a production-ready prediction function."""
def predict_image(image_path):
The preprocessing pipeline loads the image from disk using Keras utilities. The load_img() function automatically resizes the image to the target size expected by the model. Then img_to_array() converts the PIL image to a NumPy array, and np.expand_dims() adds a batch dimension at position 0, transforming the shape from (height, width, channels) to (1, height, width, channels) as required by the model.
# Load and preprocess
img = keras.utils.load_img(image_path, target_size=image_size)
img_array = keras.utils.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
# Predict
predictions = model.predict(img_array, verbose=0)
Instead of returning just the top prediction, we return the top 3 most likely classes with their confidence scores. The argsort() method returns indices that would sort the array, [-3:] takes the last 3 (highest values), and [::-1] reverses to get descending order. We build a list of dictionaries with human-readable class names and confidence scores, making the output easy to use in any application.
# Get top 3 predictions
top_indices = predictions[0].argsort()[-3:][::-1]
results = []
for idx in top_indices:
results.append({
'class': class_names[idx],
'confidence': float(predictions[0][idx])
})
return results
return predict_image
Here's how to use the prediction pipeline. We call the factory function once to create our predict function, then use it repeatedly on any image. The function returns a list of dictionaries that we can easily iterate over or convert to JSON for API responses. This clean interface makes it trivial to integrate your model into web applications, mobile apps, or batch processing pipelines.
# Usage
predict = create_prediction_pipeline(model, class_names)
results = predict('test_image.jpg')
for r in results:
print(f"{r['class']}: {r['confidence']*100:.1f}%")
Practice: Model Evaluation
Task: Train a simple CNN on CIFAR-10 and generate a complete classification report showing precision, recall, and F1-score for each class.
Show Solution
from tensorflow.keras.datasets import cifar10
from sklearn.metrics import classification_report
# Load and prepare data
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Quick model training
model = keras.Sequential([
layers.Conv2D(32, 3, activation='relu', input_shape=(32, 32, 3)),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, activation='relu'),
layers.Flatten(),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, validation_split=0.1, verbose=1)
# Evaluate
y_pred = np.argmax(model.predict(x_test), axis=1)
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
print(classification_report(y_test, y_pred, target_names=class_names))
Task: Create a function that plots confusion matrix and identifies the top 5 most confused class pairs with suggestions for improvement.
Show Solution
def analyze_confusion(y_true, y_pred, class_names, top_n=5):
"""Analyze confusion matrix and suggest improvements."""
cm = confusion_matrix(y_true, y_pred)
# Plot
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.tight_layout()
plt.show()
# Find top confused pairs
cm_copy = cm.copy()
np.fill_diagonal(cm_copy, 0)
print(f"\nTop {top_n} Most Confused Class Pairs:")
print("=" * 50)
for i in range(top_n):
idx = np.unravel_index(np.argmax(cm_copy), cm_copy.shape)
count = cm_copy[idx]
if count == 0:
break
true_class = class_names[idx[0]]
pred_class = class_names[idx[1]]
print(f"{i+1}. {true_class} → {pred_class}: {count} errors")
cm_copy[idx] = 0 # Zero out to find next
analyze_confusion(y_test.flatten(), y_pred, class_names)
Task: Build a function that creates a multi-panel figure showing: accuracy/loss curves, confusion matrix, per-class F1 scores bar chart, and ROC curves for each class.
Show Solution
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
def evaluation_dashboard(model, history, x_test, y_true, class_names):
"""Comprehensive evaluation dashboard."""
fig = plt.figure(figsize=(16, 12))
# Get predictions
y_pred_probs = model.predict(x_test)
y_pred = np.argmax(y_pred_probs, axis=1)
y_true_flat = y_true.flatten()
# 1. Training curves
ax1 = fig.add_subplot(2, 2, 1)
ax1.plot(history.history['accuracy'], label='Train Acc')
ax1.plot(history.history['val_accuracy'], label='Val Acc')
ax1.set_title('Training Progress')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 2. Confusion matrix
ax2 = fig.add_subplot(2, 2, 2)
cm = confusion_matrix(y_true_flat, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names, ax=ax2)
ax2.set_title('Confusion Matrix')
# 3. Per-class F1 scores
ax3 = fig.add_subplot(2, 2, 3)
_, _, f1, _ = precision_recall_fscore_support(y_true_flat, y_pred)
bars = ax3.barh(class_names, f1, color='steelblue')
ax3.set_xlim(0, 1)
ax3.set_title('Per-Class F1 Score')
ax3.axvline(x=0.7, color='red', linestyle='--', label='Threshold')
# 4. ROC curves (multi-class)
ax4 = fig.add_subplot(2, 2, 4)
y_true_bin = label_binarize(y_true_flat, classes=range(len(class_names)))
for i in range(len(class_names)):
fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_pred_probs[:, i])
roc_auc = auc(fpr, tpr)
ax4.plot(fpr, tpr, label=f'{class_names[i]} (AUC={roc_auc:.2f})')
ax4.plot([0, 1], [0, 1], 'k--')
ax4.set_title('ROC Curves')
ax4.legend(loc='lower right', fontsize=8)
plt.tight_layout()
plt.show()
# Usage: evaluation_dashboard(model, history, x_test, y_test, class_names)
Interactive Demo
CNN Architecture Explorer
Explore how different CNN configurations affect model capacity and computational requirements. Adjust the parameters below to see real-time calculations.
Model Configuration
Model Statistics
Architecture Preview
Generated Keras Code
Key Takeaways
Classification Fundamentals
Image classification assigns labels to images from predefined categories using deep learning models that learn discriminative features automatically
CNN Architecture
Convolutional layers extract features, pooling layers reduce dimensions, and dense layers perform classification in a hierarchical manner
Data Augmentation
Techniques like rotation, flipping, and scaling artificially expand training data to improve model generalization and reduce overfitting
Transfer Learning
Pretrained models like VGG16, ResNet, and EfficientNet provide powerful feature extractors that can be fine-tuned for specific tasks
Evaluation Metrics
Accuracy, precision, recall, F1-score, and confusion matrices provide comprehensive insights into classifier performance
Model Optimization
Learning rate scheduling, early stopping, and regularization techniques help achieve optimal model performance