注意
跳转到末尾以下载完整示例代码。
MNIST 数字数据集上的 UMAP
一个简单的示例,演示如何在 MNIST 等大型数据集上使用 UMAP。我们首先获取 MNIST 数据集,然后使用 UMAP 将其降维到仅 2 维,以便于可视化。
注意,UMAP 不仅能够将单个数字类别分组,还能保留不同数字类别之间的整体全局结构——将 1 保持远离 0,并对 3、5、8 和 4、7、9 进行分组,这些数字在某些情况下可能会相互融合。
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import umap
sns.set(context="paper", style="white")
mnist = fetch_openml("mnist_784", version=1)
X_train, X_test, y_train, y_test = train_test_split(
mnist.data, mnist.target, stratify=mnist.target, random_state=42
)
reducer = umap.UMAP(random_state=42)
embedding_train = reducer.fit_transform(X_train)
embedding_test = reducer.transform(X_test)
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(12, 10))
ax[0].scatter(
embedding_train[:, 0], embedding_train[:, 1], c=y_train, cmap="Spectral" # , s=0.1
)
ax[1].scatter(
embedding_test[:, 0], embedding_test[:, 1], c=y_test, cmap="Spectral" # , s=0.1
)
plt.setp(ax[0], xticks=[], yticks=[])
plt.setp(ax[1], xticks=[], yticks=[])
plt.suptitle("MNIST data embedded into two dimensions by UMAP", fontsize=18)
ax[0].set_title("Training Set", fontsize=12)
ax[1].set_title("Test Set", fontsize=12)
plt.show()