Module 11.3

Data Visualization

Matplotlib turns data into visual insights with powerful plotting. From simple line charts to complex multi-panel figures, it provides complete control over every aspect of your visualizations. It is the foundation of Python plotting.

50 min
Intermediate
Hands-on
What You'll Learn
  • Creating basic plots
  • Chart types and when to use them
  • Customizing appearance
  • Subplots and layouts
  • Saving figures
Contents
01

Plot Anatomy

Every Matplotlib figure has a hierarchy: Figure contains Axes, Axes contain plot elements. Understanding this structure is key to creating and customizing plots.

Key Concept

Figure and Axes

A Figure is the entire window or canvas. An Axes is a single plot within the figure. One figure can contain multiple axes (subplots).

Why it matters: Knowing the difference lets you control layout and create complex multi-panel figures.

Anatomy of a Plot Figure → Axes → Artists
Matplotlib's Layered Structure
FIGURE AXES Plot Title Y Label Data Artist 100 75 50 25 0 0 1 2 3 4 X Label
Plot Components
  • Figure = Canvas
  • Axes = Plot area
  • Artists = Drawn elements
  • Ticks = Axis markers
Key Insight

One Figure can contain multiple Axes (subplots). Each Axes has its own title, labels, and artists.

Color Legend
Figure Axes Artists Ticks

Two Interfaces

pyplot (Quick)
import matplotlib.pyplot as plt

# Quick and simple
plt.plot([1, 2, 3], [1, 4, 9])
plt.title('Quick Plot')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()
Object-Oriented (Flexible)
import matplotlib.pyplot as plt

# More control
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 4, 9])
ax.set_title('OO Plot')
ax.set_xlabel('X')
ax.set_ylabel('Y')
plt.show()

The pyplot interface is quick for simple plots and works like MATLAB. The object-oriented interface gives more control and is recommended for complex figures. When you call plt.subplots(), it returns both the figure and axes objects, giving you direct access to all plot elements. This approach makes it easier to modify specific parts of your visualization and is essential when creating multiple subplots or customizing individual elements.

Installing Matplotlib

Before creating visualizations, you need to install matplotlib. It comes with many data science distributions, but you can also install it separately using pip or conda.

# Install with pip
# pip install matplotlib

# Install with conda
# conda install matplotlib

# Verify installation
import matplotlib
print(matplotlib.__version__)  # Output: e.g., 3.8.2

# Import the pyplot module (standard convention)
import matplotlib.pyplot as plt
import numpy as np  # NumPy is commonly used with matplotlib

The standard convention is to import matplotlib.pyplot as plt. This gives you access to all the common plotting functions with a short, convenient alias. NumPy is typically imported alongside matplotlib because most visualizations involve numerical data that benefits from NumPy's array operations. The combination of these two libraries forms the foundation of Python's data visualization ecosystem.

The Plotting Workflow

Creating a visualization follows a consistent workflow: create figure, plot data, customize appearance, and display or save.

Matplotlib Workflow 6 Steps
1. CREATE fig, ax = plt.subplots() 1 2. PLOT ax.plot(x, y) ax.scatter() ax.bar() 2 3. CUSTOMIZE ax.set_title() ax.set_xlabel() ax.legend() 3 4. REFINE plt.tight_layout() ax.grid() ax.set_xlim() 4 5. SHOW/SAVE plt.show() plt.savefig() 5 6. DONE! ✓ Your beautiful visualization is ready! 6 💡 Pro Tips • Always use subplots() Even for single plots • Plot before customize Helps visualize changes • Use tight_layout() Prevents clipping labels • Save before show() show() clears figure • Set dpi for quality dpi=150 or higher
Follow this workflow for consistent, well-structured visualizations
# Complete workflow example
import matplotlib.pyplot as plt
import numpy as np

# 1. Create figure and axes
fig, ax = plt.subplots(figsize=(10, 6))

# 2. Generate and plot data
x = np.linspace(0, 10, 100)
y = np.sin(x)
ax.plot(x, y, color='steelblue', linewidth=2, label='sin(x)')

# 3. Customize appearance
ax.set_title('Sine Wave Visualization', fontsize=14, fontweight='bold')
ax.set_xlabel('X-axis', fontsize=12)
ax.set_ylabel('Y-axis', fontsize=12)
ax.legend(loc='upper right')

# 4. Refine layout
ax.grid(True, alpha=0.3)
plt.tight_layout()

# 5. Save and/or show
plt.savefig('sine_wave.png', dpi=150, bbox_inches='tight')
plt.show()

This example demonstrates the complete plotting workflow from start to finish. First, we create a figure with specified dimensions using figsize. Then we generate data with NumPy and plot it with customized styling. The set_title, set_xlabel, and set_ylabel methods add descriptive text. The grid and tight_layout calls improve readability and spacing. Finally, savefig exports the figure before show displays it. Always save before showing, as show clears the figure from memory.

Understanding Coordinates

Matplotlib uses a Cartesian coordinate system where the origin (0,0) is typically at the bottom-left, with x increasing to the right and y increasing upward.

Data Coordinates
# Data coordinates are your actual data values
x_data = [1, 2, 3, 4, 5]
y_data = [10, 20, 15, 25, 30]

plt.plot(x_data, y_data)
# Points appear at their data values
# (1, 10), (2, 20), (3, 15), etc.

Data coordinates match your actual data values.

Axes Coordinates
# Axes coordinates: 0-1 relative to axes
# (0, 0) = bottom-left corner
# (1, 1) = top-right corner
# (0.5, 0.5) = center of axes

ax.text(0.5, 0.5, 'Center', 
        transform=ax.transAxes)

Axes coordinates are normalized 0-1 values.

Practice: Plot Anatomy

Task: Import matplotlib.pyplot and numpy, then print the matplotlib version.

Show Solution
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

print(f"Matplotlib version: {matplotlib.__version__}")
print(f"NumPy version: {np.__version__}")

Task: Create a figure with axes using plt.subplots() and display it.

Show Solution
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.set_title('Empty Figure')
plt.show()

Task: Create a simple line plot using the object-oriented interface (fig, ax = plt.subplots()).

Show Solution
import matplotlib.pyplot as plt
import numpy as np

# Create figure and axes
fig, ax = plt.subplots(figsize=(8, 5))

# Plot data
x = np.linspace(0, 5, 50)
y = x ** 2
ax.plot(x, y)

# Customize
ax.set_title('Quadratic Function')
ax.set_xlabel('x')
ax.set_ylabel('x squared')
ax.grid(True)

plt.show()

Task: Create a plot following the complete workflow: create, plot, customize, save, and show.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

# 1. Create
fig, ax = plt.subplots(figsize=(10, 6))

# 2. Plot
x = np.linspace(0, 2 * np.pi, 100)
ax.plot(x, np.sin(x), label='sin(x)', color='blue')
ax.plot(x, np.cos(x), label='cos(x)', color='red')

# 3. Customize
ax.set_title('Trigonometric Functions', fontsize=14)
ax.set_xlabel('Radians')
ax.set_ylabel('Value')
ax.legend()
ax.grid(True, alpha=0.3)

# 4. Save
plt.tight_layout()
plt.savefig('trig_functions.png', dpi=150, bbox_inches='tight')

# 5. Show
plt.show()
02

Basic Plots

Start with line plots, then add markers, colors, and labels to make your visualizations informative and attractive. Line plots are the foundation of data visualization, perfect for showing trends and continuous data over time.

Key Concept

Line Plot Components

A line plot connects data points with straight lines. The essential components are: X values (horizontal position), Y values (vertical position), line style, color, and optional markers at each point.

When to use: Time series data, continuous measurements, trends over intervals, comparing multiple series.

Simple Line Plot

The most basic plot connects a series of points. Provide x and y coordinates as lists or arrays.

import matplotlib.pyplot as plt
import numpy as np

# Generate data
x = np.linspace(0, 10, 100)  # 100 points from 0 to 10
y = np.sin(x)                 # Sine of each x value

# Create the plot
plt.plot(x, y)

# Add labels and title
plt.title('Sine Wave')
plt.xlabel('X')
plt.ylabel('sin(x)')
plt.grid(True)
plt.show()

This example creates a smooth sine wave by generating 100 evenly-spaced x values and computing their sine. The np.linspace function creates the x array, and np.sin applies the sine function element-wise. The plt.plot function connects these points with a line. Adding grid(True) overlays gridlines that make it easier to read values from the chart. The more points you generate, the smoother the curve appears.

Multiple Lines on One Plot

Compare multiple data series by plotting them on the same axes. Use labels and a legend to identify each line.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)

# Plot multiple lines
plt.plot(x, np.sin(x), label='sin(x)')
plt.plot(x, np.cos(x), label='cos(x)')
plt.plot(x, np.sin(x) + np.cos(x), label='sin(x) + cos(x)')

# Customize
plt.title('Trigonometric Functions')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()  # Display legend with labels
plt.grid(True, alpha=0.3)
plt.show()

When plotting multiple lines, the label parameter assigns a name to each line that appears in the legend. The plt.legend() call displays the legend box, which by default appears in the best location to avoid overlapping data. Each subsequent plot call adds a new line in a different color from matplotlib's default color cycle. The alpha parameter on the grid controls transparency, making the gridlines subtle so they do not distract from the data.

Line Styles and Colors

Customize appearance using format strings (quick) or keyword arguments (detailed control).

Format String Reference
Colors
  • b - Blue
  • g - Green
  • r - Red
  • c - Cyan
  • m - Magenta
  • y - Yellow
  • k - Black
  • w - White
