線形因子モデル

線形因子モデルは、goodfellow本ではその後に続く深層生成モデルの、須山本では第五章の応用モデルの、基礎を成しているとのことなので、やってみました。モデルとしては、PRMLの12章を参考にしています。

線形因子モデルは、潜在変数に情報を詰め込んで、そこからデータを再現します。

  • PPCA, factor analysisは大体これと同じ
  • ICAはローカルの潜在変数に独立を仮定し、ガウス分布以外の物を使う

ので、goodfellow本13章で紹介されている発展的なモデルも、これができればそんなに遠くない(はず)。

潜在変数の解釈

潜在変数というのは、深層学習の文脈で表現(representation)や特徴(feature)と呼ばれるもので、

  • ベイズ系の統計推測での潜在変数
  • NNが隠れ層で学習する多様体
  • オートエンコーダの隠れ層
  • スパースコーディングの成果物
  • 教師なし学習で事前学習する(していた)際の目的
  • CNNのフィルタが学習するもの
  • RNNが時系列で共有するもの(パラメータシェアリング)

などなどは、同じイデアを共有していて、緩やかに繋がっているものだと理解しています。深層学習ではその潜在変数(=表現・特徴)がスパースであるほど統計的に意義がある(uniformに分布してると何も言えないから)かつその潜在変数を使った後続の回帰や分類タスクの精度が上がるので、そのスパース性を求めて様々な正則化が適用されます。

この潜在変数をいかにうまく設計するかが、統計的推論・統計的機械学習のキモです。

実装

グラフィカルモデル

model

観測データ

$$ \quad Y=[y_1,…,y_n] \quad y_n \in \mathbb{R}^D $$

潜在変数

$$ \quad X=[x_1,…,x_n] \quad x_n \in \mathbb{R}^M \\ \quad \textbf{W} \in \mathbb{R}^{M \times D} \quad (\textbf{W}_d \in \mathbb{R}^M \quad W の d 番⽬の列ベクトル)\\ \quad \mu \in \mathbb{R}^D $$

パラメータ

$$ \quad \sigma^2_y \in \mathbb{R}^+ \\ \Sigma_w \\ \Sigma_{\mu} $$

個別の分布

$$ p(\textbf{W}) = \prod^D_{d=1} N(\textbf{W}_d | \textbf{0}, \Sigma_w) $$

$$ p(\mu) = N(\mu | \textbf{0}, \Sigma_{\mu}) $$

$$ p(\textbf{x}_n) = N(\textbf{x}_n | \textbf{0}, \textbf{I}_M) $$

$\textbf{y}_n$ の条件付き分布

$$ p(\textbf{y}_n | \textbf{x}_n, \textbf{W}, \mu) = N(\textbf{y}_n | \textbf{W}^T \textbf{x}_n + \mu, \sigma^2_y \textbf{I}_D) \ $$

同時分布

$$ p(\textbf{Y}, \textbf{X}, \textbf{W}, \mu) = p(\textbf{W})p(\mu)\prod^N_{n=1}p(\textbf{y}_n | \textbf{x}_n, \textbf{W}, \mu)p(\textbf{x}_n) \ $$

model

def print_shape(name, dist, sample_shape=()):
    print(name, ":", "event shape:", dist.event_shape, "batch shape:", dist.batch_shape)
    print(name, ":", "sample shape", dist.sample(key, sample_shape=sample_shape).shape)
    print(name, ":", "whole shape:", dist.shape(sample_shape=sample_shape))
    # print(name, ":", "sample", dist.sample(key, sample_shape=sample_shape))
    print("")

