MNIST 数字数据集上的 UMAP

一个简单的示例,演示如何在 MNIST 等大型数据集上使用 UMAP。我们首先获取 MNIST 数据集,然后使用 UMAP 将其降维到仅 2 维,以便于可视化。

注意,UMAP 不仅能够将单个数字类别分组,还能保留不同数字类别之间的整体全局结构——将 1 保持远离 0,并对 3、5、8 和 4、7、9 进行分组,这些数字在某些情况下可能会相互融合。

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

import umap

sns.set(context="paper", style="white")

mnist = fetch_openml("mnist_784", version=1)
X_train, X_test, y_train, y_test = train_test_split(
    mnist.data, mnist.target, stratify=mnist.target, random_state=42
)

reducer = umap.UMAP(random_state=42)
embedding_train = reducer.fit_transform(X_train)
embedding_test = reducer.transform(X_test)

fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(12, 10))
ax[0].scatter(
    embedding_train[:, 0], embedding_train[:, 1], c=y_train, cmap="Spectral"  # , s=0.1
)
ax[1].scatter(
    embedding_test[:, 0], embedding_test[:, 1], c=y_test, cmap="Spectral"  # , s=0.1
)
plt.setp(ax[0], xticks=[], yticks=[])
plt.setp(ax[1], xticks=[], yticks=[])
plt.suptitle("MNIST data embedded into two dimensions by UMAP", fontsize=18)
ax[0].set_title("Training Set", fontsize=12)
ax[1].set_title("Test Set", fontsize=12)
plt.show()

由 Sphinx-Gallery 生成的图库