Markers
  • o - Circle
  • s - Square
  • ^ - Triangle up
  • v - Triangle down
  • * - Star
  • + - Plus
  • x - X
  • . - Point
Line Styles
  • - - Solid
  • -- - Dashed
  • -. - Dash-dot
  • : - Dotted
  • (none) - No line
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 5, 20)

# Format strings: 'color marker linestyle'
plt.plot(x, x, 'r--')           # Red dashed line
plt.plot(x, x**1.5, 'bo')       # Blue circles only (no line)
plt.plot(x, x**2, 'g.-')        # Green dot markers with dash-dot line
plt.plot(x, x**2.5, 'ms-')      # Magenta squares with solid line

plt.title('Format String Examples')
plt.legend(['Linear', 'x^1.5', 'Quadratic', 'x^2.5'])
plt.show()

Format strings combine color, marker, and line style into a short code. The order does not matter: 'r--', '--r', and 'r--' all work. When you omit the line style and only provide a marker (like 'bo'), matplotlib draws markers without connecting them. This is useful for scatter-like plots where you want to emphasize individual points rather than trends. Format strings are convenient for quick plots, but keyword arguments offer more options.

Keyword Arguments for Fine Control

For precise control over appearance, use keyword arguments instead of format strings.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 50)
y = np.sin(x)

plt.plot(x, y, 
         color='#E74C3C',        # Hex color
         linestyle='--',          # Dashed line
         linewidth=2.5,           # Line thickness
         marker='o',              # Circle markers
         markersize=6,            # Marker size
         markerfacecolor='white', # Marker fill color
         markeredgecolor='#E74C3C', # Marker border color
         markeredgewidth=1.5,     # Marker border thickness
         alpha=0.8,               # Transparency (0-1)
         label='Custom Style')

plt.title('Styled Line Plot')
plt.xlabel('X Values')
plt.ylabel('Y Values')
plt.legend()
plt.grid(True, linestyle=':', alpha=0.5)
plt.show()

Keyword arguments provide granular control over every visual aspect of your plot. You can specify colors using hex codes for exact brand colors, control line and marker sizes independently, and set different colors for marker faces and edges. The alpha parameter adds transparency, which is useful when lines overlap. This level of customization is essential for publication-quality figures and presentations where visual consistency matters.

Filling Areas

Use fill_between to shade the area between a line and a baseline, useful for showing ranges or emphasizing magnitude.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y = np.sin(x)

plt.plot(x, y, color='blue', linewidth=2)
plt.fill_between(x, y, alpha=0.3, color='blue')  # Fill to y=0
plt.fill_between(x, y, 1, where=(y > 0.5), alpha=0.3, color='green')  # Conditional fill

plt.title('Area Fill Example')
plt.xlabel('X')
plt.ylabel('Y')
plt.axhline(y=0, color='black', linestyle='-', linewidth=0.5)  # Baseline
plt.show()

The fill_between function shades the area between your data and a baseline (default y=0). The first argument is x values, the second is your y data, and an optional third argument specifies a different baseline. The where parameter allows conditional filling, shading only where a condition is true. This is powerful for highlighting specific regions like values above a threshold. The axhline function draws a horizontal reference line at y=0.

Practice: Basic Plots

Task: Plot the squares of numbers 1-10 (x vs x squared).

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.arange(1, 11)
y = x ** 2
plt.plot(x, y)
plt.title('Squares')
plt.xlabel('x')
plt.ylabel('x squared')
plt.show()

Task: Create any plot and add a title, x-label, and y-label.

Show Solution
import matplotlib.pyplot as plt

plt.plot([1, 2, 3, 4], [1, 4, 2, 3])
plt.title('My First Plot')
plt.xlabel('Time (s)')
plt.ylabel('Value')
plt.show()

Task: Plot both x squared and x cubed on the same graph with a legend.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.arange(1, 6)
plt.plot(x, x**2, label='x squared')
plt.plot(x, x**3, label='x cubed')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.show()

Task: Create a plot with red dashed line and circle markers using format strings.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 5, 10)
y = x ** 2

plt.plot(x, y, 'ro--')  # Red circles, dashed line
plt.title('Format String Demo')
plt.xlabel('X')
plt.ylabel('Y')
plt.grid(True)
plt.show()

Task: Create a line plot with custom color, linewidth, marker size, and transparency.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 20)
y = np.sin(x)

plt.plot(x, y,
         color='#3498DB',
         linewidth=2,
         marker='s',
         markersize=8,
         markerfacecolor='white',
         markeredgecolor='#3498DB',
         alpha=0.9,
         label='Styled Sine')

plt.title('Keyword Arguments Demo')
plt.xlabel('X')
plt.ylabel('sin(x)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

Task: Plot a curve and fill the area underneath with a semi-transparent color.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y = np.exp(-x/3) * np.sin(x)

plt.plot(x, y, color='purple', linewidth=2, label='Damped sine')
plt.fill_between(x, y, alpha=0.3, color='purple')
plt.axhline(y=0, color='black', linewidth=0.5)

plt.title('Filled Area Plot')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
03

Chart Types

Different data calls for different chart types. Choose the right visualization to communicate your insights effectively. Line plots excel at showing trends over time, bar charts compare discrete categories, scatter plots reveal relationships between variables, histograms display distributions, pie charts show proportions, and box plots summarize statistical properties. Selecting the appropriate chart type is crucial—the wrong choice can mislead or confuse your audience, while the right one makes patterns and insights immediately clear.

Chart Types Gallery 6 Essential Charts
Line Plot

Best for trends over time or continuous data.

plt.plot(x, y)
Bar Chart

Best for comparing categories.

plt.bar(x, heights)
Scatter Plot

Best for relationships between two variables.

plt.scatter(x, y)
Histogram

Best for distribution of values.

plt.hist(data, bins=10)
Pie Chart

Best for parts of a whole (use sparingly).

plt.pie(sizes, labels=labels)
Box Plot

Best for statistical distributions.

plt.boxplot(data)
Quick Tip: Use line for time series, bar for categories, scatter for correlations, hist for distributions.

Bar Chart

categories = ['A', 'B', 'C', 'D']
values = [25, 40, 30, 55]

plt.bar(categories, values, color='steelblue')
plt.title('Sales by Category')
plt.xlabel('Category')
plt.ylabel('Sales')
plt.show()

Bar charts are ideal for comparing discrete categories. Use plt.barh() for horizontal bars.

Scatter Plot

import numpy as np

x = np.random.randn(100)
y = x + np.random.randn(100) * 0.5

plt.scatter(x, y, alpha=0.6, c='coral')
plt.title('Correlation Example')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()

Use scatter plots to show relationships. The alpha parameter controls transparency (useful for overlapping points).

Histogram

data = np.random.randn(1000)

plt.hist(data, bins=30, edgecolor='black', alpha=0.7)
plt.title('Distribution of Values')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.show()

Histograms show how data is distributed across value ranges. The bins parameter controls granularity: more bins show more detail but can look noisy, fewer bins smooth the distribution but may hide patterns. The edgecolor parameter adds borders to distinguish adjacent bars. Always consider your data size when choosing bins: with 1000 points, 30 bins gives about 33 points per bin on average, providing a reliable estimate of the distribution shape.

Pie Chart

Pie charts show parts of a whole. Use them sparingly, as bar charts are often clearer for comparisons.

import matplotlib.pyplot as plt

# Data for pie chart
sizes = [35, 25, 20, 15, 5]
labels = ['Python', 'JavaScript', 'Java', 'C++', 'Other']
colors = ['#3498DB', '#F39C12', '#E74C3C', '#9B59B6', '#95A5A6']
explode = (0.05, 0, 0, 0, 0)  # Offset the first slice

plt.figure(figsize=(8, 8))
plt.pie(sizes, labels=labels, colors=colors, explode=explode,
        autopct='%1.1f%%', startangle=90, shadow=True)
plt.title('Programming Language Popularity')
plt.axis('equal')  # Equal aspect ratio for circular pie
plt.show()

Pie charts represent proportions of a whole, with each slice sized according to its value. The explode parameter pulls slices away from center for emphasis. autopct formats percentage labels on each slice. startangle rotates the chart so the first slice begins at that angle (90 degrees is the top). The axis('equal') call ensures the pie is circular rather than elliptical. While visually appealing, pie charts become hard to read with many categories or similar-sized slices.

Box Plot

Box plots (box-and-whisker plots) show statistical distributions including median, quartiles, and outliers.

import matplotlib.pyplot as plt
import numpy as np

# Generate sample data for multiple groups
np.random.seed(42)
data = [np.random.normal(0, std, 100) for std in range(1, 5)]

fig, ax = plt.subplots(figsize=(10, 6))
bp = ax.boxplot(data, labels=['A', 'B', 'C', 'D'], patch_artist=True)

# Color the boxes
colors = ['#3498DB', '#2ECC71', '#F39C12', '#E74C3C']
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

ax.set_title('Distribution Comparison')
ax.set_xlabel('Group')
ax.set_ylabel('Value')
ax.grid(True, axis='y', alpha=0.3)
plt.show()

Box plots visualize five key statistics: minimum, first quartile (25th percentile), median (50th percentile), third quartile (75th percentile), and maximum. Points beyond the whiskers are outliers. The patch_artist=True parameter allows filling boxes with color. Box plots excel at comparing distributions across groups because you can quickly see differences in spread, center, and skewness. They are essential for exploratory data analysis and identifying unusual patterns in your data.

Heatmap

Heatmaps display matrix data using color intensity, perfect for correlation matrices or 2D data.

import matplotlib.pyplot as plt
import numpy as np

# Create a correlation-like matrix
np.random.seed(42)
data = np.random.rand(5, 5)
# Make it symmetric for a correlation matrix effect
data = (data + data.T) / 2
np.fill_diagonal(data, 1)

fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(data, cmap='coolwarm', vmin=0, vmax=1)

# Add colorbar
cbar = plt.colorbar(im)
cbar.set_label('Correlation')

# Add labels
labels = ['A', 'B', 'C', 'D', 'E']
ax.set_xticks(range(len(labels)))
ax.set_yticks(range(len(labels)))
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)