def model(D, M, N, obs=None, debug=False):
    # mu
    loc_mu = jnp.full(D, 0) #jax.random.normal(key, (D,))
    scale_mu = jnp.full(D, 1)
    dist_mu = dist.Normal(loc=loc_mu, scale=scale_mu).to_event()
    mu = numpyro.sample("latent_mu", dist_mu)
    if (debug):
        print_shape("dist_mu", dist_mu)

    # W
    loc_W = jnp.full((M, D), 0) #jax.random.normal(key, (M,D))
    scale_W = jnp.full((M, D), 1)
    dist_W = dist.Normal(loc=loc_W, scale=scale_W).to_event()
    W = numpyro.sample("latent_W", dist_W)
    if (debug):
        print_shape("dist_W", dist_W)

    # X, latent
    loc_x = jnp.full((N, M), 0) #jax.random.normal(key, (N,M))       
    scale_x = jnp.full((N, M), 1)        
    dist_X = dist.Normal(loc=loc_x, scale=scale_x)
    X = numpyro.sample("latent_x", dist_X)
    if (debug):
        print_shape("dist_X", dist_X)

    # Y
    loc_Y = jnp.zeros((N, D))
    for i in range(N):
        loc_Y = loc_Y.at[i].set(jnp.dot(W.T, X[i]) + mu)
    sacle_Y = jnp.full_like(loc_Y, 1)
    dist_Y = dist.Normal(loc=loc_Y, scale=sacle_Y)
    Y = numpyro.sample("Y_obs", dist_Y, obs=obs)
    if (debug):
        print("sacle_Y.shape", sacle_Y.shape)
        print("loc_Y.shape", loc_Y.shape)
        print_shape("dist_Y", dist_Y)

D, M, N = 64*64, 32, 10
print("D:", D, "M:", M, "N:", N, "\n")
prior_model_trace = handlers.trace(handlers.seed(model, key))
prior_model_exec = prior_model_trace.get_trace(D=D, M=M, N=N, obs=None, debug=True)

olivetti face dataset

import pandas as pd
from sklearn import datasets
from skimage.transform import rescale
from skimage import data, color

data = datasets.fetch_olivetti_faces()
df = pd.DataFrame(data.data)
print(df.shape)
df.head()

N = 9
img_res = 64

rndidx = np.random.choice(df.shape[0], N)
imgs = jnp.zeros((len(rndidx), df.shape[1]))
for i in range(len(rndidx)):
    imgs = imgs.at[i].set(df.loc[rndidx[i]].values)

col, row = int(round(np.sqrt(N))), int(round(np.sqrt(N)))
fig = plt.figure(figsize=(10, 10))
for i in range(1, col*row+1):
    fig.add_subplot(row, col, i)
    plt.gray() 
    plt.imshow(imgs[i-1].reshape(img_res, img_res))
    plt.grid(None)
plt.show()

model

reduce 4096 to 9

D = imgs.shape[1]
M = 9
print("reduce", D, "to", M)
guide = numpyro.infer.autoguide.AutoDelta(model)

optimizer = numpyro.optim.Adam(step_size=0.0005)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(key, 2000, D=D, M=M, N=N, obs=jnp.array(imgs))
params = svi_result.params
pp.pprint(params)

復元

expected_mu = params["latent_mu_auto_loc"]
expected_W = params["latent_W_auto_loc"]
expected_x_n = params["latent_x_auto_loc"]
print("expected_mu.shape", expected_mu.shape)
print("expected_W.shape", expected_W.shape)
print("expected_x_n.shape", expected_x_n.shape)

imgs_reconstructed = jnp.zeros((N, D))
for i in range(N):
    x_n = expected_x_n[i]
    imgs_reconstructed = imgs_reconstructed.at[i].set(jnp.dot(expected_W.T, x_n) + expected_mu)

col, row = int(round(np.sqrt(N))), int(round(np.sqrt(N)))
fig = plt.figure(figsize=(10, 10))
for i in range(1, col*row+1):
    fig.add_subplot(row, col, i)
    plt.imshow(imgs_reconstructed[i-1].reshape(img_res, img_res))
    plt.grid(None)
plt.gray()    
plt.show()

model

解釈

$W$が(9, 4096)、$x_i$が(9, 1)なので、$N(\textbf{y}_n | \textbf{W}^T \textbf{x}_n + \mu, \sigma^2_y)$で、4096次元を9次元にエンコードし、そのコードから4096次元をデコードしました。

ただこの場合、$W$の役割がよくわからない。$X$は各$x_i$に顔画像一枚のなんらかの情報をエンコードしてるんだろうけど、データ全体で共有される$W$は何を符号化したものなのか。(教えていただけると嬉しいです。)

model

↑ が$W$を画像として表示したものです。それぞれの顔の特徴らしきものが見えるんだけど、、、

追記

PRML 12.2.4によると、

$W$の列ベクトルは観測変数同士の相関を捉える役割を担い、因子分析モデルの文献では因子負荷(factor loading)と呼ばれる。

とのこと。よくわからんけどまぁいいか。