Home » Python » 3D plotting in Python using matplotlib

About Shahbaz Khan

Shahbaz Khan

3D plotting in Python using matplotlib

Data visualization is one such area where a large number of libraries have been developed in Python. Among these, Matplotlib is the most popular choice for data visualization. While initially developed for plotting 2-D charts like histograms, bar charts, scatter plots, line plots, etc., Matplotlib has extended its capabilities to offer 3D plotting modules as well.

In this tutorial, we will look at various aspects of 3D plotting in Python.

We will begin by plotting a single point in a 3D coordinate space. We will then learn how to customize our plots, and then we’ll move on to more complicated plots like 3D Gaussian surfaces, 3D polygons, etc. Specifically, we will look at the following topics:

Plot a single point in a 3D space

Let us begin by going through every step necessary to create a 3D plot in Python, with an example of plotting a point in 3D space.

Step 1: Import the libraries

1
2
3
import matplotlib.pyplot as plt
 
from mpl_toolkits.mplot3d import Axes3D

The first one is a standard import statement for plotting using matplotlib, which you would see for 2D plotting as well.
The second import of the Axes3D class is required for enabling 3D projections. It is, otherwise, not used anywhere else.

Note that the second import is required for Matplotlib versions before 3.2.0. For versions 3.2.0 and higher, you can plot 3D plots without importing mpl_toolkits.mplot3d.Axes3D.

Step 2: Create figure and axes

1
2
3
fig = plt.figure(figsize=(4,4))
 
ax = fig.add_subplot(111, projection='3d')

Output:

Here we are first creating a figure of size 4 inches X 4 inches.
We then create a 3-D axis object by calling the add_subplot method and specifying the value ‘3d’ to the projection parameter.
We will use this axis object ‘ax’ to add any plot to the figure.

Note that these two steps will be common in most of the 3D plotting you do in Python using Matplotlib.

Step 3: Plot the point

After we create the axes object, we can use it to create any type of plot we want in the 3D space.
To plot a single point, we will use the scatter()method, and pass the three coordinates of the point.

1
2
3
4
5
6
7
fig = plt.figure(figsize=(4,4))
 
ax = fig.add_subplot(111, projection='3d')
 
ax.scatter(2,3,4) # plot the point (2,3,4) on the figure
 
plt.show()

Output:

As you can see, a single point has been plotted (in blue) at (2,3,4).

Plotting a 3D continuous line

Now that we know how to plot a single point in 3D, we can similarly plot a continuous line passing through a list of 3D coordinates.

We will use the plot() method and pass 3 arrays, one each for the x, y, and z coordinates of the points on the line.

01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
import numpy as np
 
x = np.linspace(−4*np.pi,4*np.pi,50)
 
y = np.linspace(−4*np.pi,4*np.pi,50)
 
z = x**2 + y**2
 
fig = plt.figure()
 
ax = fig.add_subplot(111, projection='3d')
 
ax.plot(x,y,z)
 
plt.show()

Output:

We are generating x, y, and z coordinates for 50 points. The x and y coordinates are generated usingnp.linspace to generate 50 uniformly distributed points between -4π and +4π. The z coordinate is simply the sum of the squares of the corresponding x and y coordinates.

Customizing a 3D plot

Let us plot a scatter plot in 3D space and look at how we can customize its appearance in different ways based on our preferences. We will use NumPy random seed so you can generate the same random number as the tutorial.

01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
np.random.seed(42)
 
xs = np.random.random(100)*10+20
 
ys = np.random.random(100)*5+7
 
zs = np.random.random(100)*15+50
 
fig = plt.figure()
 
ax = fig.add_subplot(111, projection='3d')
 
ax.scatter(xs,ys,zs)
 
plt.show()

Output:

Let us now add a title to this plot

Adding a title

We will call the set_title method of the axes object to add a title to the plot.

1
2
3
ax.set_title("Atom velocity distribution")
 
plt.show()

Output:

NOTE that I have not added the preceding code (to create the figure and add scatter plot) here, but you should do it.

Let us now add labels to each axis on the plot.

Adding axes labels

We can set a label for each axis in a 3D plot by calling the methods set_xlabelset_ylabel and set_zlabel on the axes object.

1
2
3
4
5
6
7
ax.set_xlabel("Atomic mass (dalton)")
 
ax.set_ylabel("Atomic radius (pm)")
 
ax.set_zlabel("Atomic velocity (x10⁶ m/s)")
 
plt.show()

Output:

Modifying the markers

As we have seen in our previous examples, the marker for each point, by default, is a filled blue circle of constant size.
We can alter the appearance of the markers to make them more expressive.

Let us begin by changing the color and style of the marker

1
2
3
ax.scatter(xs,ys,zs, marker="x", c="red")
 