# Add value annotations
for i in range(len(labels)):
    for j in range(len(labels)):
        ax.text(j, i, f'{data[i, j]:.2f}', ha='center', va='center', 
                color='white' if data[i, j] > 0.5 else 'black')

ax.set_title('Correlation Heatmap')
plt.tight_layout()
plt.show()

Heatmaps use color to represent values in a matrix. The imshow function displays the matrix, and cmap selects the color scheme (coolwarm goes from blue through white to red). The vmin and vmax parameters set the color scale range. Adding a colorbar provides a legend for interpreting colors. Text annotations show exact values in each cell, with color chosen based on the background to ensure readability. Heatmaps are indispensable for visualizing correlation matrices, confusion matrices, and any 2D tabular data.

Chart Selection Guide
Data Type Best Chart When to Use
Trends over time Line Plot Stock prices, temperature, continuous measurements
Category comparison Bar Chart Sales by region, survey responses, counts
Two variable relationship Scatter Plot Height vs weight, correlation analysis
Value distribution Histogram Age distribution, test scores, frequency
Parts of a whole Pie Chart Market share, budget allocation (limited categories)
Statistical summary Box Plot Compare distributions, identify outliers
Matrix/2D data Heatmap Correlation matrix, confusion matrix, schedules

Practice: Chart Types

Task: Create a bar chart showing sales for 4 products.

Show Solution
import matplotlib.pyplot as plt

products = ['Widget', 'Gadget', 'Gizmo', 'Doodad']
sales = [120, 85, 200, 150]
plt.bar(products, sales)
plt.title('Product Sales')
plt.ylabel('Units Sold')
plt.show()

Task: Create a scatter plot with random x and y data.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.random.rand(30)
y = np.random.rand(30)

plt.scatter(x, y, color='coral', s=50)
plt.title('Random Scatter')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()

Task: Generate 500 random numbers and plot their histogram with 20 bins.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

data = np.random.randn(500)
plt.hist(data, bins=20, edgecolor='black')
plt.title('Random Distribution')
plt.xlabel('Value')
plt.ylabel('Count')
plt.show()

Task: Create a horizontal bar chart showing programming language popularity.

Show Solution
import matplotlib.pyplot as plt

languages = ['Python', 'JavaScript', 'Java', 'C++', 'Go']
popularity = [85, 78, 65, 55, 45]

plt.barh(languages, popularity, color='steelblue')
plt.title('Programming Language Popularity')
plt.xlabel('Popularity Score')
plt.tight_layout()
plt.show()

Task: Create a scatter plot where point color depends on y value.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.random.randn(50)
y = np.random.randn(50)
plt.scatter(x, y, c=y, cmap='viridis')
plt.colorbar(label='Y value')
plt.title('Colored Scatter')
plt.show()

Task: Create a box plot comparing 3 different distributions.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

np.random.seed(42)
data = [
    np.random.normal(0, 1, 100),   # Mean 0, std 1
    np.random.normal(2, 1.5, 100), # Mean 2, std 1.5
    np.random.normal(-1, 0.5, 100) # Mean -1, std 0.5
]

plt.boxplot(data, labels=['Group A', 'Group B', 'Group C'])
plt.title('Distribution Comparison')
plt.ylabel('Value')
plt.grid(True, axis='y', alpha=0.3)
plt.show()
04

Customization

Control colors, fonts, axes limits, and more to create publication-quality visualizations.

What is Plot Customization?

Plot customization involves modifying the visual appearance of charts to improve clarity, match branding, or meet publication standards. Key customization areas include:

  • Colors: Line colors, fill colors, colormaps for data-driven coloring
  • Typography: Font family, size, weight for titles and labels
  • Axes: Limits, ticks, tick labels, logarithmic scales
  • Layout: Figure size, DPI, margins, tight layout
  • Annotations: Text, arrows, shapes to highlight features
  • Styles: Pre-built themes that change overall appearance

Colors and Styles

Matplotlib supports multiple ways to specify colors and provides built-in styles for quick theming.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)

# Color specification methods
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Named colors
axes[0, 0].plot(x, np.sin(x), color='crimson', linewidth=2)
axes[0, 0].set_title('Named Color: crimson')

# Hex colors
axes[0, 1].plot(x, np.sin(x), color='#3498DB', linewidth=2)
axes[0, 1].set_title('Hex Color: #3498DB')

# RGB tuple (0-1 scale)
axes[1, 0].plot(x, np.sin(x), color=(0.2, 0.6, 0.3), linewidth=2)
axes[1, 0].set_title('RGB Tuple: (0.2, 0.6, 0.3)')

# RGBA with transparency
axes[1, 1].plot(x, np.sin(x), color=(0.8, 0.2, 0.5, 0.7), linewidth=2)
axes[1, 1].set_title('RGBA with Alpha: 0.7')

plt.tight_layout()
plt.show()

Matplotlib accepts colors in many formats: named colors like 'crimson' or 'steelblue', hexadecimal codes for precise web colors, RGB tuples with values from 0 to 1, and RGBA tuples that include an alpha (transparency) channel. Using consistent colors across your visualizations creates a professional appearance and helps viewers associate colors with specific data categories.

Common Named Colors
  • red
  • blue
  • green
  • orange
  • purple
  • cyan
  • magenta
  • yellow
  • coral
  • crimson
  • steelblue
  • teal
  • gold
  • navy
  • olive
  • salmon

Using Styles

Styles are pre-built themes that change multiple visual settings at once.

import matplotlib.pyplot as plt
import numpy as np

# See all available styles
print(plt.style.available)
# ['Solarize_Light2', 'bmh', 'classic', 'dark_background', 
#  'fast', 'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn-v0_8', ...]

# Apply a style
plt.style.use('seaborn-v0_8-darkgrid')

x = np.linspace(0, 10, 100)
plt.figure(figsize=(10, 6))
plt.plot(x, np.sin(x), label='sin(x)')
plt.plot(x, np.cos(x), label='cos(x)')
plt.title('Seaborn Dark Grid Style')
plt.legend()
plt.show()

# Reset to default
plt.style.use('default')

Matplotlib comes with dozens of built-in styles that instantly transform your plots. Popular choices include 'ggplot' (mimics R's ggplot2), 'fivethirtyeight' (inspired by the news site's data journalism), 'dark_background' (white text on black), and 'seaborn-v0_8-darkgrid' (clean statistical style). Use plt.style.available to see all options. You can also create custom style sheets or use plt.style.context() for temporary style changes.

Colormaps

Colormaps map data values to colors, essential for heatmaps, scatter plots with color-coded data, and contour plots.

import matplotlib.pyplot as plt
import numpy as np

# Create data for colormap demonstration
x = np.random.randn(100)
y = np.random.randn(100)
colors = np.sqrt(x**2 + y**2)  # Distance from origin

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Different colormaps
cmaps = ['viridis', 'plasma', 'coolwarm', 'RdYlGn']
for ax, cmap in zip(axes.flat, cmaps):
    scatter = ax.scatter(x, y, c=colors, cmap=cmap, s=50, alpha=0.7)
    ax.set_title(f'Colormap: {cmap}')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    plt.colorbar(scatter, ax=ax, label='Distance')

plt.tight_layout()
plt.show()

Colormaps transform numerical values into colors. Sequential colormaps (viridis, plasma) work best for ordered data ranging from low to high. Diverging colormaps (coolwarm, RdYlGn) are ideal when data has a meaningful center point, with different colors for above and below. Categorical colormaps (Set1, tab10) provide distinct colors for discrete categories. The 'viridis' colormap is perceptually uniform and colorblind-friendly, making it the default choice for most applications.

Axis Limits and Ticks

Control exactly what ranges and tick marks appear on your axes for clearer data presentation.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y = np.sin(x)

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Default axis limits
axes[0, 0].plot(x, y)
axes[0, 0].set_title('Default Axis Limits')

# Custom axis limits
axes[0, 1].plot(x, y)
axes[0, 1].set_xlim(0, 5)        # X-axis range
axes[0, 1].set_ylim(-1.5, 1.5)  # Y-axis range
axes[0, 1].set_title('Custom Limits: xlim(0,5), ylim(-1.5,1.5)')

# Custom tick locations
axes[1, 0].plot(x, y)
axes[1, 0].set_xticks([0, np.pi, 2*np.pi, 3*np.pi])
axes[1, 0].set_xticklabels(['0', 'π', '2π', '3π'])
axes[1, 0].set_title('Custom Tick Labels')

# Tick formatting
axes[1, 1].plot(x, y)
axes[1, 1].set_yticks([-1, -0.5, 0, 0.5, 1])
axes[1, 1].set_yticklabels(['Very Low', 'Low', 'Mid', 'High', 'Very High'])
axes[1, 1].set_title('Descriptive Tick Labels')

plt.tight_layout()
plt.show()

