Subplots in Matplotlib

Subplots allow you to create multiple plots in a single figure. This is particularly useful for comparing different datasets or visualizing multiple variables simultaneously.


Creating Subplots

The subplot() function in Matplotlib is used to create subplots. It takes three arguments: the number of rows, the number of columns, and the index of the current subplot.

Example: Basic Subplots

import matplotlib.pyplot as plt
import numpy as np
 
# Data
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
 
# Create subplots
plt.subplot(2, 1, 1)  # 2 rows, 1 column, 1st plot
plt.plot(x, y1, label='Sine')
plt.title("Sine Wave")
plt.legend()
 
plt.subplot(2, 1, 2)  # 2 rows, 1 column, 2nd plot
plt.plot(x, y2, label='Cosine', color='orange')
plt.title("Cosine Wave")
plt.legend()
 
# Adjust layout
plt.tight_layout()
 
# Display the plot
plt.show()

Using plt.subplots()

The subplots() function provides more flexibility and is commonly used for creating subplots.

Example: Multiple Subplots with plt.subplots()

# Data
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
 
# Create subplots
fig, axs = plt.subplots(2, 2, figsize=(10, 6))
 
# Plot on each subplot
axs[0, 0].plot(x, y1, label='Sine')
axs[0, 0].set_title("Sine Wave")
axs[0, 0].legend()
 
axs[0, 1].plot(x, y2, label='Cosine', color='orange')
axs[0, 1].set_title("Cosine Wave")
axs[0, 1].legend()
 
axs[1, 0].plot(x, y1 + y2, label='Sine + Cosine', color='green')
axs[1, 0].set_title("Combined Wave")
axs[1, 0].legend()
 
axs[1, 1].axis('off')  # Empty subplot
 
# Adjust layout
plt.tight_layout()
 
# Display the plot
plt.show()

Sharing Axes in Subplots

You can share axes between subplots for better comparison.

Example: Shared Axes

# Data
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
 
# Create subplots with shared axes
fig, axs = plt.subplots(2, 1, sharex=True, sharey=True, figsize=(8, 6))
 
# Plot on each subplot
axs[0].plot(x, y1, label='Sine', color='blue')
axs[0].set_title("Sine Wave")
axs[0].legend()
 
axs[1].plot(x, y2, label='Cosine', color='red')
axs[1].set_title("Cosine Wave")
axs[1].legend()
 
# Adjust layout
plt.tight_layout()
 
# Display the plot
plt.show()

Practical Examples

Example 1: Comparing Stock Prices

# Data
x = np.arange(1, 6)
company_a = [100, 110, 115, 120, 125]
company_b = [90, 95, 100, 105, 110]
 
# Create subplots
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
 
# Plot stock prices
axs[0].plot(x, company_a, label='Company A', marker='o')
axs[0].set_title("Company A Stock Prices")
axs[0].set_xlabel("Days")
axs[0].set_ylabel("Price")
axs[0].legend()
 
axs[1].plot(x, company_b, label='Company B', marker='o', color='green')
axs[1].set_title("Company B Stock Prices")
axs[1].set_xlabel("Days")
axs[1].legend()
 
# Adjust layout
plt.tight_layout()
 
# Display the plot
plt.show()

Example 2: Weather Data

# Data
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May']
temp = [5, 7, 10, 15, 20]
rainfall = [50, 40, 60, 30, 20]
 
# Create subplots
fig, axs = plt.subplots(2, 1, figsize=(8, 8))
 
# Plot temperature
axs[0].bar(months, temp, color='orange')
axs[0].set_title("Average Monthly Temperature")
axs[0].set_ylabel("Temperature (°C)")
 
# Plot rainfall
axs[1].bar(months, rainfall, color='blue')
axs[1].set_title("Monthly Rainfall")
axs[1].set_ylabel("Rainfall (mm)")
 
# Adjust layout
plt.tight_layout()
 
# Display the plot
plt.show()

Try It Yourself

Problem 1: Compare Sales Data

Create subplots to compare sales data for two products over a week.

Show Code
# Data
days = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
product_a = [10, 15, 20, 25, 30, 35, 40]
product_b = [5, 10, 15, 20, 25, 30, 35]
 
# Create subplots
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
 
# Plot sales for Product A
axs[0].plot(days, product_a, label='Product A', marker='o', color='green')
axs[0].set_title("Product A Sales")
axs[0].set_ylabel("Units Sold")
axs[0].legend()
 
# Plot sales for Product B
axs[1].plot(days, product_b, label='Product B', marker='o', color='blue')
axs[1].set_title("Product B Sales")
axs[1].set_xlabel("Days")
axs[1].set_ylabel("Units Sold")
axs[1].legend()
 
# Adjust layout
plt.tight_layout()
 
# Display the plot
plt.show()

Subplots allow you to effectively visualize and compare multiple datasets within a single figure. Experiment with layouts, shared axes, and different styles to create insightful visualizations.


Pyground

Play with Python!

Output: