注意
跳转至末尾以下载完整的示例代码。
在 MNIST 数字数据集上使用 UMAP
一个简单示例,演示如何在 MNIST 等大型数据集上使用 UMAP。我们首先获取 MNIST 数据集,然后使用 UMAP 将其降至仅 2 维,以便于可视化。
请注意,UMAP 不仅能够将单独的数字类别分组,还能保留不同数字类别之间的整体全局结构——将 1 与 0 保持较远距离,并将 3、5、8 和 4、7、9 这些有时可以相互混合的数字进行分组。
import umap
from sklearn.datasets import fetch_openml
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(context="paper", style="white")
mnist = fetch_openml("mnist_784", version=1)
reducer = umap.UMAP(random_state=42)
embedding = reducer.fit_transform(mnist.data)
fig, ax = plt.subplots(figsize=(12, 10))
color = mnist.target.astype(int)
plt.scatter(embedding[:, 0], embedding[:, 1], c=color, cmap="Spectral", s=0.1)
plt.setp(ax, xticks=[], yticks=[])
plt.title("MNIST data embedded into two dimensions by UMAP", fontsize=18)
plt.show()