The set_xlim and set_ylim methods control the visible range of your axes, useful for zooming into interesting regions or maintaining consistent scales across multiple plots. The set_xticks and set_yticks methods specify exact tick positions, while set_xticklabels and set_yticklabels let you replace numeric ticks with custom text. This is particularly useful for displaying mathematical notation, categorical labels, or human-readable descriptions instead of raw numbers.

Figure Size and DPI

Control the dimensions and resolution of your figures for different output needs.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y = np.sin(x)

# Method 1: Set size when creating figure
plt.figure(figsize=(12, 4))  # Width=12 inches, Height=4 inches
plt.plot(x, y)
plt.title('Wide Figure (12x4 inches)')
plt.show()

# Method 2: With subplots and DPI
fig, ax = plt.subplots(figsize=(8, 6), dpi=150)  # Higher DPI = sharper
ax.plot(x, y, linewidth=2)
ax.set_title('High DPI Figure (150 dpi)')
plt.show()

# Checking current figure size
fig = plt.gcf()
print(f"Figure size: {fig.get_size_inches()}")  # Returns (width, height)

The figsize parameter takes a tuple of (width, height) in inches. Standard sizes include (6, 4) for general use, (10, 6) for presentations, and (12, 8) for detailed visualizations. DPI (dots per inch) controls resolution: 100 dpi is standard for screens, 150-300 dpi for print, and 72 dpi for web. Higher DPI creates sharper images but larger file sizes. Always consider your final output medium when choosing these settings.

Annotations and Text

Add text, arrows, and shapes to highlight important features in your plots.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 2*np.pi, 100)
y = np.sin(x)

fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, y, 'b-', linewidth=2)

# Add text at a specific location
ax.text(np.pi/2, 1.1, 'Maximum', fontsize=12, ha='center', 
        color='green', fontweight='bold')

# Add annotation with arrow
ax.annotate('Minimum', xy=(3*np.pi/2, -1), xytext=(3*np.pi/2, -1.5),
            fontsize=12, ha='center',
            arrowprops=dict(arrowstyle='->', color='red', lw=2))

# Add horizontal reference line
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)

# Add vertical span to highlight a region
ax.axvspan(np.pi, 2*np.pi, alpha=0.2, color='yellow', label='Second half')

# Add text box
props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
ax.text(0.05, 0.95, 'Sine Wave: y = sin(x)', transform=ax.transAxes, 
        fontsize=10, verticalalignment='top', bbox=props)

ax.set_title('Annotation Examples')
ax.set_xlabel('x (radians)')
ax.set_ylabel('sin(x)')
ax.legend()
plt.tight_layout()
plt.show()

Annotations transform raw visualizations into informative graphics. The text function places text at data coordinates, while annotate adds text with an arrow pointing to a specific location. The axhline and axvline functions draw horizontal and vertical reference lines, while axhspan and axvspan shade regions. The transform=ax.transAxes parameter places text in figure-relative coordinates (0-1), useful for legends or notes that should stay fixed regardless of data range.

Grid and Spines

Customize grid lines and axis borders (spines) for cleaner visualizations.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y = np.sin(x)

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Grid customization
axes[0, 0].plot(x, y)
axes[0, 0].grid(True, linestyle='--', alpha=0.7)
axes[0, 0].set_title('Dashed Grid')

# Major and minor grid
axes[0, 1].plot(x, y)
axes[0, 1].grid(True, which='major', linestyle='-', alpha=0.7)
axes[0, 1].grid(True, which='minor', linestyle=':', alpha=0.4)
axes[0, 1].minorticks_on()
axes[0, 1].set_title('Major + Minor Grid')

# Remove spines (axis borders)
axes[1, 0].plot(x, y)
axes[1, 0].spines['top'].set_visible(False)
axes[1, 0].spines['right'].set_visible(False)
axes[1, 0].set_title('No Top/Right Spines')

# Move spines to center
axes[1, 1].plot(x, y)
axes[1, 1].spines['left'].set_position('center')
axes[1, 1].spines['bottom'].set_position('center')
axes[1, 1].spines['top'].set_visible(False)
axes[1, 1].spines['right'].set_visible(False)
axes[1, 1].set_title('Centered Spines (Math Style)')

plt.tight_layout()
plt.show()

Grid lines help viewers read exact values from your plot. Use linestyle ('--', ':', '-') and alpha to control appearance. Minor ticks provide finer divisions without cluttering major tick labels. Spines are the axis borders; removing top and right spines creates a cleaner, modern look common in data journalism. Moving spines to center creates traditional math-style axes that cross at the origin, useful for educational materials showing positive and negative values.

Practice: Customization

Task: Create a line plot with hex color #E74C3C.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 50)
y = np.sin(x)
plt.plot(x, y, color='#E74C3C', linewidth=2)
plt.title('Custom Hex Color')
plt.show()

Task: Apply the 'ggplot' style to a simple plot.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

plt.style.use('ggplot')
x = np.linspace(0, 10, 50)
plt.plot(x, np.sin(x), label='sin')
plt.plot(x, np.cos(x), label='cos')
plt.legend()
plt.title('ggplot Style')
plt.show()
plt.style.use('default')  # Reset

Task: Create a sine plot and zoom into x=[0, π], y=[-0.5, 1.5].

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 2*np.pi, 100)
plt.plot(x, np.sin(x))
plt.xlim(0, np.pi)
plt.ylim(-0.5, 1.5)
plt.title('Zoomed View')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.show()

Task: Create a plot with dashed grid lines at 50% opacity.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 50)
plt.plot(x, x**2)
plt.grid(True, linestyle='--', alpha=0.5)
plt.title('Dashed Grid')
plt.xlabel('x')
plt.ylabel('y')
plt.show()

Task: Plot sin(x) and annotate the maximum point with an arrow.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 2*np.pi, 100)
y = np.sin(x)

plt.figure(figsize=(10, 6))
plt.plot(x, y, 'b-', linewidth=2)

plt.annotate('Maximum', 
             xy=(np.pi/2, 1),      # Point to annotate
             xytext=(np.pi/2 + 1, 0.7),  # Text position
             fontsize=12,
             arrowprops=dict(arrowstyle='->', color='red'))

plt.title('Annotated Sine Wave')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.grid(True, alpha=0.3)
plt.show()

Task: Create a plot with only bottom and left axis borders visible.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 50)
y = np.exp(-x/5) * np.sin(x)

fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, y, 'steelblue', linewidth=2)

# Remove top and right spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

ax.set_title('Clean Modern Style')
ax.set_xlabel('Time')
ax.set_ylabel('Amplitude')
plt.show()
05

Subplots and Saving

Create multiple plots in a single figure and save your visualizations to files.

What are Subplots?

Subplots allow you to display multiple plots within a single figure, enabling side-by-side comparisons or multi-panel visualizations. Key concepts include:

  • Figure: The overall canvas that contains all subplots
  • Axes: Individual plot areas within the figure
  • Grid Layout: Organizing subplots in rows and columns
  • GridSpec: Advanced layout control for irregular grids
  • Shared Axes: Common x or y axes across subplots for easier comparison

Basic Subplots

The subplots function creates a figure with a grid of axes that you can fill with different plots.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)

# Create a 2x2 grid of subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Top-left: Line plot
axes[0, 0].plot(x, np.sin(x), 'b-', linewidth=2)
axes[0, 0].set_title('Sine Wave')
axes[0, 0].set_xlabel('x')
axes[0, 0].set_ylabel('sin(x)')
axes[0, 0].grid(True, alpha=0.3)

# Top-right: Cosine plot
axes[0, 1].plot(x, np.cos(x), 'r-', linewidth=2)
axes[0, 1].set_title('Cosine Wave')
axes[0, 1].set_xlabel('x')
axes[0, 1].set_ylabel('cos(x)')
axes[0, 1].grid(True, alpha=0.3)

# Bottom-left: Bar chart
categories = ['A', 'B', 'C', 'D', 'E']
values = [23, 45, 56, 78, 32]
axes[1, 0].bar(categories, values, color='steelblue')
axes[1, 0].set_title('Category Values')
axes[1, 0].set_ylabel('Count')

# Bottom-right: Scatter plot
scatter_x = np.random.randn(50)
scatter_y = np.random.randn(50)
axes[1, 1].scatter(scatter_x, scatter_y, c='coral', alpha=0.6, s=50)
axes[1, 1].set_title('Random Scatter')
axes[1, 1].set_xlabel('X')
axes[1, 1].set_ylabel('Y')

# Prevent overlap and display
plt.tight_layout()
plt.show()

The subplots function returns a Figure object and an array of Axes objects. For a 2x2 grid, axes is a 2D array accessed with [row, col] indexing (zero-based). Each axes object has its own plotting methods like plot, bar, scatter, and its own set_title, set_xlabel, set_ylabel methods. The tight_layout function automatically adjusts spacing to prevent titles and labels from overlapping, which is essential for multi-panel figures.

Single Row or Column

When creating a single row or column of subplots, the axes array is 1D instead of 2D.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)

# Single row (1x3 grid)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(x, np.sin(x), 'b-')
axes[0].set_title('Sine')

axes[1].plot(x, np.cos(x), 'r-')
axes[1].set_title('Cosine')

axes[2].plot(x, np.tan(x), 'g-')
axes[2].set_ylim(-5, 5)  # Limit y for tangent
axes[2].set_title('Tangent')

plt.tight_layout()
plt.show()

# Single column (3x1 grid)
fig, axes = plt.subplots(3, 1, figsize=(8, 10))