plt.show()

Output:
We have used the parameters marker and c to change the style and color of the individual points

Modifying the axes limits and ticks

The range and interval of values on the axes are set by default based on the input values.
We can however alter them to our desired values.

Let us create another scatter plot representing a new set of data points, and then modify its axes range and interval.

01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
np.random.seed(42)
 
ages = np.random.randint(low = 8, high = 30, size=35)
 
heights = np.random.randint(130, 195, 35)
 
weights = np.random.randint(30, 160, 35)
 
fig = plt.figure()
 
ax = fig.add_subplot(111, projection='3d')
 
ax.scatter(xs = heights, ys = weights, zs = ages)
 
ax.set_title("Age-wise body weight-height distribution")
 
ax.set_xlabel("Height (cm)")
 
ax.set_ylabel("Weight (kg)")
 
ax.set_zlabel("Age (years)")
 
plt.show()

Output:

We have plotted data of 3 variables, namely, height, weight and age on the 3 axes.
As you can see, the limits on the X, Y, and Z axes have been assigned automatically based on the input data.

Let us modify the minimum and maximum limit on each axis, by calling the set_xlimset_ylim, and set_zlim methods.

1
2
3
4
5
6
7
ax.set_xlim(100,200)
 
ax.set_ylim(20,160)
 
ax.set_zlim(5,35)
 
plt.show()

Output:

The limits for the three axes have been modified based on the min and max values we passed to the respective methods.
We can also modify the individual ticks for each axis. Currently, the X-axis ticks are [100,120,140,160,180,200].
Let us update this to [100,125,150,175,200]

1
2
3
ax.set_xticks([100,125,150,175,200])
 
plt.show()

Output:


Similarly, we can update the Y and Z ticks using the set_yticks and set_zticks methods.

1
2
3
4
5
ax.set_yticks([20,55,90,125,160])
 
ax.set_zticks([5,15,25,35])
 
plt.show()

Output:

Change the size of the plot

If we want our plots to be bigger or smaller than the default size, we can easily set the size of the plot either when initializing the figure – using the figsize parameter of the plt.figure method, or we can update the size of an existing plot by calling the set_size_inches method on the figure object. In both approaches, we must specify the width and height of the plot in inches.

Since we have seen the first method of specifying the size of the plot earlier, let us look at the second approach now i.e modifying the size of an existing plot. We will change the size of our scatter plot to 6×6 inches.

1
2
3
fig.set_size_inches(6, 6)
 
plt.show()

Output:

The size of our scatter plot has been increased compared to its previous default size.

Turn off/on gridlines

All the plots that we have plotted so far have gridlines on them by default.
We can change this by calling the grid method of the axes object, and pass the value ‘False.’
If we want the gridlines back again, we can call the same method with the parameter ‘True.’.

1
2
3
ax.grid(False)
 
plt.show()

Output:

Set 3D plot colors based on class

Let us suppose that the individuals represented by our scatter plot were further divided into two or more categories.
We can represent this information by plotting the individuals of each category with a different color.
For instance, let us divide our data into ‘Male’ and ‘Female’ categories.
We will create a new array of the same size as the number of data points, and assign the values 0 for ‘Male’ and 1 for the ‘Female’ category.
We will then pass this array to the color parameter c when creating the scatter plot.

01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
np.random.seed(42)
 
ages = np.random.randint(low = 8, high = 30, size=35)
 
heights = np.random.randint(130, 195, 35)
 
weights = np.random.randint(30, 160, 35)
 
gender_labels = np.random.choice([0, 1], 35) #0 for male, 1 for female
 
fig = plt.figure()
 
ax = fig.add_subplot(111, projection='3d')
 
ax.scatter(xs = heights, ys = weights, zs = ages, c=gender_labels)
 
ax.set_title("Age-wise body weight-height distribution")
 
ax.set_xlabel("Height (cm)")
 
ax.set_ylabel("Weight (kg)")
 
ax.set_zlabel("Age (years)")
 
plt.show()

Output:

The plot now shows each of the two categories with a different color.
But how would we know which color corresponds to which category

We can add a ‘colorbar’ to solve this problem.

1
2
3
4
5
6
7
8
9
scat_plot = ax.scatter(xs = heights, ys = weights, zs = ages, c=gender_labels)
 
cb = plt.colorbar(scat_plot, pad=0.2)
 
cb.set_ticks([0,1])
 
cb.set_ticklabels(["Male", "Female"])
 
plt.show()

Output:

Putting legends

Often we have more than 1 set of data that we want to plot on the same figure.
In such a situation, we must assign labels to each plot and add a legend to the figure to distinguish the different plots from each other.

