octaveで深層学習

『潮騒のアニマ』(講談社)なんていう推理小説を読んでいる。副題に法医昆虫学捜査官なんて 付いている。確かこのシリーズ、昔読んだことがあるぞ。

事件の現場に残された昆虫や微生物を手掛かりに、犯行日時や場所を特定してく物語。 意外な昆虫達の生態が解説されてたりして、推理+αで、得した気分。

この小説の中で、 フェルミ推定なんてのが紹介されてた。

東京都内にあるマンホールの蓋の枚数は? ってのが、就職試験の定番の問題らしい。 そんな問題必要なんですか?とか、ググレカスとか、ワトスン博士に聞いてみたらなんて、 答えようものならGoogleへは絶対に入れないぞ。

ボケ防止でオイラーも考えてみるか。知恵熱出そうだな。

そんなのよりも、こちらの方が現実的で面白そう。

RubyでSVMをつかう

勾配法

人工知能の学習理論

いよいよ 人工知能の学習理論に取り掛かる。

まずは、第1回目の学習とは、から。 ノートを一読し、プログラムとデータを頂いてきて、走らせてみると

[sakae@fedora 1]$ octave-cli -q
octave:1> jh1501
error: operator *: nonconformant arguments (op1 is 8x25, op2 is 26x1)
error: called from
    jh1501> at line -1 column -1
    jh1501 at line 76 column 4
error: evaluating argument list element number 1
error: called from
    jh1501> at line -1 column -1
    jh1501 at line 76 column 4

早速エラーの洗礼を受けましたよ。エラーになってる行は、

    75    t=ydata(:,ii);
    76    h=hid(w,th,x);
    77    o=out(u,ph,h);

よく分からないので、別の例題として5回目のやつを取ってきて実行すると、やはり同様な 所でエラー。で、プログラムをemacsで開いてみると、行末に ^M ってのが表示されてた。 これって、Windows機で作られたファイルだな。と言う事は、使われている訓練データも Windows機で作られたに違いない。と、フェルミ推定。

行末の扱いがUnixとWindowsでは違うんでないかと言う、第二段推定。そういう目で、もう 一度見直し。

直接のエラーは、掛け算で、次元が合わないと言ってる。ならば、どういう次元か確認するか。

>> size(w)
ans =

    8   25

>> size(th)
ans =

   8   1

>> size(x)
ans =

   26    1

推定するに、wとxの掛け算をやってるんだな。前の方に遡って、hidって関数を探す。

hid=@(w,th,x)(sig(w*x+th));

27行目でそれっぽいのが見つかった。確かにwとxを掛けている。それより、@ って、 どういう意味? helpで調べても、not found。まさか説明が無いとは!

こういう時は、本式のマニュアルを当たる。 11.11.2 Anonymous Functions に説明が有った。匿名関数って、 Schemeで言うLambdaの事か。そして、前の節にも@の用法が出てて、関数ハンドラとしても 使うとな。ちょいと、ややこしくないかい。

ノートの説明によると、5x5ドットの画像を学習する事になってるから、一文字が25ドットの データとなる。それに対してxが26ってのは、いかにもおかしいぞとオーラを放っていると 推定される。

で、今度は、前に方に向かってxの成り立ちを追っていく。

    31  A=dlmread('jh1501_train.txt');
    32  xdata=A';
    74    x=xdata(:,ii);
>> size(A)
ans =

   400    26
>> A
   :
 Columns 25 and 26:

   0.07000   0.00000
   0.38000   0.00000

読み込んだ元データを見ると、一つ余計なデータが混じっている。どうもdlmreadって 関数が怪しいな。helpして要点を上げると

 -- Built-in Function: DATA = dlmread (FILE)
 -- Built-in Function: DATA = dlmread (FILE, SEP)
 -- Built-in Function: DATA = dlmread (FILE, SEP, R0, C0)
 -- Built-in Function: DATA = dlmread (FILE, SEP, RANGE)
 -- Built-in Function: DATA = dlmread (..., "emptyvalue", EMPTYVAL)
     Read the matrix DATA from a text file which uses the delimiter SEP
     between data values.

     If SEP is not defined the separator between fields is determined
     from the file itself.

区切りはファイル自身に任せるとな。データファイルのセパレータにはスペースが使われている。それはいいんだけど、行末がWindows仕様のCRLFになってる。unixの行末はLFなんで、 このまま使うとCRがセパレータと解釈され、CRとLFの間にNULL文字が入っていると思う んだろう。(これで、何段目のフェルミ推定やら)

で、NULL文字を数値に変換して、0.000が補われた。かくして、余計なフィールドが1つ 追加された。

試しに、dos2unixを使って、unix用にトレーニングデータ(とテストデータ)を変換して みるか。