for i, (func, name) in enumerate([(np.sin, 'Sine'), (np.cos, 'Cosine'), (np.exp, 'Exponential')]):
    if name == 'Exponential':
        axes[i].plot(x, func(x/10), 'purple')  # Scale x for exp
    else:
        axes[i].plot(x, func(x), 'teal')
    axes[i].set_title(name)
    axes[i].set_xlabel('x')
    axes[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

When subplots creates a single row (1, n) or single column (n, 1), the axes variable becomes a 1D array accessed with a single index like axes[0], axes[1]. This is more convenient than 2D indexing for simple layouts. The example also shows iterating through axes with enumerate, which is useful when applying similar formatting to multiple subplots. Each subplot maintains independent x and y limits, allowing you to zoom differently on each panel.

Shared Axes

Use shared axes to maintain the same scale across subplots, making comparisons easier.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)

# Shared X axis (useful for time series at different scales)
fig, axes = plt.subplots(2, 1, figsize=(10, 8), sharex=True)

axes[0].plot(x, np.sin(x), 'b-', linewidth=2)
axes[0].set_ylabel('sin(x)')
axes[0].set_title('Shared X-Axis Example')
axes[0].grid(True, alpha=0.3)

axes[1].plot(x, np.sin(x) + np.random.randn(100) * 0.2, 'r-', alpha=0.7)
axes[1].set_ylabel('sin(x) + noise')
axes[1].set_xlabel('x')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Shared Y axis (useful for comparing distributions)
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=True)

data1 = np.random.normal(0, 1, 500)
data2 = np.random.normal(2, 1.5, 500)
data3 = np.random.normal(-1, 0.8, 500)

for ax, data, title in zip(axes, [data1, data2, data3], ['Group A', 'Group B', 'Group C']):
    ax.hist(data, bins=20, edgecolor='black', alpha=0.7)
    ax.set_title(title)
    ax.set_xlabel('Value')

axes[0].set_ylabel('Frequency')  # Only first needs y-label
plt.tight_layout()
plt.show()

The sharex=True parameter links the x-axes of all subplots so they zoom and pan together, essential for time series or sequential data. The sharey=True parameter does the same for y-axes, crucial when comparing distributions or values that should be on the same scale. With shared axes, x-axis labels only appear on the bottom row and y-axis labels only on the leftmost column, reducing visual clutter while maintaining clarity.

GridSpec for Complex Layouts

GridSpec provides fine-grained control for creating irregular subplot grids.

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

x = np.linspace(0, 10, 100)

# Create figure with GridSpec
fig = plt.figure(figsize=(12, 8))
gs = gridspec.GridSpec(3, 3, figure=fig)

# Large plot spanning 2 rows and 2 columns
ax1 = fig.add_subplot(gs[0:2, 0:2])
ax1.plot(x, np.sin(x), 'b-', linewidth=2)
ax1.plot(x, np.cos(x), 'r-', linewidth=2)
ax1.set_title('Main Plot (2x2)', fontsize=14)
ax1.legend(['sin', 'cos'])
ax1.grid(True, alpha=0.3)

# Right column: two stacked plots
ax2 = fig.add_subplot(gs[0, 2])
ax2.bar(['A', 'B', 'C'], [3, 7, 5], color='steelblue')
ax2.set_title('Bar Chart')

ax3 = fig.add_subplot(gs[1, 2])
ax3.scatter(np.random.randn(30), np.random.randn(30), c='coral')
ax3.set_title('Scatter')

# Bottom row: three equal plots
ax4 = fig.add_subplot(gs[2, 0])
ax4.hist(np.random.randn(200), bins=15, color='purple', alpha=0.7)
ax4.set_title('Histogram 1')

ax5 = fig.add_subplot(gs[2, 1])
ax5.hist(np.random.randn(200), bins=15, color='green', alpha=0.7)
ax5.set_title('Histogram 2')

ax6 = fig.add_subplot(gs[2, 2])
ax6.hist(np.random.randn(200), bins=15, color='orange', alpha=0.7)
ax6.set_title('Histogram 3')

plt.tight_layout()
plt.show()

GridSpec creates a grid that you can slice like a NumPy array to create subplots of varying sizes. The syntax gs[row_slice, col_slice] specifies which grid cells the subplot occupies. For example, gs[0:2, 0:2] creates a subplot spanning rows 0-1 and columns 0-1 (a 2x2 area). This is invaluable for dashboards where a main visualization needs more space, surrounded by smaller supporting plots. GridSpec also supports gaps between subplots via hspace and wspace parameters.

Saving Figures

Save your visualizations to various file formats for reports, papers, or web publishing.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y = np.sin(x)

# Create a figure
plt.figure(figsize=(10, 6))
plt.plot(x, y, 'b-', linewidth=2)
plt.title('Sine Wave')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.grid(True, alpha=0.3)

# Save as PNG (raster format)
plt.savefig('sine_wave.png', dpi=150)

# Save as PDF (vector format - scales without pixelation)
plt.savefig('sine_wave.pdf')

# Save as SVG (vector format - editable in Illustrator)
plt.savefig('sine_wave.svg')

# Save with tight bounding box (removes extra whitespace)
plt.savefig('sine_wave_tight.png', dpi=200, bbox_inches='tight')

# Save with transparent background
plt.savefig('sine_wave_transparent.png', dpi=150, transparent=True)

# Save with custom background color
plt.savefig('sine_wave_dark.png', dpi=150, facecolor='#2C3E50')

plt.show()

The savefig function exports your figure to a file. PNG is ideal for web and presentations with good quality at 150-200 dpi. PDF and SVG are vector formats that scale infinitely without losing quality, perfect for print publications and editable graphics. The bbox_inches='tight' parameter crops extra whitespace, while transparent=True removes the background for overlay use. Always call savefig before show() because show() may clear the figure in some backends.

File Format Reference
Format Type Best For Extension
PNG Raster Web, presentations, screenshots .png
PDF Vector Print, papers, reports .pdf
SVG Vector Web (scalable), editing .svg
JPEG Raster Photos (lossy compression) .jpg, .jpeg
EPS Vector LaTeX, professional printing .eps

Practice: Subplots & Saving

Task: Create two plots side by side showing sine and cosine.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(x, np.sin(x))
axes[0].set_title('Sine')

axes[1].plot(x, np.cos(x))
axes[1].set_title('Cosine')

plt.tight_layout()
plt.show()

Task: Create two plots stacked vertically.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
fig, axes = plt.subplots(2, 1, figsize=(8, 8))

axes[0].plot(x, x**2)
axes[0].set_title('Quadratic')

axes[1].plot(x, np.sqrt(x))
axes[1].set_title('Square Root')

plt.tight_layout()
plt.show()

Task: Create a 2x2 grid with line, bar, scatter, and histogram.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(2, 2, figsize=(10, 8))

# Line plot
x = np.linspace(0, 10, 50)
axes[0, 0].plot(x, np.sin(x))
axes[0, 0].set_title('Line Plot')

# Bar chart
axes[0, 1].bar(['A', 'B', 'C', 'D'], [4, 7, 2, 8])
axes[0, 1].set_title('Bar Chart')

# Scatter plot
axes[1, 0].scatter(np.random.rand(30), np.random.rand(30))
axes[1, 0].set_title('Scatter Plot')

# Histogram
axes[1, 1].hist(np.random.randn(200), bins=15)
axes[1, 1].set_title('Histogram')

plt.tight_layout()
plt.show()

Task: Create two stacked plots that share the x-axis.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)

axes[0].plot(x, np.sin(x), 'b-')
axes[0].set_ylabel('sin(x)')
axes[0].set_title('Shared X-Axis')

axes[1].plot(x, np.cos(x), 'r-')
axes[1].set_ylabel('cos(x)')
axes[1].set_xlabel('x')

plt.tight_layout()
plt.show()

Task: Create a plot and save it as PNG at 200 DPI with tight bounding box.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
plt.figure(figsize=(10, 6))
plt.plot(x, np.sin(x), 'b-', linewidth=2)
plt.title('Saved Figure')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.grid(True, alpha=0.3)

plt.savefig('my_plot.png', dpi=200, bbox_inches='tight')
plt.show()

Task: Create a layout with one large plot on the left and two smaller plots stacked on the right.

Show Solution
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

x = np.linspace(0, 10, 100)

fig = plt.figure(figsize=(12, 6))
gs = gridspec.GridSpec(2, 2, figure=fig)

# Large left plot
ax1 = fig.add_subplot(gs[:, 0])
ax1.plot(x, np.sin(x), 'b-', linewidth=2)
ax1.set_title('Main Plot')

# Top-right small plot
ax2 = fig.add_subplot(gs[0, 1])
ax2.bar(['A', 'B', 'C'], [3, 7, 5])
ax2.set_title('Bar')

# Bottom-right small plot
ax3 = fig.add_subplot(gs[1, 1])
ax3.scatter(np.random.rand(20), np.random.rand(20))
ax3.set_title('Scatter')

plt.tight_layout()
plt.show()
06

Advanced Techniques

Master advanced visualization techniques including twin axes, 3D plots, animations, and real-world data integration.

Advanced Visualization Concepts

Advanced techniques help you create sophisticated visualizations for complex data scenarios:

  • Twin Axes: Plot two different scales on the same chart (e.g., temperature and precipitation)
  • Logarithmic Scales: Handle data spanning multiple orders of magnitude
  • 3D Plots: Visualize three-dimensional data relationships
  • Pandas Integration: Direct plotting from DataFrames
  • Real-time Updates: Create animated or updating visualizations

Twin Axes (Dual Y-Axis)

Display two different measurements with different scales on the same plot, sharing the x-axis.