For eg, let us suppose that our age-height-weight data were collected from 3 states of the United States, namely, Florida, Georgia and California.
We want to plot scatter plots for the 3 states and add a legend to distinguish them from each other.

Let us create the 3 plots in a for-loop and assign a different label to them each time.

01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
labels = ["Florida", "Georgia", "California"]
 
for l in labels:
 
    ages = np.random.randint(low = 8, high = 20, size=20)
 
    heights = np.random.randint(130, 195, 20)
 
    weights = np.random.randint(30, 160, 20)
 
    ax.scatter(xs = heights, ys = weights, zs = ages, label=l)
 
ax.set_title("Age-wise body weight-height distribution")
 
ax.set_xlabel("Height (cm)")
 
ax.set_ylabel("Weight (kg)")
 
ax.set_zlabel("Age (years)")
 
ax.legend(loc="best")
 
plt.show()

Output:

Plot markers of varying size

In the scatter plots that we have seen so far, all the point markers have been of constant sizes.

We can alter the size of markers by passing custom values to the parameter s of the scatter plot.
We can either pass a single number to set all the markers to a new fixed size, or we can provide an array of values, where each value represents the size of one marker.

In our example, we will calculate a new variable called ‘bmi’ from the heights and weights of individuals and make the sizes of individual markers proportional to their BMI values.

01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
np.random.seed(42)
 
ages = np.random.randint(low = 8, high = 30, size=35)
 
heights = np.random.randint(130, 195, 35)
 
weights = np.random.randint(30, 160, 35)
 
bmi = weights/((heights*0.01)**2)
 
fig = plt.figure()
 
ax = fig.add_subplot(111, projection='3d')
 
ax.scatter(xs = heights, ys = weights, zs = ages, s=bmi*5 )
 
ax.set_title("Age-wise body weight-height distribution")
 
ax.set_xlabel("Height (cm)")
 
ax.set_ylabel("Weight (kg)")
 
ax.set_zlabel("Age (years)")
 
plt.show()

Output:


The greater the sizes of markers in this plot, the higher are the BMI’s of those individuals, and vice-versa.

Plotting a Gaussian distribution

You may be aware of a univariate Gaussian distribution plotted on a 2D plane, popularly known as the ‘bell-shaped curve.’

We can also plot a Gaussian distribution in a 3D space, using the multivariate normal distribution.
We must define the variables X and Y and plot a probability distribution of them together.

01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from scipy.stats import multivariate_normal
 
X = np.linspace(-5,5,50)
 
Y = np.linspace(-5,5,50)
 
X, Y = np.meshgrid(X,Y)
 
X_mean = 0; Y_mean = 0
 
X_var = 5; Y_var = 8
 
pos = np.empty(X.shape+(2,))
 
pos[:,:,0]=X
 
pos[:,:,1]=Y
 
rv = multivariate_normal([X_mean, Y_mean],[[X_var, 0], [0, Y_var]])
 
fig = plt.figure()
 
ax = fig.add_subplot(111, projection='3d')
 
ax.plot_surface(X, Y, rv.pdf(pos), cmap="plasma")
 
plt.show()

Output:

Using the plot_surface method, we can create similar surfaces in a 3D space.

Plotting a 3D Polygon

We can also plot polygons with 3-dimensional vertices in Python.

01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
 
fig = plt.figure()
 
ax = fig.add_subplot(111, projection='3d')
 
x = [1, 0, 3, 4]
 
y = [0, 5, 5, 1]
 
z = [1, 3, 4, 0]
 
vertices = [list(zip(x,y,z))]
 
poly = Poly3DCollection(vertices, alpha=0.8)
 
ax.add_collection3d(poly)
 
ax.set_xlim(0,5)
 
ax.set_ylim(0,5)
 
ax.set_zlim(0,5)

Output:

Rotate a 3D plot with the mouse

To create an interactive plot in a Jupyter Notebook, you should run the
magic command %matplotlib notebook at the beginning of the notebook.

This enables us to interact with the 3D plots, by zooming in and out of the plot, as well as rotating them in any direction.

01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
%matplotlib notebook
import matplotlib.pyplot as plt
 
from mpl_toolkits.mplot3d import Axes3D
 
import numpy as np
 
from scipy.stats import multivariate_normal
 
X = np.linspace(-5,5,50)
 
Y = np.linspace(-5,5,50)
 
X, Y = np.meshgrid(X,Y)
 
X_mean = 0; Y_mean = 0
 
X_var = 5; Y_var = 8
 
pos = np.empty(X.shape+(2,))
 
pos[:,:,0]=X
 
pos[:,:,1]=Y
 
rv = multivariate_normal([X_mean, Y_mean],[[X_var, 0], [0, Y_var]])
 
fig = plt.figure()
 