[sakae@fedora 1]$ dos2unix jh1501_train.txt
dos2unix: converting file jh1501_train.txt to Unix format...
[sakae@fedora 1]$ dos2unix jh1501_test.txt
dos2unix: converting file jh1501_test.txt to Unix format...

これで走らせてみると、無事に動いた。

>> jh1501
[    0] Training Error=74.223810, Test Error=83.902507
[   50] Training Error=5.442099, Test Error=15.501147
[  100] Training Error=2.555474, Test Error=22.482521
[  150] Training Error=2.305331, Test Error=22.239709
[  200] Training Error=1.650833, Test Error=21.368431
[  250] Training Error=0.415568, Test Error=21.840705
[  300] Training Error=0.327562, Test Error=21.624879
[  350] Training Error=0.323466, Test Error=21.471653
[  400] Training Error=0.319438, Test Error=21.516938
[  450] Training Error=0.317358, Test Error=21.608426
[  500] Training Error=0.315550, Test Error=21.749871
[  550] Training Error=0.314320, Test Error=21.839179
[  600] Training Error=0.313282, Test Error=21.892183
[  650] Training Error=0.312383, Test Error=21.877501
[  700] Training Error=0.312011, Test Error=21.815095
[  750] Training Error=0.311337, Test Error=21.755891
[  800] Training Error=0.310792, Test Error=21.705044
[  850] Training Error=0.310198, Test Error=21.657126
[  900] Training Error=0.309682, Test Error=21.614150
[  950] Training Error=0.309068, Test Error=21.575336
TRAINED LIST (Error>0.50) -------------
Error/TRAINED = 0/400 = 0.000
TEST LIST (Error>0.50) -------------
14: TRUE=(1.00, 0.00), PRED=(0.35, 0.65)
39: TRUE=(1.00, 0.00), PRED=(0.01, 0.99)
73: TRUE=(1.00, 0.00), PRED=(0.02, 0.98)
145: TRUE=(1.00, 0.00), PRED=(0.46, 0.54)
165: TRUE=(1.00, 0.00), PRED=(0.04, 0.96)
180: TRUE=(1.00, 0.00), PRED=(0.04, 0.96)
204: TRUE=(0.00, 1.00), PRED=(0.52, 0.48)
226: TRUE=(0.00, 1.00), PRED=(0.91, 0.09)
266: TRUE=(0.00, 1.00), PRED=(0.66, 0.34)
278: TRUE=(0.00, 1.00), PRED=(0.64, 0.36)
280: TRUE=(0.00, 1.00), PRED=(0.72, 0.28)
302: TRUE=(0.00, 1.00), PRED=(0.77, 0.23)
306: TRUE=(0.00, 1.00), PRED=(0.98, 0.02)
332: TRUE=(0.00, 1.00), PRED=(0.66, 0.34)
347: TRUE=(0.00, 1.00), PRED=(0.63, 0.37)
364: TRUE=(0.00, 1.00), PRED=(0.95, 0.05)
Error/TEST = 16/400 = 0.040

なお、文字列を数値に変換する関数に空文字列を渡すと、下記の挙動となった。 dlmreadは組み込み関数なので、そのOS用にチューニングされていると、これまた フェルミ推定される。詳しくはCフラフラ語を読んで、フラフラになるが良い。

>> str2num("")
ans = [](0x0)
>> str2double("")
ans = NaN

で、いよいよコードを見るかと、取り掛かったんだけど、ちと歯が立たず。いままで あちこちのサイトを覗いて、観念的な事は分かった積りなんだけど、いきなりoctaveじゃ octave上がってしまっていて、声も出ないよ。

オイラーは、昔バスパートにぴったりな声をしてますって言われて、合唱部に誘われた 事がある。オクターブ上の声なんて、出るわけがない。

そんな訳で、一段かませて、オクターブへの橋渡しをする事にした。

haskellで

前回調べておいたオイラーのフィールドに合うのがあったので、観察。

線形システムの一括更新学習と逐次更新学習をそれぞれHaskellで書いてみた

このコード、幸いな事にhugsでも走る、超基本的なモジュールしか使っていない。 論よりrun。うだうだ考えてるより、取り合えず走れ。

[ob: ~]$ runhugs dnn.hs
This is sample of learning by linear system.
Learn parameter "a" of the function "f(x) = ax".

Training data = [(1.0,1.0),(2.0,2.5),(3.0,2.5),(4.0,4.5),(5.0,4.5)]
Initial value of "a" = 2.0

---------------------Batch learning---------------------