import matplotlib.pyplot as plt
import numpy as np

# Sample data: temperature and rainfall over 12 months
months = np.arange(1, 13)
temperature = [5, 7, 12, 18, 23, 28, 31, 30, 25, 18, 11, 6]  # Celsius
rainfall = [80, 65, 70, 55, 45, 30, 25, 35, 50, 75, 85, 90]   # mm

fig, ax1 = plt.subplots(figsize=(12, 6))

# First axis: Temperature (line plot)
color1 = '#E74C3C'
ax1.set_xlabel('Month')
ax1.set_ylabel('Temperature (°C)', color=color1)
ax1.plot(months, temperature, color=color1, linewidth=2, marker='o', label='Temperature')
ax1.tick_params(axis='y', labelcolor=color1)
ax1.set_ylim(0, 35)

# Second axis: Rainfall (bar chart)
ax2 = ax1.twinx()  # Create twin axis sharing x
color2 = '#3498DB'
ax2.set_ylabel('Rainfall (mm)', color=color2)
ax2.bar(months, rainfall, color=color2, alpha=0.5, label='Rainfall')
ax2.tick_params(axis='y', labelcolor=color2)
ax2.set_ylim(0, 100)

# Title and legend
plt.title('Monthly Temperature and Rainfall')
ax1.set_xticks(months)
ax1.set_xticklabels(['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 
                     'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'])

# Combined legend
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper left')

plt.tight_layout()
plt.show()

The twinx() method creates a second y-axis that shares the same x-axis. This is essential when comparing variables with different units or scales, like temperature and rainfall. Color-code both the data and the y-axis labels to make it clear which axis corresponds to which data. The example also demonstrates combining legends from both axes into a single legend box. Use twin axes sparingly, as they can be confusing if overdone.

Logarithmic Scales

Use log scales when data spans multiple orders of magnitude or follows exponential patterns.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(1, 100, 100)
y_linear = x ** 2
y_exp = np.exp(x / 20)

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Linear scale (both axes)
axes[0, 0].plot(x, y_exp, 'b-', linewidth=2)
axes[0, 0].set_title('Linear Scale')
axes[0, 0].set_xlabel('x')
axes[0, 0].set_ylabel('y')
axes[0, 0].grid(True, alpha=0.3)

# Logarithmic Y-axis
axes[0, 1].semilogy(x, y_exp, 'r-', linewidth=2)  # Or use set_yscale('log')
axes[0, 1].set_title('Logarithmic Y-Axis (semilogy)')
axes[0, 1].set_xlabel('x')
axes[0, 1].set_ylabel('y (log scale)')
axes[0, 1].grid(True, alpha=0.3)

# Logarithmic X-axis
axes[1, 0].semilogx(x, y_linear, 'g-', linewidth=2)
axes[1, 0].set_title('Logarithmic X-Axis (semilogx)')
axes[1, 0].set_xlabel('x (log scale)')
axes[1, 0].set_ylabel('y')
axes[1, 0].grid(True, alpha=0.3)

# Log-log scale (both axes)
axes[1, 1].loglog(x, y_linear, 'm-', linewidth=2)
axes[1, 1].set_title('Log-Log Scale (loglog)')
axes[1, 1].set_xlabel('x (log scale)')
axes[1, 1].set_ylabel('y (log scale)')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Logarithmic scales compress large ranges and expand small ranges, making patterns visible that would be invisible on linear scales. Use semilogy when y values span many orders of magnitude (like population growth), semilogx when x values do (like frequency response), and loglog for power-law relationships. On a log scale, exponential curves become straight lines, and power laws become straight lines on log-log plots. This is powerful for identifying the mathematical nature of your data.

3D Surface Plots

Visualize three-dimensional data using surface plots, wireframes, and scatter plots.

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

# Create mesh grid
x = np.linspace(-5, 5, 50)
y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(x, y)
Z = np.sin(np.sqrt(X**2 + Y**2))  # Ripple pattern

fig = plt.figure(figsize=(15, 5))

# 3D Surface plot
ax1 = fig.add_subplot(131, projection='3d')
surf = ax1.plot_surface(X, Y, Z, cmap='viridis', edgecolor='none', alpha=0.8)
ax1.set_title('3D Surface')
ax1.set_xlabel('X')
ax1.set_ylabel('Y')
ax1.set_zlabel('Z')
fig.colorbar(surf, ax=ax1, shrink=0.5, aspect=10)

# 3D Wireframe
ax2 = fig.add_subplot(132, projection='3d')
ax2.plot_wireframe(X, Y, Z, color='steelblue', linewidth=0.5)
ax2.set_title('3D Wireframe')
ax2.set_xlabel('X')
ax2.set_ylabel('Y')
ax2.set_zlabel('Z')

# 3D Scatter
ax3 = fig.add_subplot(133, projection='3d')
n = 200
x_scatter = np.random.randn(n)
y_scatter = np.random.randn(n)
z_scatter = np.random.randn(n)
colors = np.sqrt(x_scatter**2 + y_scatter**2 + z_scatter**2)
scatter = ax3.scatter(x_scatter, y_scatter, z_scatter, c=colors, cmap='plasma', s=30)
ax3.set_title('3D Scatter')
ax3.set_xlabel('X')
ax3.set_ylabel('Y')
ax3.set_zlabel('Z')

plt.tight_layout()
plt.show()

3D plotting requires importing Axes3D from mpl_toolkits.mplot3d and specifying projection='3d' when creating axes. The meshgrid function creates coordinate matrices from x and y vectors, essential for surface plots. Surface plots (plot_surface) fill the surface with color, wireframes show only the grid lines, and scatter3D displays individual points in 3D space. 3D plots can be rotated interactively in Jupyter notebooks or matplotlib windows, helping viewers understand spatial relationships.

Plotting with Pandas

Pandas DataFrames have built-in plotting that uses Matplotlib under the hood.

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Create sample DataFrame
np.random.seed(42)
dates = pd.date_range('2024-01-01', periods=100)
df = pd.DataFrame({
    'Date': dates,
    'Sales': np.random.randint(100, 500, 100) + np.sin(np.arange(100) / 10) * 100,
    'Costs': np.random.randint(50, 250, 100),
    'Profit': np.random.randint(20, 200, 100)
})
df.set_index('Date', inplace=True)

# Line plot directly from DataFrame
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Line plot: all columns
df.plot(ax=axes[0, 0], linewidth=1.5)
axes[0, 0].set_title('All Columns - Line Plot')
axes[0, 0].set_ylabel('Value')
axes[0, 0].legend(loc='upper left')

# Bar plot: single column
df['Sales'].resample('W').mean().plot(kind='bar', ax=axes[0, 1], color='steelblue')
axes[0, 1].set_title('Weekly Average Sales')
axes[0, 1].set_xlabel('Week')
axes[0, 1].tick_params(axis='x', rotation=45)

# Area plot: stacked
df[['Costs', 'Profit']].plot.area(ax=axes[1, 0], alpha=0.5)
axes[1, 0].set_title('Stacked Area Plot')
axes[1, 0].set_ylabel('Value')

# Scatter plot: two columns
df.plot.scatter(x='Costs', y='Sales', ax=axes[1, 1], alpha=0.6, c='Profit', 
                cmap='viridis', colorbar=True)
axes[1, 1].set_title('Costs vs Sales (colored by Profit)')

plt.tight_layout()
plt.show()

Pandas plotting methods like df.plot(), df.plot.bar(), df.plot.scatter() provide convenient shortcuts that automatically handle labels, legends, and date formatting. The ax parameter lets you place Pandas plots on specific matplotlib axes for multi-panel figures. Pandas also supports resampling (resample) for time series aggregation before plotting. This integration makes it easy to go from data analysis directly to visualization without manual data extraction.

Error Bars

Display uncertainty or variability in your data using error bars.

import matplotlib.pyplot as plt
import numpy as np

# Sample data with errors
categories = ['Group A', 'Group B', 'Group C', 'Group D', 'Group E']
means = [45, 62, 38, 55, 70]
std_devs = [5, 8, 4, 7, 6]  # Standard deviations as error

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Vertical bar chart with error bars
axes[0].bar(categories, means, yerr=std_devs, capsize=5, 
            color='steelblue', edgecolor='black', alpha=0.8)
axes[0].set_title('Bar Chart with Error Bars')
axes[0].set_ylabel('Value')
axes[0].set_xlabel('Category')
axes[0].grid(True, axis='y', alpha=0.3)

# Line plot with error band
x = np.linspace(0, 10, 20)
y = np.sin(x) * 2 + 5
error = 0.5 + 0.3 * np.random.randn(20)

axes[1].plot(x, y, 'b-', linewidth=2, marker='o', label='Mean')
axes[1].fill_between(x, y - error, y + error, alpha=0.3, color='blue', label='±1 std')
axes[1].errorbar(x, y, yerr=error, fmt='none', ecolor='blue', capsize=3, alpha=0.5)
axes[1].set_title('Line Plot with Error Band')
axes[1].set_xlabel('X')
axes[1].set_ylabel('Y')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Error bars communicate uncertainty in your measurements, essential for scientific and statistical visualizations. The yerr parameter adds vertical error bars, while xerr handles horizontal errors. The capsize parameter controls the width of the error bar caps. For continuous data, fill_between creates error bands that are less cluttered than individual error bars. Always include error bars when presenting experimental or statistical data to help viewers assess the reliability of your findings.

Contour Plots

Visualize 3D data in 2D using contour lines or filled contours.

import matplotlib.pyplot as plt
import numpy as np