ax = fig.add_subplot(111, projection='3d')
 
ax.plot_surface(X, Y, rv.pdf(pos), cmap="plasma")
 
plt.show()

Output:

Plot two different 3D distributions

We can add two different 3D plots to the same figure, with the help of the fig.add_subplot method.
The 3-digit number we supply to the method indicates the number of rows and columns in the grid and the position of the current plot in the grid.
The first two digits indicate the total number of rows and columns we need to divide the figure in.
The last digit indicates the position of the subplot in the grid.

For example, if we pass the value 223 to the add_subplot method, we are referring to the 3rd plot in the 2×2 grid (considering row-first ordering).

Let us now look at an example where we plot two different distributions on a single plot.

01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#data generation for 1st plot
np.random.seed(42)
 
xs = np.random.random(100)*10+20
 
ys = np.random.random(100)*5+7
 
zs = np.random.random(100)*15+50
 
#data generation for 2nd plot
np.random.seed(42)
 
ages = np.random.randint(low = 8, high = 30, size=35)
 
heights = np.random.randint(130, 195, 35)
 
weights = np.random.randint(30, 160, 35)
 
fig = plt.figure(figsize=(8,4))
 
#First plot
ax = fig.add_subplot(121, projection='3d')
 
ax.scatter(xs,ys,zs, marker="x", c="red")
 
ax.set_title("Atom velocity distribution")
 
ax.set_xlabel("Atomic mass (dalton)")
 
ax.set_ylabel("Atomic radius (pm)")
 
ax.set_zlabel("Atomic velocity (x10⁶ m/s)")
 
#Second plot
ax = fig.add_subplot(122, projection='3d')
 
ax.scatter(xs = heights, ys = weights, zs = ages)
 
ax.set_title("Age-wise body weight-height distribution")
 
ax.set_xlabel("Height (cm)")
 
ax.set_ylabel("Weight (kg)")
 
ax.set_zlabel("Age (years)")
 
plt.show()

Output:

We can plot as many subplots as we want in this way, as long as we fit them right in the grid.

Output Python 3D plot to HTML

If we want to embed a 3D plot figure to an HTML page, without first saving it as an image file,
we can do so by encoding the figure into ‘base64’ and then inserting it at the correct position in an HTML img tag

01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import base64
 
from io import BytesIO
 
np.random.seed(42)
 
xs = np.random.random(100)*10+20
 
ys = np.random.random(100)*5+7
 
zs = np.random.random(100)*15+50
 
fig = plt.figure()
 
ax = fig.add_subplot(111, projection='3d')
 
ax.scatter(xs,ys,zs)
 
#encode the figure
temp = BytesIO()
 
fig.savefig(temp, format="png")
 
fig_encode_bs64 = base64.b64encode(temp.getvalue()).decode('utf-8')
 
html_string = """
<h2>This is a test html</h2>
<img src = 'data:image/png;base64,{}'/>
""".format(fig_encode_bs64)

We can now write this HTML code string to an html file, which we can then view in a browser

1
2
3
with open("test.html", "w") as f:
 
    f.write(html_string)

Output:

Conclusion

In this tutorial, we learned how to plot 3D plots in Python using the matplotlib library.
We began by plotting a point in the 3D coordinate space, and then plotted 3D curves and scatter plots.

Then we learned various ways of customizing a 3D plot in Python, such as adding a title, legends, axes labels to the plot, resizing the plot, switching on/off the gridlines on the plot, modifying the axes ticks, etc.
We also learned how to vary the size and color of the markers based on the data point category.

After that, we learned how to plot surfaces in a 3D space. We plotted a Gaussian distribution and a 3D polygon in Python.

We then saw how we can interact with a Python 3D plot in a Jupyter notebook.

Finally, we learned how to plot multiple subplots on the same figure, and how to output a figure into an HTML code.

Published on Web Code Geeks with permission by Shahbaz Khan, partner at our WCG program. See the original article here: 3D plotting in Python using matplotlib

Opinions expressed by Web Code Geeks contributors are their own.

(0 rating, 0 votes)
You need to be a registered member to rate this.
Start the discussion Views Tweet it!
Do you want to know how to develop your skillset to become a Web Rockstar?
Subscribe to our newsletter to start Rocking right now!
To get you started we give you our best selling eBooks for FREE!
1. Building web apps with Node.js
2. HTML5 Programming Cookbook
3. CSS Programming Cookbook
4. AngularJS Programming Cookbook
5. jQuery Programming Cookbook
6. Bootstrap Programming Cookbook
and many more ....
I agree to the Terms and Privacy Policy
Subscribe
Notify of
guest

This site uses Akismet to reduce spam. Learn how your comment data is processed.

0 Comments
Inline Feedbacks
View all comments