Approximation error: 58.0(Training times: 0)
Approximation error: 1.55199999140787(Training times: 1)
Approximation error: 0.987519999848165(Training times: 2)
Approximation error: 0.981875199998358(Training times: 3)
Approximation error: 0.981818752000021(Training times: 4)
Approximation error: 0.981818187519996(Training times: 5)
Approximation error: 0.9818181818752(Training times: 6)
Approximation error: 0.981818181818752(Training times: 7)
Approximation error: 0.981818181818187(Training times: 8)
Approximation error: 0.981818181818182(Training times: 9)
Approximation error: 0.981818181818182(Training times: 10)
Learned "a" value: a = 0.981818181920111

------------------Sequential learning------------------
Approximation error: 58.0 (Training times: 0)
Approximation error: 13.0375000257726 (Training times: 1)
Approximation error: 3.03437499990874 (Training times: 2)
Approximation error: 1.15234374910796 (Training times: 3)
Approximation error: 0.99121093770322 (Training times: 4)
Learned "a" value: a = 0.96874999985863

2種類の学習方法が提示されてたけど、基本原理は一緒だな。取り合えず 適当な初期値を決めて、それで計算。正しい値と比べて、次の値を決める。 これを繰り返して、正しい値にある誤差内で一致したらおしまい。

これって、デルタシグマ形 AD コンバータとそっくりじゃん。ちょろちょろと追いつていく所が。 効率的にやろうとすれば、逐次比較型 バイナリーサーチ法か。果たして 機械学習にはあるのだろうか?

で、一番の問題は、比較結果から、ちょろっと変化させる方法。足したらいいのか、 引いたらいいのか。これを間違うと絶対に正しい答えに到達出来ない。

この問題の解決に微分が使われるとな。そして、これをニューロンに適用すると 難しい偏微分とかが出て来るとな。

単一ニューロンによる逐次更新学習(オンライン学習)アルゴリズムをHaskellで書いた

ふう、大変だ。OpenBSD上のhugsでは、動かん。

[ob: ~]$ runhugs nyu.hs
This is sample of learning algorithm with single neuron.

Program error: arithmetic overflow

しょうがないので、Fedoraに入れてる、ghcでやってみる。

[sakae@fedora ~]$ runhaskell nyu.hs
This is sample of learning algorithm with single neuron.
Training data: [[0.0,1.0,0.0],[1.0,0.0,0.0],[0.0,0.0,0.0],[1.0,3.0,1.0],[2.0,1.0,1.0],[1.5,2.0,1.0]]
initial weights: [5.965031025690636e-3,6.998965691455217e-4,8.090974964325992e-4]
Learned weights: [-2.47306843004465,1.752590684740484e-2,-0.5414719252377221]
Input: [0.0,1.0]
Target output: 0.0
Output of neuron: 2.952429496881601e-3

Input: [1.0,0.0]
Target output: 0.0
Output of neuron: 2.952429496881601e-3

Input: [0.0,0.0]
Target output: 0.0
Output of neuron: 5.4232818408139576e-5

Input: [1.0,3.0]
Target output: 1.0
Output of neuron: 0.9979293706483066

Input: [2.0,1.0]
Target output: 1.0
Output of neuron: 0.8982408878699047

Input: [1.5,2.0]
Target output: 1.0
Output of neuron: 0.9848997700454285

hugsで動かないのが癪に障るので、当たりを付けて初期値を手動で与えてみた。

        ws_with_theta = [0.001, 0.002, 0.003] :: [Double]

そしたら動いた。乱数発生は昔から難解な事になっている。 Haskell 98 言語とライブラリを見る限りでは、間違った使い方をしていないようなんだけど。。。

取り合えず動いたので、いじり倒せます。なお、Windows7機の32Bit版hugsでは、 何の問題もなく、走りましたよ。

octaveのコード読み

一番最初の講義は、学生を引き付けるつかみのものだと思うので、5番目のやつを やってみる。土台はdebianに入っていたoctave 3.8.2という環境。

octave:1> jh1505
Training Characters.
[0] Training error=62.363847, Test error=71.904803
[50] Training error=6.069484, Test error=18.877028
[100] Training error=2.565096, Test error=23.489988
[150] Training error=2.248376, Test error=23.601673
[200] Training error=2.162079, Test error=23.432055
[250] Training error=2.117818, Test error=23.210915
[300] Training error=2.062441, Test error=22.793030
[350] Training error=1.552640, Test error=21.585753
[400] Training error=0.302569, Test error=21.431707
[450] Training error=0.210536, Test error=21.528878
Error in Train:
   Error/TRAINED = 0/400 = 0.000
Error in Test:14 39 73 165 180 226 266 278 280 306 364 370 382
   Error/TEST = 13/400 = 0.033
Elapsed time is 39.8303 seconds.

実行時間を測る為、冒頭に tic(); を入れてタイマースタート。一番最後に、toc()を 入れて、かかった時間を表示。タイマーは動きっぱなしなんで、toc()をすると、ticからの 継続時間が測れるぞ。