# Create data
x = np.linspace(-3, 3, 100)
y = np.linspace(-3, 3, 100)
X, Y = np.meshgrid(x, y)
Z = np.exp(-(X**2 + Y**2)) + np.exp(-((X-1.5)**2 + (Y-1.5)**2))

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Contour lines
cs1 = axes[0].contour(X, Y, Z, levels=10, colors='black')
axes[0].clabel(cs1, inline=True, fontsize=8)  # Add labels to contour lines
axes[0].set_title('Contour Lines')
axes[0].set_xlabel('X')
axes[0].set_ylabel('Y')
axes[0].set_aspect('equal')

# Filled contours
cs2 = axes[1].contourf(X, Y, Z, levels=20, cmap='viridis')
plt.colorbar(cs2, ax=axes[1], label='Z value')
axes[1].set_title('Filled Contours')
axes[1].set_xlabel('X')
axes[1].set_ylabel('Y')
axes[1].set_aspect('equal')

# Contour lines on filled background
cs3a = axes[2].contourf(X, Y, Z, levels=20, cmap='coolwarm', alpha=0.8)
cs3b = axes[2].contour(X, Y, Z, levels=10, colors='black', linewidths=0.5)
plt.colorbar(cs3a, ax=axes[2], label='Z value')
axes[2].set_title('Combined Contours')
axes[2].set_xlabel('X')
axes[2].set_ylabel('Y')
axes[2].set_aspect('equal')

plt.tight_layout()
plt.show()

Contour plots represent 3D surfaces on a 2D plane, like topographic maps. The contour function draws lines of constant value, while contourf fills regions between contour levels with colors. The clabel function adds numeric labels directly on contour lines. The levels parameter controls how many contour lines appear; more levels show finer detail. Contour plots are invaluable for optimization landscapes, probability distributions, and any application where you need to show height or intensity patterns.

Practice: Advanced Techniques

Task: Plot temperature (line) and humidity (bars) on the same figure with different y-axes.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

days = np.arange(1, 8)
temp = [22, 25, 28, 26, 24, 23, 21]
humidity = [60, 55, 50, 58, 62, 65, 70]

fig, ax1 = plt.subplots(figsize=(10, 6))

ax1.plot(days, temp, 'r-o', linewidth=2, label='Temperature')
ax1.set_xlabel('Day')
ax1.set_ylabel('Temperature (°C)', color='red')
ax1.tick_params(axis='y', labelcolor='red')

ax2 = ax1.twinx()
ax2.bar(days, humidity, color='blue', alpha=0.5, label='Humidity')
ax2.set_ylabel('Humidity (%)', color='blue')
ax2.tick_params(axis='y', labelcolor='blue')

plt.title('Weekly Weather')
plt.tight_layout()
plt.show()

Task: Plot exponential growth using semilogy to show a straight line.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y = np.exp(x)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

axes[0].plot(x, y)
axes[0].set_title('Linear Scale')
axes[0].set_xlabel('x')
axes[0].set_ylabel('exp(x)')

axes[1].semilogy(x, y)
axes[1].set_title('Logarithmic Y Scale')
axes[1].set_xlabel('x')
axes[1].set_ylabel('exp(x)')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Task: Create a 3D surface plot of z = sin(x) * cos(y).

Show Solution
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

x = np.linspace(-np.pi, np.pi, 50)
y = np.linspace(-np.pi, np.pi, 50)
X, Y = np.meshgrid(x, y)
Z = np.sin(X) * np.cos(Y)

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(X, Y, Z, cmap='viridis')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('3D Surface: sin(x) * cos(y)')
fig.colorbar(surf, shrink=0.5)
plt.show()

Task: Create a bar chart with error bars representing standard deviation.

Show Solution
import matplotlib.pyplot as plt
import numpy as np

categories = ['A', 'B', 'C', 'D']
means = [25, 40, 35, 50]
std_devs = [3, 5, 4, 6]

plt.figure(figsize=(8, 6))
plt.bar(categories, means, yerr=std_devs, capsize=5, 
        color='steelblue', edgecolor='black')
plt.title('Bar Chart with Error Bars')
plt.xlabel('Category')
plt.ylabel('Value')
plt.grid(True, axis='y', alpha=0.3)
plt.show()
07

Interactive Demo

Experiment with different plot parameters to see how they affect visualizations in real-time.

Visualization Playground

Select a chart type to see example code and learn when to use it:

Line Plot

Best for: Trends over time, continuous data, comparing multiple series.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y = np.sin(x)

plt.figure(figsize=(10, 6))
plt.plot(x, y, 'b-', linewidth=2, label='sin(x)')
plt.title('Line Plot Example')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Visualization Best Practices
DO:
  • Always include titles and axis labels
  • Choose colors with good contrast
  • Use legends when showing multiple series
  • Match chart type to data type
  • Keep visualizations simple and focused
  • Use consistent styling across related plots
DON'T:
  • Use pie charts with too many slices
  • Start y-axis at non-zero without good reason
  • Use 3D when 2D is sufficient
  • Overcrowd plots with too much data
  • Use rainbow colormaps for sequential data
  • Forget to add units to labels
Common Mistakes & Fixes
Mistake Problem Fix
plt.show() then plt.savefig() Empty file saved Call savefig() before show()
Forgetting plt.tight_layout() Overlapping labels Add plt.tight_layout() before show
Using plt.plot() for categories Connected points Use plt.bar() for categories
Too many colors in legend Visual clutter Group data or use subplots
Default figure size Plots too small Use figsize=(10, 6) or larger

Real-World Examples

See how data visualization is applied in real scenarios across different industries and use cases.

Stock Price Analysis

Visualize stock prices with moving averages to identify trends and trading signals.

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Simulate stock data
np.random.seed(42)
dates = pd.date_range('2024-01-01', periods=100)
prices = 100 + np.cumsum(np.random.randn(100) * 2)

# Calculate moving averages
ma_20 = pd.Series(prices).rolling(20).mean()
ma_50 = pd.Series(prices).rolling(50).mean()

fig, ax = plt.subplots(figsize=(12, 6))

ax.plot(dates, prices, 'b-', alpha=0.6, label='Price')
ax.plot(dates, ma_20, 'r-', linewidth=2, label='20-day MA')
ax.plot(dates, ma_50, 'g-', linewidth=2, label='50-day MA')

# Highlight buy/sell signals
ax.fill_between(dates, prices, ma_20, 
                where=(prices > ma_20), 
                alpha=0.2, color='green', label='Above MA')
ax.fill_between(dates, prices, ma_20, 
                where=(prices < ma_20), 
                alpha=0.2, color='red', label='Below MA')

ax.set_title('Stock Price with Moving Averages')
ax.set_xlabel('Date')
ax.set_ylabel('Price ($)')
ax.legend(loc='upper left')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Health Data Dashboard

Create a dashboard showing vital signs over time with normal range indicators.

import matplotlib.pyplot as plt
import numpy as np

# Simulate daily health data
days = np.arange(1, 31)
heart_rate = 70 + np.random.randn(30) * 8
blood_pressure_sys = 120 + np.random.randn(30) * 10
blood_pressure_dia = 80 + np.random.randn(30) * 5

fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

# Heart rate plot
axes[0].plot(days, heart_rate, 'r-o', markersize=4)
axes[0].axhspan(60, 100, alpha=0.2, color='green', label='Normal')
axes[0].set_ylabel('Heart Rate (bpm)')
axes[0].set_title('30-Day Health Monitoring')
axes[0].legend(loc='upper right')
axes[0].grid(True, alpha=0.3)

# Blood pressure plot
axes[1].plot(days, blood_pressure_sys, 'b-o', markersize=4, label='Systolic')
axes[1].plot(days, blood_pressure_dia, 'c-s', markersize=4, label='Diastolic')
axes[1].axhspan(90, 120, alpha=0.2, color='green', label='Normal Sys')
axes[1].set_xlabel('Day')
axes[1].set_ylabel('Blood Pressure (mmHg)')
axes[1].legend(loc='upper right')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
Sales Analysis

Analyze sales data by region and product category with grouped bar charts.

import matplotlib.pyplot as plt
import numpy as np

# Sales data
regions = ['North', 'South', 'East', 'West']
products = ['Electronics', 'Clothing', 'Food']
x = np.arange(len(regions))
width = 0.25

sales = {
    'Electronics': [120, 90, 150, 110],
    'Clothing': [80, 120, 100, 95],
    'Food': [150, 130, 140, 160]
}

fig, ax = plt.subplots(figsize=(10, 6))

for i, (product, values) in enumerate(sales.items()):
    offset = width * i
    bars = ax.bar(x + offset, values, width, label=product)
    ax.bar_label(bars, padding=3, fontsize=8)

ax.set_xlabel('Region')
ax.set_ylabel('Sales (thousands $)')
ax.set_title('Q4 Sales by Region and Product')
ax.set_xticks(x + width)
ax.set_xticklabels(regions)
ax.legend(loc='upper left')
ax.grid(True, axis='y', alpha=0.3)

plt.tight_layout()
plt.show()
Weather Analysis

Visualize temperature and precipitation patterns with dual-axis plots.

import matplotlib.pyplot as plt
import numpy as np

months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
          'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
temp = [5, 7, 12, 18, 23, 28, 31, 30, 25, 18, 11, 6]
rain = [80, 65, 70, 55, 45, 30, 25, 35, 50, 75, 85, 90]

fig, ax1 = plt.subplots(figsize=(12, 6))

