テンソルネットワークの入り口(その2)

はじめに

 今回は「テンソルネットワークの入り口」の2回目として、実際に深層学習に応用した例を紹介したい。今回のソースコードはここにある。

深層学習への応用

 Googleの出しているテンソルネットワークのフレームワークTensorNetworkと、深層学習フレームワークTensorFlowを用いて、画像分類を行う。データセットはFashion-MNISTである。訓練データ数は60000、テストデータ数は10000、画像サイズは28\times28である。

 テンソルネットワークを考える前のニューラルネットワークの構造は以下の通り(src/main_2.py)。

全結合層だけを用いた構成である。入力画像のサイズは28\times 28であり、これをベクトルに引き延ばした784(=28\times 28)次元ベクトルが入力となる。一番最初(3行目)の全結合層を式で書くと

     \begin{align*} Y=WX+B \end{align*}

となる。ここで、YN(=512)次元ベクトル、XK(=784)次元ベクトル、WN\times K行列、BN次元ベクトルとなる。上式を成分で書くと

(1)    \begin{align*} Y_n=W_{nk}X_k+B_n \end{align*}

となる(アインシュタイン縮約記法を用いた)。図で示すと以下のようになる。

図7

さて、これをテンソルネットワークで表すため天下り的ではあるが以下のように書き換える。

図8

つまり、1階テンソルX,Y,Bを2階テンソルX^{\prime},Y^{\prime},B^{\prime}に変更し、Wを点線の矩形内に示した2つのテンソルの積(MPS)に置換する。式で書くと

(2)    \begin{align*} Y^{\prime}_{mn}=A_{mkj}C_{lkn}X^{\prime}_{jl}+B^{\prime}_{mn} \end{align*}

となる。Yの次元をN=512としたので、Y^{\prime}32\times16の行列とすることができる。先の説明では、MPSに置き換える際に特異値分解を用いた。上の表式に現れるWは学習により決まる行列であるから特異値分解を適用することはできない。そこで、A,Cも学習から決まるテンソルとみなす。以上を実現するコードが以下である(src/tensor_network_layer_for_mnist.py)。このコードはTensorNetworkの公式が公開しているサンプルプログラムからの流用である。

4行目のtensornetworkがテンソルネットワークのモジュールである。8行目から12行目までの変数名は式(2)周辺で用いた変数名と一致させている。20行目のself.a_varAに、23行目のself.c_varCに、26行目のself.baisB^{\prime}に相当する。40行目から65行目までのコードが式(2)の計算に対応する。tensornetworkの使い方に関しては公式のチュートリアルを見てほしい。

 最初に示した全結合層だけのネットワーク構造の中の最初の層だけを上のクラスのインスタンスで置き換える。

3行目がテンソルネットワークで置き換えた層である。それ以外の層は変えていない。さて、パラメータの保存に必要なメモリ量を第1層について計算してみる。式(1)の右辺の成分の総数L_1は、(N,K)=(512,784)であるから行列Wの要素数が784\times512=401408Bの要素数が512となり、これらを足し合わせてL_1=401920となる。一方、式(2)の右辺のパラメータの総数L_2は、kの次元をKとすると、Aの成分数が32\times K\timesK\times28=896KCの成分数が16\times K\times28\timesK=448KB^{\prime}の成分数が512=32\times16なのでこれらを足し合わせてL_2=1344K+512となる。L_1\geq L_2のときメモリ量を節約できることになる。この式を評価するとK\leq298であれば良いことになる。今回の計算ではK=1とした。従ってL_2/L_1=1856/401920=0.4\%となり大幅な削減を実現できる。以下に計算結果を示す。

図9

精度を維持しつつ(少し良くなっている)、ネットワーク容量を圧縮できた。

まとめ

 今回は、テンソルネットワークを使うことで、精度を維持しつつニューラルネットワークのパラメータ数を大幅に減らすことができることを見た。全結合層以外へのテンソルネットワークの応用についても今後調べたい。

参考文献

  • 情報幾何学の基礎
  • テンソルネットワーク入門
  • 物理学の手法「テンソルネットワーク」を用いて計算の圧倒的効率化を図る
  • Google TensorNetwork(GitHub)
  • Google TensorNetwork(サンプルコード)
  • Google TensorNetwork(解説)
  • Kumada Seiya

    Kumada Seiya

    仕事であろうとなかろうと勉強し続ける、その結果”中身”を知ったエンジニアになれる

    最近の記事

    • 関連記事
    • おすすめ記事
    • 特集記事

    アーカイブ

    カテゴリー

    PAGE TOP