雑多な技術系メモ

自分用のメモ。内容は保証しません。よろしくお願いします。

二変数の正規分布のプロット

以下のようなグラフをプロットするソースコードの紹介

f:id:ttt242242:20190807144456p:plain

ソースコード

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from mpl_toolkits.mplot3d import Axes3D

mu_x, mu_y = 0, 0   # 平均
variance_x, variance_y = 15, 15 # 分散

x = np.linspace(-10,10,500)
y = np.linspace(-10,10,500)
X, Y = np.meshgrid(x,y)
pos = np.empty(X.shape + (2,))
pos[:, :, 0] = X; pos[:, :, 1] = Y  # pos[x][y][x, y]
rv = multivariate_normal([mu_x, mu_y], [[variance_x, 0], [0, variance_y]])  # 多変量正規分布の定義

fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_wireframe(X, Y, rv.pdf(pos))
ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_zlabel('Z axis')
plt.show()

参考文献

https://stackoverflow.com/questions/38698277/plot-normal-distribution-in-3d