# Temperature line
color1 = '#E74C3C'
ax1.plot(months, temp, color=color1, marker='o', linewidth=2)
ax1.fill_between(months, temp, alpha=0.2, color=color1)
ax1.set_xlabel('Month')
ax1.set_ylabel('Temperature (°C)', color=color1)
ax1.tick_params(axis='y', labelcolor=color1)
ax1.set_ylim(0, 35)

# Rainfall bars
ax2 = ax1.twinx()
color2 = '#3498DB'
ax2.bar(months, rain, color=color2, alpha=0.5)
ax2.set_ylabel('Rainfall (mm)', color=color2)
ax2.tick_params(axis='y', labelcolor=color2)
ax2.set_ylim(0, 100)

plt.title('Annual Temperature and Rainfall')
plt.tight_layout()
plt.show()
Machine Learning Model Evaluation

Visualize model performance with confusion matrix and ROC curve.

import matplotlib.pyplot as plt
import numpy as np

# Simulate confusion matrix and ROC data
confusion_matrix = np.array([[85, 15], [10, 90]])
fpr = np.array([0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0])
tpr = np.array([0, 0.5, 0.7, 0.85, 0.92, 0.97, 1.0])
auc = 0.89

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Confusion Matrix
im = axes[0].imshow(confusion_matrix, cmap='Blues')
axes[0].set_xticks([0, 1])
axes[0].set_yticks([0, 1])
axes[0].set_xticklabels(['Predicted Negative', 'Predicted Positive'])
axes[0].set_yticklabels(['Actual Negative', 'Actual Positive'])
axes[0].set_title('Confusion Matrix')

# Add text annotations
for i in range(2):
    for j in range(2):
        axes[0].text(j, i, str(confusion_matrix[i, j]), 
                    ha='center', va='center', fontsize=20,
                    color='white' if confusion_matrix[i, j] > 50 else 'black')

plt.colorbar(im, ax=axes[0])

# ROC Curve
axes[1].plot(fpr, tpr, 'b-', linewidth=2, label=f'ROC (AUC = {auc:.2f})')
axes[1].plot([0, 1], [0, 1], 'r--', label='Random Classifier')
axes[1].fill_between(fpr, tpr, alpha=0.3)
axes[1].set_xlabel('False Positive Rate')
axes[1].set_ylabel('True Positive Rate')
axes[1].set_title('ROC Curve')
axes[1].legend(loc='lower right')
axes[1].grid(True, alpha=0.3)
axes[1].set_xlim([0, 1])
axes[1].set_ylim([0, 1])

plt.tight_layout()
plt.show()
08

Quick Reference

A comprehensive cheat sheet of commonly used Matplotlib functions, parameters, and patterns.

Plot Functions
Function Description Example
plt.plot() Line plot plt.plot(x, y, 'b-o')
plt.bar() Vertical bar chart plt.bar(cats, vals)
plt.barh() Horizontal bar chart plt.barh(cats, vals)
plt.scatter() Scatter plot plt.scatter(x, y, c=colors)
plt.hist() Histogram plt.hist(data, bins=30)
plt.pie() Pie chart plt.pie(sizes, labels=names)
plt.boxplot() Box plot plt.boxplot(data)
plt.contour() Contour lines plt.contour(X, Y, Z)
plt.contourf() Filled contours plt.contourf(X, Y, Z)
plt.imshow() Display image/matrix plt.imshow(data, cmap='viridis')
Customization Functions
Function Description Example
plt.title() Set plot title plt.title('My Plot', fontsize=14)
plt.xlabel() Set x-axis label plt.xlabel('Time (s)')
plt.ylabel() Set y-axis label plt.ylabel('Value')
plt.xlim() Set x-axis range plt.xlim(0, 10)
plt.ylim() Set y-axis range plt.ylim(-1, 1)
plt.xticks() Set x tick positions plt.xticks([0, 5, 10])
plt.yticks() Set y tick positions plt.yticks([0, 0.5, 1])
plt.legend() Add legend plt.legend(loc='upper left')
plt.grid() Add grid lines plt.grid(True, alpha=0.3)
plt.text() Add text annotation plt.text(x, y, 'label')
plt.annotate() Add annotation with arrow plt.annotate('note', xy=(x,y))
plt.colorbar() Add color scale bar plt.colorbar(label='Value')
Figure & Layout Functions
Function Description Example
plt.figure() Create new figure plt.figure(figsize=(10, 6))
plt.subplots() Create figure with subplots fig, ax = plt.subplots(2, 2)
plt.subplot() Add single subplot plt.subplot(2, 2, 1)
plt.tight_layout() Adjust layout spacing plt.tight_layout()
plt.savefig() Save figure to file plt.savefig('plot.png', dpi=150)
plt.show() Display the figure plt.show()
plt.clf() Clear current figure plt.clf()
plt.close() Close figure window plt.close('all')
ax.twinx() Create twin y-axis ax2 = ax.twinx()
ax.twiny() Create twin x-axis ax2 = ax.twiny()
Line Styles
Code Style
'-'Solid line ―
'--'Dashed line - -
'-.'Dash-dot line -·
':'Dotted line ···
''No line (markers only)
Markers
Code Marker
'o'Circle ●
's'Square ■
'^'Triangle ▲
'*'Star ★
'+'Plus +
'x'Cross ×
'D'Diamond ◆
Color Codes
Code Color
'b' Blue
'g' Green
'r' Red
'c' Cyan
'm' Magenta
'y' Yellow
'k' Black
'w' White
Popular Colormaps
Name Use Case
'viridis'Default, perceptually uniform
'plasma'High contrast sequential
'coolwarm'Diverging (blue to red)
'RdYlGn'Diverging (red to green)
'Blues'Sequential blues
'hot'Black to white via red
'gray'Grayscale
'tab10'Categorical (10 colors)
Ready-to-Use Templates
Basic Plot Template
import matplotlib.pyplot as plt
import numpy as np

# Data
x = np.linspace(0, 10, 100)
y = np.sin(x)

# Create figure
plt.figure(figsize=(10, 6))

# Plot
plt.plot(x, y, 'b-', linewidth=2, label='Data')

# Customize
plt.title('Title', fontsize=14)
plt.xlabel('X Label')
plt.ylabel('Y Label')
plt.legend()
plt.grid(True, alpha=0.3)

# Save and show
plt.tight_layout()
plt.savefig('plot.png', dpi=150, bbox_inches='tight')
plt.show()
Multi-Panel Template
import matplotlib.pyplot as plt
import numpy as np

# Data
x = np.linspace(0, 10, 100)

# Create subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plot each panel
axes[0, 0].plot(x, np.sin(x))
axes[0, 0].set_title('Panel 1')

axes[0, 1].plot(x, np.cos(x))
axes[0, 1].set_title('Panel 2')

axes[1, 0].bar(['A', 'B', 'C'], [1, 2, 3])
axes[1, 0].set_title('Panel 3')

axes[1, 1].scatter(np.random.rand(20), np.random.rand(20))
axes[1, 1].set_title('Panel 4')

# Adjust and show
plt.tight_layout()
plt.savefig('multi_panel.png', dpi=150)
plt.show()
Publication-Quality Template
import matplotlib.pyplot as plt
import numpy as np

# Set style for publications
plt.rcParams.update({
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 16,
    'legend.fontsize': 11,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'figure.figsize': (8, 6),
    'figure.dpi': 100,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight'
})

# Data
x = np.linspace(0, 2*np.pi, 100)

# Create figure with object-oriented API
fig, ax = plt.subplots()

# Plot with styling
ax.plot(x, np.sin(x), 'b-', linewidth=2, label=r'$\sin(x)$')
ax.plot(x, np.cos(x), 'r--', linewidth=2, label=r'$\cos(x)$')

# Customize
ax.set_xlabel(r'$x$ (radians)')
ax.set_ylabel(r'$f(x)$')
ax.set_title('Trigonometric Functions')
ax.legend(loc='upper right', frameon=True)
ax.grid(True, alpha=0.3, linestyle='--')

# Remove top and right spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Save for publication
plt.savefig('publication_plot.pdf')
plt.savefig('publication_plot.png')
plt.show()
Frequently Asked Questions

Use plt.rcParams.update({'font.size': 14}) to change the default font size for all text elements, or use parameters like fontsize=14 on individual elements.

You likely called plt.show() before plt.savefig(). The show() function clears the figure in some backends. Always call savefig() first.

Try using a style with plt.style.use('seaborn-v0_8') or plt.style.use('ggplot'). Also increase figure size with figsize=(10, 6) and remove top/right spines for a cleaner look.

plt.plot() is the pyplot interface (simpler, implicit). ax.plot() is the object-oriented interface (more control, explicit). Use OO style for complex plots and subplots.

Key Takeaways

Figure and Axes

Figure is the canvas, Axes is where data is plotted. One figure can have multiple axes.

Choose Right Chart

Line for trends, bar for categories, scatter for relationships, histogram for distributions.

Always Label

Add titles, axis labels, and legends to make plots self-explanatory.

Subplots for Comparison

Use subplots() to create multi-panel figures for side-by-side comparisons.

Save Before Show

Call savefig() before show() to save your plots to files.

Use Styles

plt.style.use() applies pre-made themes for consistent, attractive plots.

Knowledge Check

Quick Quiz

Test what you've learned about data visualization with Matplotlib

1 What is the difference between Figure and Axes?
2 Which chart type is best for showing distribution of values?
3 What does plt.tight_layout() do?
4 When should you call plt.savefig()?
5 How do you create a 2x2 grid of subplots?
6 Which chart type is best for showing the relationship between two variables?
Answer all questions to check your score