Introduction to Matplotlib
Matplotlib is the foundational visualization library in Python's data science ecosystem. Created by John Hunter in 2003, it provides a MATLAB-like interface for creating static, animated, and interactive visualizations. Whether you need a quick exploratory plot or a publication-ready figure, Matplotlib gives you complete control over every aspect of your visualization.
Why Matplotlib?
While newer libraries like Seaborn and Plotly offer higher-level interfaces, Matplotlib remains essential for several reasons. First, it provides the most granular control over plot elements, allowing you to customize everything from tick marks to annotation positions. Second, most other Python visualization libraries are built on top of Matplotlib, so understanding it helps you work with the entire ecosystem. Third, it excels at creating publication-quality figures with precise formatting requirements.
Matplotlib
A comprehensive library for creating static, animated, and interactive visualizations in Python. It produces publication-quality figures in a variety of formats and interactive environments.
Getting Started
To use Matplotlib, you first need to install and import it. The most common approach is to import the pyplot module, which provides a convenient interface similar to MATLAB. By convention, we import it as plt, which you will see in virtually all Python data science code.
# Install matplotlib (run in terminal)
# pip install matplotlib
# Standard import convention
import matplotlib.pyplot as plt
import numpy as np
# Check version
print(plt.matplotlib.__version__) # 3.8.2 (or your installed version)
Your First Plot
Creating a basic plot in Matplotlib is remarkably simple. You pass your data to a plotting function, add some labels, and display the result. The library handles all the complex rendering details behind the scenes, allowing you to focus on your data rather than the mechanics of drawing.
# Create sample data
x = [1, 2, 3, 4, 5]
y = [2, 4, 6, 8, 10]
# Create a simple line plot
plt.plot(x, y)
plt.xlabel('X Values')
plt.ylabel('Y Values')
plt.title('My First Matplotlib Plot')
plt.show()
%matplotlib inline at the top of your notebook to display plots directly below code cells. For interactive plots, use %matplotlib widget.Two Interfaces: pyplot vs Object-Oriented
Matplotlib offers two ways to create plots. The pyplot interface (also called the state-based interface) is simpler and works well for quick plots. The object-oriented interface gives you more control and is better for complex figures with multiple subplots. Professional data scientists typically use the object-oriented approach for production code.
# pyplot interface (state-based) - simpler but less control
plt.figure(figsize=(8, 4))
plt.plot([1, 2, 3], [1, 4, 9])
plt.title('pyplot Style')
plt.show()
# Object-oriented interface - more control, recommended for complex plots
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot([1, 2, 3], [1, 4, 9])
ax.set_title('Object-Oriented Style')
plt.show()
pyplot Interface
- Quick exploratory plots
- Single figure, simple layouts
- Uses
plt.function()syntax - State-based, tracks current figure
Object-Oriented
- Complex multi-panel figures
- Fine-grained customization
- Uses
ax.method()syntax - Explicit figure and axes objects
Figure & Axes Architecture
Understanding Matplotlib's architecture is crucial for creating effective visualizations. At its core, every Matplotlib plot consists of a Figure container that holds one or more Axes objects. The Figure is like a canvas, while each Axes is an individual plot within that canvas. Mastering this hierarchy unlocks the full power of the library.
The Matplotlib Hierarchy
Matplotlib organizes visualizations in a clear hierarchy. The Figure is the top-level container that represents the entire image. Inside the Figure, you have one or more Axes objects, which are the actual plots where data gets visualized. Each Axes contains additional elements like the x-axis, y-axis, title, and legend. Understanding this structure helps you manipulate any part of your visualization.
Figure & Axes
Figure: The top-level container for all plot elements. Think of it as the window or page that holds your visualization.
Axes: The area where data is plotted, including the x-axis, y-axis, and all visual elements like lines and markers.
# Creating Figure and Axes explicitly
fig = plt.figure(figsize=(10, 6)) # Create a figure 10 inches wide, 6 inches tall
ax = fig.add_subplot(111) # Add a single axes (1 row, 1 col, position 1)
# Plot on the axes
ax.plot([1, 2, 3, 4], [1, 4, 2, 3])
ax.set_xlabel('X Axis Label')
ax.set_ylabel('Y Axis Label')
ax.set_title('Understanding Figure and Axes')
plt.show()
The subplots() Shortcut
While you can create figures and axes separately, the plt.subplots() function provides a convenient
way to create both at once. This is the most common pattern you will see in professional code. It returns a
tuple containing the figure and axes, which you can unpack directly into variables.
# The most common pattern: create figure and axes together
fig, ax = plt.subplots(figsize=(10, 6))
# Now use ax to plot and customize
x = np.linspace(0, 10, 100)
ax.plot(x, np.sin(x), label='sin(x)')
ax.plot(x, np.cos(x), label='cos(x)')
ax.legend()
ax.set_title('Sine and Cosine Waves')
plt.show()
Figure Properties
The Figure object controls overall properties like size, resolution, and background color. Setting the right
figure size is important for both display and export. The figsize parameter takes a tuple of
width and height in inches. The dpi parameter controls resolution, with higher values producing
sharper but larger images.
# Figure with custom properties
fig, ax = plt.subplots(
figsize=(12, 8), # Width, height in inches
dpi=100, # Dots per inch (resolution)
facecolor='white' # Background color
)
ax.plot([1, 2, 3], [1, 2, 3])
ax.set_title('Custom Figure Properties')
# Access figure properties
print(f"Figure size: {fig.get_size_inches()}") # [12. 8.]
print(f"DPI: {fig.get_dpi()}") # 100.0
plt.tight_layout() # Adjust spacing to prevent overlap
plt.show()
Axes Properties and Methods
The Axes object provides methods to customize every aspect of your plot. Methods starting with set_
configure properties like labels and limits. The Axes also manages the actual data visualization through
plotting methods like plot(), scatter(), and bar(). Learning the key
Axes methods is essential for effective visualization.
fig, ax = plt.subplots(figsize=(10, 6))
# Sample data
x = np.arange(1, 11)
y = x ** 2
# Plot the data
ax.plot(x, y, color='blue', linewidth=2, marker='o')
# Customize axes properties
ax.set_xlabel('X Values', fontsize=12)
ax.set_ylabel('Y Squared', fontsize=12)
ax.set_title('Customizing Axes Properties', fontsize=14, fontweight='bold')
ax.set_xlim(0, 12) # Set x-axis limits
ax.set_ylim(0, 120) # Set y-axis limits
ax.set_xticks(range(0, 13, 2)) # Custom tick positions
ax.grid(True, linestyle='--', alpha=0.7) # Add grid
plt.show()
| Method | Description | Example |
|---|---|---|
ax.set_xlabel() |
Set x-axis label | ax.set_xlabel('Time (s)') |
ax.set_ylabel() |
Set y-axis label | ax.set_ylabel('Value') |
ax.set_title() |
Set plot title | ax.set_title('My Plot') |
ax.set_xlim() |
Set x-axis range | ax.set_xlim(0, 100) |
ax.set_ylim() |
Set y-axis range | ax.set_ylim(-1, 1) |
ax.legend() |
Display legend | ax.legend(loc='best') |
ax.grid() |
Add grid lines | ax.grid(True) |
Practice Questions
Task: Create a figure that is 8 inches wide and 5 inches tall, then plot the squares of numbers 1-5.
# Given: Numbers 1-5
numbers = [1, 2, 3, 4, 5]
# Your code here: Create figure with figsize, calculate squares, plot, add title
Expected Output: A line plot showing points (1,1), (2,4), (3,9), (4,16), (5,25) with title "Squares"
numbers = [1, 2, 3, 4, 5]
fig, ax = plt.subplots(figsize=(8, 5))
squares = [n ** 2 for n in numbers]
ax.plot(numbers, squares, marker='o')
ax.set_title('Squares')
ax.set_xlabel('Number')
ax.set_ylabel('Square')
plt.show()
Task: Create a plot with custom x and y limits, a grid, and formatted labels.
# Given: Temperature data
hours = [6, 9, 12, 15, 18, 21]
temps = [15, 20, 28, 32, 25, 18]
# Your code: Plot temps, set xlim 0-24, ylim 10-40, add grid, label axes
Expected Output: Temperature plot with proper axis limits, grid, and descriptive labels
hours = [6, 9, 12, 15, 18, 21]
temps = [15, 20, 28, 32, 25, 18]
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(hours, temps, marker='s', color='red', linewidth=2)
ax.set_xlim(0, 24)
ax.set_ylim(10, 40)
ax.set_xlabel('Hour of Day', fontsize=12)
ax.set_ylabel('Temperature (°C)', fontsize=12)
ax.set_title('Daily Temperature Profile', fontsize=14)
ax.grid(True, linestyle='--', alpha=0.7)
ax.set_xticks([0, 6, 12, 18, 24])
plt.show()
Task: Plot three mathematical functions on the same axes with a legend.
# Given: x values from 0 to 2π
x = np.linspace(0, 2 * np.pi, 100)
# Your code: Plot sin(x), cos(x), and tan(x) clipped to [-2, 2]
# Add labels, legend (upper right), title, and grid
Expected Output: Three curves with different colors, a legend, and appropriate styling
x = np.linspace(0, 2 * np.pi, 100)
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(x, np.sin(x), label='sin(x)', color='blue')
ax.plot(x, np.cos(x), label='cos(x)', color='red')
ax.plot(x, np.clip(np.tan(x), -2, 2), label='tan(x)', color='green', linestyle='--')
ax.set_xlabel('x (radians)')
ax.set_ylabel('f(x)')
ax.set_title('Trigonometric Functions')
ax.set_ylim(-2.5, 2.5)
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
plt.show()
Basic Plot Types
Matplotlib provides a rich variety of plot types for different data visualization needs. From line plots showing trends over time to scatter plots revealing relationships between variables, each plot type serves a specific purpose. Mastering these fundamental chart types gives you the building blocks for more complex visualizations.
Interactive: Choose the Right Chart
Decision HelperWhat kind of data do you want to visualize? Click to see the best chart type for your needs.
Line Plots
Line plots are ideal for showing continuous data and trends over time or ordered categories. They connect data points with lines, making it easy to see patterns, trends, and fluctuations. Use line plots when your x-axis represents a continuous variable like time, or when you want to emphasize the connection between sequential data points.
# Basic line plot
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
sales = [12000, 15000, 13500, 17000, 19000, 21000]
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(months, sales, marker='o', linewidth=2, markersize=8, color='#3498db')
ax.set_xlabel('Month')
ax.set_ylabel('Sales ($)')
ax.set_title('Monthly Sales Performance')
ax.grid(True, axis='y', alpha=0.3)
plt.show()
# Multiple lines for comparison
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(months, sales, marker='o', label='2024', linewidth=2)
ax.plot(months, [11000, 13000, 14000, 15500, 17000, 18500], marker='s', label='2023', linewidth=2)
ax.legend()
ax.set_title('Year-over-Year Sales Comparison')
plt.show()
Scatter Plots
Scatter plots display individual data points as markers, making them perfect for exploring relationships between two numerical variables. They help identify correlations, clusters, and outliers in your data. You can also encode additional variables using color and size to create information-rich visualizations.
# Basic scatter plot
np.random.seed(42)
study_hours = np.random.uniform(1, 10, 50)
exam_scores = 50 + 5 * study_hours + np.random.normal(0, 5, 50)
fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(study_hours, exam_scores, alpha=0.7, s=100, c='#e74c3c', edgecolors='white')
ax.set_xlabel('Study Hours')
ax.set_ylabel('Exam Score')
ax.set_title('Study Hours vs Exam Performance')
plt.show()
# Scatter with color mapping (third variable)
ages = np.random.randint(18, 25, 50)
scatter = ax.scatter(study_hours, exam_scores, c=ages, cmap='viridis', s=100, alpha=0.7)
plt.colorbar(scatter, label='Student Age')
plt.show()
Key Scatter Parameters
s controls marker size, c sets color (can be array for colormap),
alpha controls transparency (0-1), and edgecolors sets marker border color.
Bar Charts
Bar charts compare discrete categories using rectangular bars. They are excellent for showing quantities
across different groups or categories. Matplotlib supports both vertical bars (bar()) and
horizontal bars (barh()), as well as grouped and stacked variations for comparing multiple series.
# Vertical bar chart
categories = ['Electronics', 'Clothing', 'Food', 'Books', 'Sports']
revenue = [45000, 32000, 28000, 15000, 22000]
fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.bar(categories, revenue, color=['#3498db', '#e74c3c', '#2ecc71', '#f39c12', '#9b59b6'])
ax.set_xlabel('Category')
ax.set_ylabel('Revenue ($)')
ax.set_title('Revenue by Product Category')
# Add value labels on bars
for bar, val in zip(bars, revenue):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 500,
f'${val:,}', ha='center', va='bottom', fontsize=10)
plt.show()
# Horizontal bar chart (good for long category names)
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(categories, revenue, color='#3498db')
ax.set_xlabel('Revenue ($)')
ax.set_title('Revenue by Product Category')
plt.show()
Grouped and Stacked Bar Charts
When comparing multiple data series across categories, grouped bars place bars side-by-side while stacked bars pile them on top of each other. Grouped bars are better for comparing individual values, while stacked bars show how parts contribute to a whole.
# Grouped bar chart
categories = ['Q1', 'Q2', 'Q3', 'Q4']
product_a = [25, 30, 35, 40]
product_b = [20, 28, 32, 38]
x = np.arange(len(categories))
width = 0.35
fig, ax = plt.subplots(figsize=(10, 6))
bars1 = ax.bar(x - width/2, product_a, width, label='Product A', color='#3498db')
bars2 = ax.bar(x + width/2, product_b, width, label='Product B', color='#e74c3c')
ax.set_xlabel('Quarter')
ax.set_ylabel('Sales (thousands)')
ax.set_title('Quarterly Sales Comparison')
ax.set_xticks(x)
ax.set_xticklabels(categories)
ax.legend()
plt.show()
# Stacked bar chart
fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(categories, product_a, label='Product A', color='#3498db')
ax.bar(categories, product_b, bottom=product_a, label='Product B', color='#e74c3c')
ax.set_ylabel('Total Sales (thousands)')
ax.set_title('Quarterly Sales - Stacked')
ax.legend()
plt.show()
Histograms
Histograms visualize the distribution of numerical data by grouping values into bins. They show how frequently values occur within different ranges, helping you understand the shape of your data. Use histograms to identify patterns like normal distributions, skewness, or multimodal distributions.
# Basic histogram
np.random.seed(42)
exam_scores = np.random.normal(75, 10, 200) # Mean=75, Std=10, 200 students
fig, ax = plt.subplots(figsize=(10, 6))
counts, bins, patches = ax.hist(exam_scores, bins=20, color='#3498db',
edgecolor='white', alpha=0.7)
ax.set_xlabel('Exam Score')
ax.set_ylabel('Number of Students')
ax.set_title('Distribution of Exam Scores')
ax.axvline(exam_scores.mean(), color='red', linestyle='--', label=f'Mean: {exam_scores.mean():.1f}')
ax.legend()
plt.show()
# Histogram with density curve
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(exam_scores, bins=20, density=True, alpha=0.7, color='#3498db', edgecolor='white')
# Overlay a smooth density curve
from scipy import stats
x_range = np.linspace(exam_scores.min(), exam_scores.max(), 100)
ax.plot(x_range, stats.norm.pdf(x_range, exam_scores.mean(), exam_scores.std()),
'r-', linewidth=2, label='Normal Distribution')
ax.set_xlabel('Exam Score')
ax.set_ylabel('Density')
ax.legend()
plt.show()
Pie Charts
Pie charts show parts of a whole as slices of a circle. While often overused, they work well for
displaying simple proportions with a small number of categories (typically 5 or fewer). Use the
explode parameter to emphasize specific slices and autopct to display percentages.
# Pie chart with percentages
categories = ['Electronics', 'Clothing', 'Food', 'Books', 'Other']
market_share = [35, 25, 20, 10, 10]
colors = ['#3498db', '#e74c3c', '#2ecc71', '#f39c12', '#95a5a6']
explode = [0.05, 0, 0, 0, 0] # Slightly separate first slice
fig, ax = plt.subplots(figsize=(10, 8))
wedges, texts, autotexts = ax.pie(market_share, labels=categories, colors=colors,
autopct='%1.1f%%', explode=explode,
shadow=True, startangle=90)
# Style the percentage labels
for autotext in autotexts:
autotext.set_fontsize(11)
autotext.set_fontweight('bold')
ax.set_title('Market Share by Category', fontsize=14, fontweight='bold')
plt.show()
Practice Questions
Task: Create a bar chart showing website traffic by day of week.
# Given: Daily traffic data
days = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
visitors = [1200, 1500, 1400, 1600, 1800, 2200, 1900]
# Your code: Create bar chart with different color for weekends
Expected Output: Bar chart with weekdays in blue, weekends in green
days = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
visitors = [1200, 1500, 1400, 1600, 1800, 2200, 1900]
colors = ['#3498db'] * 5 + ['#2ecc71'] * 2 # Blue weekdays, green weekends
fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(days, visitors, color=colors)
ax.set_xlabel('Day of Week')
ax.set_ylabel('Visitors')
ax.set_title('Website Traffic by Day')
plt.show()
Task: Create a scatter plot of house prices vs. square footage with a trend line.
# Given: Housing data
np.random.seed(42)
sqft = np.random.uniform(1000, 3000, 30)
price = 50000 + 150 * sqft + np.random.normal(0, 30000, 30)
# Your code: Scatter plot + linear trend line using np.polyfit
Expected Output: Scatter plot with red trend line, proper labels
np.random.seed(42)
sqft = np.random.uniform(1000, 3000, 30)
price = 50000 + 150 * sqft + np.random.normal(0, 30000, 30)
fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(sqft, price, alpha=0.7, s=80, c='#3498db')
# Add trend line
z = np.polyfit(sqft, price, 1)
p = np.poly1d(z)
x_line = np.linspace(sqft.min(), sqft.max(), 100)
ax.plot(x_line, p(x_line), 'r-', linewidth=2, label='Trend')
ax.set_xlabel('Square Footage')
ax.set_ylabel('Price ($)')
ax.set_title('House Prices vs. Square Footage')
ax.legend()
plt.show()
Task: Create a histogram of employee salaries with mean and median lines.
# Given: Salary data (right-skewed distribution)
np.random.seed(42)
salaries = np.random.exponential(scale=50000, size=500) + 30000
# Your code: Histogram with 25 bins, vertical lines for mean (red) and median (green)
Expected Output: Histogram with legend showing mean and median values
np.random.seed(42)
salaries = np.random.exponential(scale=50000, size=500) + 30000
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(salaries, bins=25, color='#3498db', edgecolor='white', alpha=0.7)
mean_sal = salaries.mean()
median_sal = np.median(salaries)
ax.axvline(mean_sal, color='red', linestyle='--', linewidth=2, label=f'Mean: ${mean_sal:,.0f}')
ax.axvline(median_sal, color='green', linestyle='-', linewidth=2, label=f'Median: ${median_sal:,.0f}')
ax.set_xlabel('Salary ($)')
ax.set_ylabel('Frequency')
ax.set_title('Employee Salary Distribution')
ax.legend()
plt.show()
Task: Create a grouped bar chart comparing revenue across regions and quarters.
# Given: Revenue data by region
quarters = ['Q1', 'Q2', 'Q3', 'Q4']
north = [120, 145, 160, 180]
south = [100, 130, 140, 155]
west = [90, 110, 125, 145]
# Your code: Grouped bar chart with 3 bars per quarter, legend, and value labels
Expected Output: Three colored bars per quarter with values displayed on top
quarters = ['Q1', 'Q2', 'Q3', 'Q4']
north = [120, 145, 160, 180]
south = [100, 130, 140, 155]
west = [90, 110, 125, 145]
x = np.arange(len(quarters))
width = 0.25
fig, ax = plt.subplots(figsize=(12, 6))
bars1 = ax.bar(x - width, north, width, label='North', color='#3498db')
bars2 = ax.bar(x, south, width, label='South', color='#e74c3c')
bars3 = ax.bar(x + width, west, width, label='West', color='#2ecc71')
# Add value labels
for bars in [bars1, bars2, bars3]:
for bar in bars:
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2,
f'{bar.get_height():.0f}', ha='center', va='bottom', fontsize=9)
ax.set_xlabel('Quarter')
ax.set_ylabel('Revenue (thousands)')
ax.set_title('Quarterly Revenue by Region')
ax.set_xticks(x)
ax.set_xticklabels(quarters)
ax.legend()
plt.show()
Customizing Your Visualizations
The true power of Matplotlib lies in its extensive customization options. You can control every visual aspect of your plots, from colors and fonts to line styles and marker shapes. Learning these customization techniques helps you create professional, publication-ready figures that effectively communicate your data story.
Colors in Matplotlib
Matplotlib accepts colors in multiple formats, giving you flexibility in how you specify them. You can use named colors like 'red' or 'steelblue', hex codes like '#3498db', RGB tuples like (0.2, 0.4, 0.6), or shorthand codes like 'b' for blue. Using consistent, accessible colors improves your visualization quality significantly.
# Different ways to specify colors
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
x = [1, 2, 3, 4, 5]
y = [2, 4, 6, 8, 10]
# Named color
axes[0, 0].plot(x, y, color='steelblue', linewidth=3)
axes[0, 0].set_title("Named: 'steelblue'")
# Hex code
axes[0, 1].plot(x, y, color='#e74c3c', linewidth=3)
axes[0, 1].set_title("Hex: '#e74c3c'")
# RGB tuple (0-1 range)
axes[0, 2].plot(x, y, color=(0.2, 0.6, 0.3), linewidth=3)
axes[0, 2].set_title("RGB: (0.2, 0.6, 0.3)")
# Shorthand codes
axes[1, 0].plot(x, y, color='g', linewidth=3) # green
axes[1, 0].set_title("Shorthand: 'g'")
# With alpha transparency
axes[1, 1].plot(x, y, color='purple', alpha=0.5, linewidth=10)
axes[1, 1].set_title("With alpha=0.5")
# Using colormaps
colors = plt.cm.viridis(np.linspace(0, 1, 5))
for i, c in enumerate(colors):
axes[1, 2].plot([1, 2], [i, i+1], color=c, linewidth=3)
axes[1, 2].set_title("Colormap: viridis")
plt.tight_layout()
plt.show()
Line Styles and Markers
Differentiating multiple data series requires varying line styles and markers. Matplotlib provides solid, dashed, dotted, and dash-dot line styles. Markers include circles, squares, triangles, and many more. Combining these options creates visually distinct lines that remain readable even in black-and-white printing.
# Line styles
fig, ax = plt.subplots(figsize=(12, 6))
x = np.linspace(0, 10, 50)
ax.plot(x, np.sin(x), linestyle='-', label='Solid (-)', linewidth=2)
ax.plot(x, np.sin(x + 0.5), linestyle='--', label='Dashed (--)', linewidth=2)
ax.plot(x, np.sin(x + 1), linestyle='-.', label='Dash-dot (-.)', linewidth=2)
ax.plot(x, np.sin(x + 1.5), linestyle=':', label='Dotted (:)', linewidth=2)
ax.legend()
ax.set_title('Line Style Options')
plt.show()
# Marker styles
fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(1, 6)
markers = ['o', 's', '^', 'D', 'v', 'p', '*', 'h']
marker_names = ['Circle', 'Square', 'Triangle Up', 'Diamond', 'Triangle Down', 'Pentagon', 'Star', 'Hexagon']
for i, (marker, name) in enumerate(zip(markers, marker_names)):
ax.plot(x, [i+1]*5, marker=marker, markersize=12, linestyle='', label=name)
ax.set_yticks(range(1, len(markers)+1))
ax.set_yticklabels(marker_names)
ax.set_title('Marker Style Options')
ax.legend(loc='upper right')
plt.show()
Format String Shortcut
Combine color, marker, and line style in one string: plt.plot(x, y, 'ro--')
creates red circles with dashed lines. Format: [color][marker][linestyle].
Labels, Titles, and Text
Clear labels and titles are essential for understandable visualizations. Matplotlib provides extensive control over text properties including font size, weight, family, and color. You can add annotations to highlight specific data points and use mathematical notation with LaTeX-style formatting.
fig, ax = plt.subplots(figsize=(12, 7))
x = np.linspace(0, 10, 100)
y = np.sin(x)
ax.plot(x, y, 'b-', linewidth=2)
# Customized labels and title
ax.set_xlabel('Time (seconds)', fontsize=14, fontweight='bold', color='#333')
ax.set_ylabel('Amplitude', fontsize=14, fontweight='bold', color='#333')
ax.set_title('Sine Wave with Custom Styling', fontsize=18, fontweight='bold',
color='#2c3e50', pad=20)
# Add text annotation
ax.text(5, 0.5, 'Peak Region', fontsize=12, style='italic',
bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.5))
# Add annotation with arrow
ax.annotate('Maximum', xy=(np.pi/2, 1), xytext=(3, 0.7),
fontsize=12, arrowprops=dict(arrowstyle='->', color='red'),
color='red')
# Mathematical notation using LaTeX
ax.text(8, -0.5, r'$y = \sin(x)$', fontsize=14,
bbox=dict(facecolor='white', edgecolor='gray'))
plt.tight_layout()
plt.show()
Legends
Legends identify different data series in your plot. Matplotlib offers extensive legend customization including position, number of columns, frame styling, and font properties. A well-placed legend enhances readability without obscuring important data points.
fig, ax = plt.subplots(figsize=(12, 6))
x = np.linspace(0, 10, 100)
ax.plot(x, np.sin(x), label='sin(x)', linewidth=2)
ax.plot(x, np.cos(x), label='cos(x)', linewidth=2)
ax.plot(x, np.sin(x) * np.cos(x), label='sin(x)·cos(x)', linewidth=2)
# Customized legend
ax.legend(
loc='upper right', # Position: 'best', 'upper left', 'lower right', etc.
fontsize=11, # Font size
frameon=True, # Show frame
facecolor='white', # Background color
edgecolor='gray', # Frame color
framealpha=0.9, # Frame transparency
ncol=3, # Number of columns
title='Functions', # Legend title
title_fontsize=12
)
ax.set_title('Legend Customization Example')
plt.show()
# Legend positions
positions = ['best', 'upper left', 'upper right', 'lower left',
'lower right', 'center left', 'center right', 'upper center',
'lower center', 'center']
print(f"Available legend locations: {positions}")
Styles and Themes
Matplotlib includes built-in style sheets that change the overall appearance of your plots. These provide consistent, professional looks without manually setting every property. You can also combine styles or create custom ones for your organization's branding.
# View available styles
print(plt.style.available)
# ['Solarize_Light2', 'bmh', 'classic', 'dark_background', 'fivethirtyeight',
# 'ggplot', 'grayscale', 'seaborn-v0_8', 'seaborn-v0_8-bright', ...]
# Apply a style
plt.style.use('seaborn-v0_8-whitegrid')
fig, ax = plt.subplots(figsize=(10, 6))
x = np.linspace(0, 10, 100)
ax.plot(x, np.sin(x), linewidth=2)
ax.plot(x, np.cos(x), linewidth=2)
ax.set_title('Seaborn Whitegrid Style')
plt.show()
# Use style as context manager (temporary)
with plt.style.context('dark_background'):
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, np.sin(x), linewidth=2)
ax.set_title('Dark Background Style')
plt.show()
# Reset to default
plt.style.use('default')
| Style | Description | Best For |
|---|---|---|
seaborn-v0_8 |
Clean, modern aesthetic | General data science |
ggplot |
R's ggplot2 inspired | Statistical graphics |
fivethirtyeight |
Bold, journalistic | Presentations |
dark_background |
Dark theme | Slides, dark UIs |
bmh |
Bayesian Methods for Hackers | Academic papers |
Practice Questions
Task: Create a styled line plot with custom colors, markers, and line style.
# Given: Monthly revenue data
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
revenue = [45, 52, 48, 61, 55, 67]
# Your code: Plot with hex color #2ecc71, square markers, dashed line, linewidth 2
Expected Output: Green dashed line with square markers
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
revenue = [45, 52, 48, 61, 55, 67]
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(months, revenue, color='#2ecc71', marker='s', linestyle='--',
linewidth=2, markersize=10)
ax.set_xlabel('Month')
ax.set_ylabel('Revenue (thousands)')
ax.set_title('Monthly Revenue')
plt.show()
Task: Create a plot with annotations highlighting the maximum value.
# Given: Stock price data
days = np.arange(1, 31)
prices = 100 + np.cumsum(np.random.randn(30))
# Your code: Line plot with annotation arrow pointing to the max price
Expected Output: Line plot with "Peak: $X.XX" annotation at the maximum
np.random.seed(42)
days = np.arange(1, 31)
prices = 100 + np.cumsum(np.random.randn(30))
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(days, prices, 'b-', linewidth=2)
max_idx = np.argmax(prices)
max_price = prices[max_idx]
max_day = days[max_idx]
ax.annotate(f'Peak: ${max_price:.2f}',
xy=(max_day, max_price),
xytext=(max_day + 3, max_price - 2),
fontsize=12, fontweight='bold',
arrowprops=dict(arrowstyle='->', color='red', lw=2),
color='red')
ax.scatter([max_day], [max_price], color='red', s=100, zorder=5)
ax.set_xlabel('Day')
ax.set_ylabel('Price ($)')
ax.set_title('Stock Price with Peak Annotation')
plt.show()
Task: Create a publication-quality comparison chart with full styling.
# Given: Performance metrics for 3 algorithms
categories = ['Accuracy', 'Speed', 'Memory', 'Scalability']
algo_a = [0.92, 0.78, 0.85, 0.70]
algo_b = [0.88, 0.95, 0.72, 0.88]
algo_c = [0.95, 0.65, 0.90, 0.75]
# Your code: Grouped bar chart with custom colors, legend, annotations
# Use fivethirtyeight style, add title with subtitle effect
Expected Output: Professional grouped bar chart with value labels
categories = ['Accuracy', 'Speed', 'Memory', 'Scalability']
algo_a = [0.92, 0.78, 0.85, 0.70]
algo_b = [0.88, 0.95, 0.72, 0.88]
algo_c = [0.95, 0.65, 0.90, 0.75]
with plt.style.context('fivethirtyeight'):
fig, ax = plt.subplots(figsize=(12, 7))
x = np.arange(len(categories))
width = 0.25
bars1 = ax.bar(x - width, algo_a, width, label='Algorithm A', color='#3498db')
bars2 = ax.bar(x, algo_b, width, label='Algorithm B', color='#e74c3c')
bars3 = ax.bar(x + width, algo_c, width, label='Algorithm C', color='#2ecc71')
# Value labels
for bars in [bars1, bars2, bars3]:
for bar in bars:
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
f'{bar.get_height():.2f}', ha='center', fontsize=9)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('Algorithm Performance Comparison\n', fontsize=16, fontweight='bold')
ax.text(0.5, 1.02, 'Higher is better for all metrics', transform=ax.transAxes,
ha='center', fontsize=10, style='italic', color='gray')
ax.set_xticks(x)
ax.set_xticklabels(categories)
ax.set_ylim(0, 1.15)
ax.legend(loc='upper right', framealpha=0.9)
plt.tight_layout()
plt.show()
Subplots & Figure Layouts
Complex data stories often require multiple visualizations displayed together. Matplotlib's subplot system lets you create grid layouts with multiple axes in a single figure. From simple side-by-side comparisons to sophisticated dashboard-style layouts, mastering subplots is essential for professional data visualization.
Basic Subplots with plt.subplots()
The plt.subplots() function creates a figure with a grid of axes. Specify the number of
rows and columns, and it returns a figure object and an array of axes. For a single row or column,
the axes array is 1-dimensional. For grids, it is 2-dimensional and you access individual axes using
row and column indices.
# Create a 2x2 grid of subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
x = np.linspace(0, 10, 100)
# Access each subplot by row, column index
axes[0, 0].plot(x, np.sin(x), 'b-')
axes[0, 0].set_title('Sine Wave')
axes[0, 1].plot(x, np.cos(x), 'r-')
axes[0, 1].set_title('Cosine Wave')
axes[1, 0].plot(x, np.exp(-x/5), 'g-')
axes[1, 0].set_title('Exponential Decay')
axes[1, 1].plot(x, np.log(x + 1), 'm-')
axes[1, 1].set_title('Logarithmic Growth')
plt.tight_layout() # Prevent overlapping
plt.show()
# For single row or column, axes is 1D
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].plot(x, np.sin(x))
axes[1].plot(x, np.cos(x))
axes[2].plot(x, np.tan(x))
plt.tight_layout()
plt.show()
Sharing Axes
When comparing related data, sharing x or y axes helps viewers make accurate comparisons. The
sharex and sharey parameters link the axis limits across subplots. This
ensures that all plots use the same scale, making differences in the data immediately apparent.
# Shared x-axis for time series comparison
fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True)
time = np.arange(100)
np.random.seed(42)
# Different metrics over the same time period
axes[0].plot(time, np.cumsum(np.random.randn(100)), 'b-', linewidth=2)
axes[0].set_ylabel('Revenue')
axes[0].set_title('Business Metrics Over Time')
axes[1].plot(time, 50 + np.cumsum(np.random.randn(100) * 0.5), 'g-', linewidth=2)
axes[1].set_ylabel('Customers')
axes[2].plot(time, 80 + np.cumsum(np.random.randn(100) * 0.3), 'r-', linewidth=2)
axes[2].set_ylabel('Satisfaction')
axes[2].set_xlabel('Days')
plt.tight_layout()
plt.show()
# Shared y-axis for comparing distributions
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)
data1 = np.random.normal(50, 10, 1000)
data2 = np.random.normal(60, 15, 1000)
axes[0].hist(data1, bins=30, color='#3498db', edgecolor='white')
axes[0].set_title('Group A')
axes[0].set_xlabel('Value')
axes[0].set_ylabel('Frequency')
axes[1].hist(data2, bins=30, color='#e74c3c', edgecolor='white')
axes[1].set_title('Group B')
axes[1].set_xlabel('Value')
plt.tight_layout()
plt.show()
sharex / sharey
When True, all subplots share the same axis. Can also be 'row', 'col', or 'all' for more control. Shared axes automatically hide redundant tick labels.
GridSpec for Complex Layouts
For layouts where subplots have different sizes, use GridSpec. This powerful tool
lets you create grids where individual plots span multiple rows or columns. It is perfect for
creating dashboard-style visualizations with a mix of large and small charts.
from matplotlib.gridspec import GridSpec
fig = plt.figure(figsize=(14, 10))
gs = GridSpec(3, 3, figure=fig)
# Large plot spanning top row
ax1 = fig.add_subplot(gs[0, :]) # Row 0, all columns
x = np.linspace(0, 10, 100)
ax1.plot(x, np.sin(x) * np.exp(-x/10), linewidth=2)
ax1.set_title('Main Time Series (Spanning Full Width)', fontsize=14)
# Two medium plots in middle row
ax2 = fig.add_subplot(gs[1, :2]) # Row 1, columns 0-1
ax2.bar(['A', 'B', 'C', 'D'], [25, 40, 30, 35], color='#3498db')
ax2.set_title('Category Breakdown')
ax3 = fig.add_subplot(gs[1, 2]) # Row 1, column 2
ax3.pie([35, 25, 20, 20], labels=['Q1', 'Q2', 'Q3', 'Q4'], autopct='%1.0f%%')
ax3.set_title('Quarterly Split')
# Three small plots in bottom row
ax4 = fig.add_subplot(gs[2, 0])
ax4.scatter(np.random.rand(20), np.random.rand(20), c='#e74c3c', s=50)
ax4.set_title('Scatter')
ax5 = fig.add_subplot(gs[2, 1])
ax5.hist(np.random.randn(100), bins=15, color='#2ecc71', edgecolor='white')
ax5.set_title('Distribution')
ax6 = fig.add_subplot(gs[2, 2])
ax6.plot([1, 2, 3, 4], [1, 4, 2, 3], 'o-', color='#9b59b6', linewidth=2)
ax6.set_title('Trend')
plt.tight_layout()
plt.show()
Saving Figures
After creating your visualization, you will often need to save it for reports, presentations, or
publications. The savefig() function exports figures in various formats including PNG,
PDF, SVG, and EPS. Control resolution with dpi and use bbox_inches='tight'
to remove extra whitespace.
# Create a figure to save
fig, ax = plt.subplots(figsize=(10, 6))
x = np.linspace(0, 10, 100)
ax.plot(x, np.sin(x), linewidth=2)
ax.set_title('Publication-Ready Figure')
ax.set_xlabel('X Axis')
ax.set_ylabel('Y Axis')
# Save in different formats
fig.savefig('my_plot.png', dpi=300, bbox_inches='tight', facecolor='white')
fig.savefig('my_plot.pdf', bbox_inches='tight') # Vector format for papers
fig.savefig('my_plot.svg', bbox_inches='tight') # Vector format for web
# With transparent background
fig.savefig('my_plot_transparent.png', dpi=300, bbox_inches='tight', transparent=True)
print("Figures saved successfully!")
plt.show()
# Common savefig parameters
# dpi: Resolution (300 for print, 150 for web)
# bbox_inches: 'tight' removes whitespace
# facecolor: Background color
# transparent: True for transparent background
# format: 'png', 'pdf', 'svg', 'eps', 'jpg'
| Format | Type | Best For | Recommended DPI |
|---|---|---|---|
| PNG | Raster | Web, presentations | 150-300 |
| Vector | Academic papers, print | N/A (scalable) | |
| SVG | Vector | Web, interactive | N/A (scalable) |
| EPS | Vector | LaTeX documents | N/A (scalable) |
| JPG | Raster | Photos (not charts) | 150-300 |
Practice Questions
Task: Create a 1x2 subplot comparing two distributions.
# Given: Two datasets
np.random.seed(42)
before = np.random.normal(100, 15, 200)
after = np.random.normal(110, 12, 200)
# Your code: Side-by-side histograms with shared y-axis, proper titles
Expected Output: Two histograms showing "Before" and "After" with same y-scale
np.random.seed(42)
before = np.random.normal(100, 15, 200)
after = np.random.normal(110, 12, 200)
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)
axes[0].hist(before, bins=20, color='#e74c3c', edgecolor='white', alpha=0.7)
axes[0].set_title('Before Treatment')
axes[0].set_xlabel('Value')
axes[0].set_ylabel('Frequency')
axes[0].axvline(before.mean(), color='black', linestyle='--', label=f'Mean: {before.mean():.1f}')
axes[0].legend()
axes[1].hist(after, bins=20, color='#2ecc71', edgecolor='white', alpha=0.7)
axes[1].set_title('After Treatment')
axes[1].set_xlabel('Value')
axes[1].axvline(after.mean(), color='black', linestyle='--', label=f'Mean: {after.mean():.1f}')
axes[1].legend()
plt.suptitle('Treatment Effect Comparison', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
Task: Create a 2x2 subplot showing different views of sales data.
# Given: Monthly sales data
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
online = [120, 150, 140, 180, 200, 220]
store = [80, 90, 95, 100, 85, 110]
# Create 4 plots: line chart, bar chart, stacked bar, pie chart
Expected Output: Four different visualizations of the same data
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
online = [120, 150, 140, 180, 200, 220]
store = [80, 90, 95, 100, 85, 110]
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Line chart
axes[0, 0].plot(months, online, 'o-', label='Online', linewidth=2)
axes[0, 0].plot(months, store, 's-', label='Store', linewidth=2)
axes[0, 0].set_title('Sales Trend')
axes[0, 0].legend()
axes[0, 0].set_ylabel('Sales (K)')
# Grouped bar chart
x = np.arange(len(months))
width = 0.35
axes[0, 1].bar(x - width/2, online, width, label='Online')
axes[0, 1].bar(x + width/2, store, width, label='Store')
axes[0, 1].set_title('Monthly Comparison')
axes[0, 1].set_xticks(x)
axes[0, 1].set_xticklabels(months)
axes[0, 1].legend()
# Stacked bar
axes[1, 0].bar(months, online, label='Online')
axes[1, 0].bar(months, store, bottom=online, label='Store')
axes[1, 0].set_title('Total Sales (Stacked)')
axes[1, 0].legend()
# Pie chart for totals
axes[1, 1].pie([sum(online), sum(store)], labels=['Online', 'Store'],
autopct='%1.1f%%', colors=['#3498db', '#e74c3c'])
axes[1, 1].set_title('Channel Distribution')
plt.suptitle('Sales Dashboard', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()
Task: Create a dashboard layout with one large plot on top and three smaller plots below.
# Given: Stock data
np.random.seed(42)
days = np.arange(1, 101)
stock_price = 100 + np.cumsum(np.random.randn(100) * 2)
volume = np.random.uniform(1, 5, 100)
returns = np.diff(stock_price) / stock_price[:-1] * 100
# Create: Large price chart on top, volume bar, returns histogram, returns scatter below
Expected Output: Professional dashboard with main chart and supporting visuals
from matplotlib.gridspec import GridSpec
np.random.seed(42)
days = np.arange(1, 101)
stock_price = 100 + np.cumsum(np.random.randn(100) * 2)
volume = np.random.uniform(1, 5, 100)
returns = np.diff(stock_price) / stock_price[:-1] * 100
fig = plt.figure(figsize=(16, 10))
gs = GridSpec(2, 3, figure=fig, height_ratios=[2, 1])
# Main price chart (top, spans all columns)
ax1 = fig.add_subplot(gs[0, :])
ax1.plot(days, stock_price, 'b-', linewidth=2)
ax1.fill_between(days, stock_price.min(), stock_price, alpha=0.2)
ax1.set_title('Stock Price Over Time', fontsize=14, fontweight='bold')
ax1.set_xlabel('Day')
ax1.set_ylabel('Price ($)')
ax1.grid(True, alpha=0.3)
# Volume bar chart (bottom left)
ax2 = fig.add_subplot(gs[1, 0])
colors = ['#2ecc71' if returns[i-1] >= 0 else '#e74c3c' for i in range(1, len(volume))]
colors.insert(0, '#3498db')
ax2.bar(days, volume, color=colors, width=1)
ax2.set_title('Trading Volume')
ax2.set_xlabel('Day')
ax2.set_ylabel('Volume (M)')
# Returns histogram (bottom center)
ax3 = fig.add_subplot(gs[1, 1])
ax3.hist(returns, bins=20, color='#9b59b6', edgecolor='white')
ax3.axvline(0, color='black', linestyle='--')
ax3.set_title('Returns Distribution')
ax3.set_xlabel('Daily Return (%)')
ax3.set_ylabel('Frequency')
# Returns scatter (bottom right)
ax4 = fig.add_subplot(gs[1, 2])
colors = ['#2ecc71' if r >= 0 else '#e74c3c' for r in returns]
ax4.scatter(days[1:], returns, c=colors, alpha=0.6, s=30)
ax4.axhline(0, color='black', linestyle='-', linewidth=0.5)
ax4.set_title('Daily Returns')
ax4.set_xlabel('Day')
ax4.set_ylabel('Return (%)')
plt.suptitle('Stock Analysis Dashboard', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()
Key Takeaways
Figure & Axes Architecture
Every Matplotlib visualization has a Figure (canvas) containing one or more Axes (individual plots). Use fig, ax = plt.subplots() to create both at once. The object-oriented approach gives you more control than pyplot.
Choose the Right Plot Type
Use line plots for trends over time, scatter plots for relationships between variables, bar charts for categorical comparisons, and histograms for distributions. Match your chart type to your data and message.
Customize for Clarity
Use colors, markers, and line styles to differentiate data series. Add clear labels, titles, and legends. The format string shortcut ('ro--') combines color, marker, and linestyle in one argument.
Master Subplots
Create multi-panel figures with plt.subplots(rows, cols). Use sharex and sharey for consistent scales. For complex layouts, use GridSpec to span rows and columns.
Use Built-in Styles
Apply professional styling with plt.style.use('seaborn-v0_8'). Use styles as context managers for temporary changes. Explore available styles with plt.style.available.
Save Publication-Quality
Export figures with fig.savefig(). Use PNG for web (dpi=150), PDF/SVG for print. Always add bbox_inches='tight' to remove extra whitespace and facecolor='white' for clean backgrounds.
Knowledge Check
Quick Quiz
Test your understanding of Matplotlib fundamentals