どんなグラフを表示するか、一応確認しておく。

Figure1は、0のモザイクが2x5、6のモザイクも2x5で表示された。Figure2は、青線で トレーニングエラー、赤線でテストエラーのグラフ、450回分まで表示。 Figure3は、色付きモザイク、1x2の隠れ層から出力分。2x4の入力から隠れ層への、 重みデータ。

重みは、前段が5x5、後段は、2x4になってた。要約されたんだな。

そして、ハラハラしなくてもいいように、グラフ表示部分を、 M-; を使って、コメント/アンコメント してみる。やってみたけど、大勢に影響は なかったぞ。

コードを読んでいく時は、余計なものが無い方が、集中出来るから消しておくか。 後、どんな変数名を使っているか整理しておかんとな。octaveには、そんな事も あろうかと、変数名一覧を表示するコマンドが用意されてる。

octave> who
Variables in the current scope:

A                M                err1             o
ALPHA            MODCYC           err2             out
B                N                fff              ph
CYCLE            PIX              h                sig
EPSILON          RIDGE            hid              t
ETA              counter          i                th
Err0             cycle            ii               u
Err1             delta1           max1             w
Err2             delta2           max2             x
H                dph              maxarg1          xdata
HYPERPARAMETER1  dth              maxarg2          xtest
HYPERPARAMETER2  du               n                ydata
LASSO            dw               ntest            ytest

やや、面白いコマンドが有るな。

     'profile on'
          Start the profiler, clearing all previously collected data if
          there is any.

     'profile off'
          Stop profiling.  The collected data can later be retrieved and
          examined with calls like 'S = profile ("info")'.

          See also: profshow, profexplore.

こんな風に使うのか

profile on;
  # code here
profile off;
SS=profile('info');

まずは概要。

octave:3> profshow (SS)
   #         Function Attr     Time (s)        Calls
----------------------------------------------------
  79         binary *             3.955      2818605
  84 anonymous@:68:28             3.730       208800
  89 anonymous@:69:28             3.557       208800
  85 anonymous@:67:26             3.436       417600
  38         binary +             1.702      2443760
  40         binary -             1.663      2016005
  65          drawnow             1.258            1
  87              exp             0.631       417600
  69        postfix '             0.615       800003
  90        binary .*             0.506       800000
  86         prefix -             0.398       417600
  88        binary ./             0.288       417600
  91 anonymous@:56:15             0.259       400000
  11              mod             0.211       200501
  81         binary /             0.151       400556
  83            floor             0.125       200000
  92              dot             0.015         8000
  78          dlmread             0.012            2
  50              max             0.005         1601
  34          findall             0.003            1

そして探検モードに突入。

octave:5> profexplore

Top
  1) anonymous@:68:28: 208800 calls, 7.059 total, 3.730 self
  2) anonymous@:69:28: 208800 calls, 6.577 total, 3.557 self
  3) binary *: 2401005 calls, 3.044 total, 3.044 self
  4) binary -: 2016001 calls, 1.663 total, 1.663 self
  5) clf: 1 calls, 1.276 total, 0.001 self
  6) binary +: 1608544 calls, 1.017 total, 1.017 self
  7) postfix ': 800002 calls, 0.615 total, 0.615 self
  8) binary .*: 800000 calls, 0.506 total, 0.506 self
  9) anonymous@:56:15: 400000 calls, 0.259 total, 0.259 self
  10) mod: 200500 calls, 0.211 total, 0.211 self
  11) binary /: 400556 calls, 0.151 total, 0.151 self
  12) floor: 200000 calls, 0.125 total, 0.125 self
  13) dot: 8000 calls, 0.015 total, 0.015 self
  14) dlmread: 2 calls, 0.012 total, 0.012 self
  15) max: 1600 calls, 0.005 total, 0.005 self
  16) close: 1 calls, 0.003 total, 0.000 self
  17) fprintf: 29 calls, 0.001 total, 0.001 self
  18) binary ==: 1302 calls, 0.001 total, 0.001 self
  19) not: 800 calls, 0.001 total, 0.001 self
  20) binary <: 800 calls, 0.000 total, 0.000 self
  21) randn: 4 calls, 0.000 total, 0.000 self
  22) zeros: 14 calls, 0.000 total, 0.000 self
  23) profile: 1 calls, 0.000 total, 0.000 self
  24) clear: 1 calls, 0.000 total, 0.000 self
  25) toc: 1 calls, 0.000 total, 0.000 self
  26) tic: 1 calls, 0.000 total, 0.000 self

profexplore> help

Commands for profile explorer:

exit   Return to Octave prompt.
quit   Return to Octave prompt.
help   Display this help message.
up [N] Go up N levels, where N is an integer.  Default is 1.
N      Go down a level into option N.

ああ、またページが尽きた。続く。