Subplots in Matplotlib
Subplots allow you to create multiple plots in a single figure. This is particularly useful for comparing different datasets or visualizing multiple variables simultaneously.
Creating Subplots
The subplot()
function in Matplotlib is used to create subplots. It takes three arguments: the number of rows, the number of columns, and the index of the current subplot.
Example: Basic Subplots
import matplotlib.pyplot as plt
import numpy as np
# Data
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
# Create subplots
plt.subplot(2, 1, 1) # 2 rows, 1 column, 1st plot
plt.plot(x, y1, label='Sine')
plt.title("Sine Wave")
plt.legend()
plt.subplot(2, 1, 2) # 2 rows, 1 column, 2nd plot
plt.plot(x, y2, label='Cosine', color='orange')
plt.title("Cosine Wave")
plt.legend()
# Adjust layout
plt.tight_layout()
# Display the plot
plt.show()
Using plt.subplots()
The subplots()
function provides more flexibility and is commonly used for creating subplots.
Example: Multiple Subplots with plt.subplots()
# Data
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
# Create subplots
fig, axs = plt.subplots(2, 2, figsize=(10, 6))
# Plot on each subplot
axs[0, 0].plot(x, y1, label='Sine')
axs[0, 0].set_title("Sine Wave")
axs[0, 0].legend()
axs[0, 1].plot(x, y2, label='Cosine', color='orange')
axs[0, 1].set_title("Cosine Wave")
axs[0, 1].legend()
axs[1, 0].plot(x, y1 + y2, label='Sine + Cosine', color='green')
axs[1, 0].set_title("Combined Wave")
axs[1, 0].legend()
axs[1, 1].axis('off') # Empty subplot
# Adjust layout
plt.tight_layout()
# Display the plot
plt.show()
Sharing Axes in Subplots
You can share axes between subplots for better comparison.
Example: Shared Axes
# Data
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
# Create subplots with shared axes
fig, axs = plt.subplots(2, 1, sharex=True, sharey=True, figsize=(8, 6))
# Plot on each subplot
axs[0].plot(x, y1, label='Sine', color='blue')
axs[0].set_title("Sine Wave")
axs[0].legend()
axs[1].plot(x, y2, label='Cosine', color='red')
axs[1].set_title("Cosine Wave")
axs[1].legend()
# Adjust layout
plt.tight_layout()
# Display the plot
plt.show()
Practical Examples
Example 1: Comparing Stock Prices
# Data
x = np.arange(1, 6)
company_a = [100, 110, 115, 120, 125]
company_b = [90, 95, 100, 105, 110]
# Create subplots
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
# Plot stock prices
axs[0].plot(x, company_a, label='Company A', marker='o')
axs[0].set_title("Company A Stock Prices")
axs[0].set_xlabel("Days")
axs[0].set_ylabel("Price")
axs[0].legend()
axs[1].plot(x, company_b, label='Company B', marker='o', color='green')
axs[1].set_title("Company B Stock Prices")
axs[1].set_xlabel("Days")
axs[1].legend()
# Adjust layout
plt.tight_layout()
# Display the plot
plt.show()
Example 2: Weather Data
# Data
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May']
temp = [5, 7, 10, 15, 20]
rainfall = [50, 40, 60, 30, 20]
# Create subplots
fig, axs = plt.subplots(2, 1, figsize=(8, 8))
# Plot temperature
axs[0].bar(months, temp, color='orange')
axs[0].set_title("Average Monthly Temperature")
axs[0].set_ylabel("Temperature (°C)")
# Plot rainfall
axs[1].bar(months, rainfall, color='blue')
axs[1].set_title("Monthly Rainfall")
axs[1].set_ylabel("Rainfall (mm)")
# Adjust layout
plt.tight_layout()
# Display the plot
plt.show()
Try It Yourself
Problem 1: Compare Sales Data
Create subplots to compare sales data for two products over a week.
Show Code
# Data
days = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
product_a = [10, 15, 20, 25, 30, 35, 40]
product_b = [5, 10, 15, 20, 25, 30, 35]
# Create subplots
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
# Plot sales for Product A
axs[0].plot(days, product_a, label='Product A', marker='o', color='green')
axs[0].set_title("Product A Sales")
axs[0].set_ylabel("Units Sold")
axs[0].legend()
# Plot sales for Product B
axs[1].plot(days, product_b, label='Product B', marker='o', color='blue')
axs[1].set_title("Product B Sales")
axs[1].set_xlabel("Days")
axs[1].set_ylabel("Units Sold")
axs[1].legend()
# Adjust layout
plt.tight_layout()
# Display the plot
plt.show()
Subplots allow you to effectively visualize and compare multiple datasets within a single figure. Experiment with layouts, shared axes, and different styles to create insightful visualizations.