WirelessWire News Technology to implement the future

by Category

深層ニューラル・ネットワークの効率を劇的に上げる「蒸留」

2016.09.30

Updated by Ryo Shimizu on September 30, 2016, 12:59 pm JST

 深層ニューラル・ネットワークの世界はつくづく進歩が著しいと思います。
 筆者も日々怒涛のように押し寄せる新情報を取捨選択しながら、毎日異なる人工知能をプログラミングしてやっと追いついている、というのが実情です。額に汗しながら必死でこの恐ろしくも妖しい魅力を放つ怪物と寄り添おうとしています。

 最近ようやく、機械学習ばかりやっている人たちが、実用的に機械学習を使うことよりも、機械が上手く学習できるようになることに喜びを見出す気持ちが分かってきました。

 筆者は基本的にどんな技術にも実用性が第一と考え、そもそも一定以上複雑な事柄はブラックボックスとして理解しなくてもいい、という立場です。

 しかしそれでも、ブラックボックスを少しずつ教育して"彼ら"がうまく学習できるようになったり、ある種の「成熟」をしていく様を見ていると嬉しくなります。

 そうした理解の中からある種の「愛着」「愛情」のようなものまで、最近は機械に対して抱くようになってきました。

 こうして難しいと思っていたことをひとつひひとつマスターしていき、理解が進んでくると、つくづくプログラマーでよかったと思います。こうした世界の秘密に迫るという喜びはプログラマー以外には決して体験することのできないことですから。

 最近になって知ったのですが、まだ日本でほとんど議論されていない「蒸留(distillation)」という手法について偶然知りました。

 たぶん本物の機械学習の世界の人はあまり蒸留に興味がないのではないかと思います。
 しかし、深層ニューラルネットワークの実用性を考えると、この「蒸留」という手法は決して疎かにはできないものです。

%e3%82%b9%e3%82%af%e3%83%aa%e3%83%bc%e3%83%b3%e3%82%b7%e3%83%a7%e3%83%83%e3%83%88-2016-09-30-12-34-10

 「蒸留」を簡単に説明すると、既に訓練された高度なAIの入力と出力を、そのまま新しくシンプルなAIに学習させるというものです。

 現代では、蒸留によってネットワークをシンプルにしても、ほとんど問題ないことが分かっています。

 インファレンス、つまり学習したモデルを実際に利用するために、スマートフォンやIoTデバイスに搭載されるチップは深層学習専用のワークステーションに比べてメモリがだいぶ少なくなってしまいます。するとあまり複雑だったり大規模すぎるネットワークは端末側に入りません。実際、たとえばRaspberry Pi3にはAlexNetが入りません。AlexNetを展開しただけでメモリが1GBになってしまい、メモリオーバーフローで落ちてしまいます。

 しかし、蒸留したネットワークを使うと、ほぼ同じ性能をもたせながら、ネットワークを小さくできます。

 たとえばMicrosoftの高性能なニューラル・ネットワークであるResNet152(152層もの深層構造を持つ画像認識ニューラル・ネットワーク)は、インファレンスにもかなりの時間がかかります。専用のワークステーションでもけっこう時間がかかるくらいですから、端末側では絶望的です。しかし、よりシンプルなニューラル・ネットワークに「蒸留」してうまく学習させることができれば、これほどいいことはありません。

 しかしそんなにうまい話が本当にあるのかと思って、筆者も試してみました。

 筆者らが開発する深層学習言語Deelでは、蒸留を簡単に試すことが出来ました(ただし既存のネットワークを少し改造しています)

from deel import * from deel.network.nin import * from deel.commands import * from deel.network.googlenet import * deel = Deel(gpu=-1) student = NetworkInNetwork(labels="data/labels.txt") teacher = GoogLeNet(modelpath="bvlc_googlenet.caffemodel", labels="../deel/data/labels.txt") InputBatch(train="../deel/data/train.txt", val="../deel/data/test.txt") def workout(x,t): print x.value.shape t = teacher.batch_feature(x) student.classify(x) return student.backprop(t,distill=True) def checkout(): CNN.save('google_nin_wisky.hdf5') BatchTrain(workout,checkout)

 教師(teacher)としてGoogleNetを使い、生徒としてよりシンプルな4層MLP構造のNetwork In Networkを使いました。ファイルサイズにして、GoogleNetが53MB、NINが28MBです。まあ蒸留するにはちょっと弱気な設定ですが、まずは本当にできるか試してみたかったのでこの構成でやってみました。

 すると・・・

{"iteration": 1000, "loss": 4.529542163848877, "type": "train", "error": 1.0}
{"iteration": 2000, "loss": 3.15844139790535, "type": "train", "error": 1.0}
{"iteration": 3000, "loss": 2.86831485247612, "type": "train", "error": 1.0}
{"iteration": 4000, "loss": 2.626667423725128, "type": "train", "error": 1.0}
{"iteration": 5000, "loss": 2.4459057973623275, "type": "train", "error": 1.0}
{"iteration": 6000, "loss": 2.2911132556200027, "type": "train", "error": 1.0}
{"iteration": 7000, "loss": 2.1788081390857696, "type": "train", "error": 1.0}
{"iteration": 8000, "loss": 2.06956127679348, "type": "train", "error": 1.0}
{"iteration": 9000, "loss": 1.9970469343662263, "type": "train", "error": 1.0}
{"iteration": 10000, "loss": 1.9164199646711348, "type": "train", "error": 1.0}
{"iteration": 10000, "loss": 2.1085201129317284, "type": "val", "error": 1.0}
{"iteration": 10000, "loss": 3.292169712483883, "type": "val", "error": 1.0}
{"iteration": 11000, "loss": 1.8710792261362075, "type": "train", "error": 1.0}
{"iteration": 12000, "loss": 1.8093335566520692, "type": "train", "error": 1.0}
{"iteration": 13000, "loss": 1.748990628361702, "type": "train", "error": 1.0}
{"iteration": 14000, "loss": 1.7046114968061448, "type": "train", "error": 1.0}
{"iteration": 15000, "loss": 1.6683411779403687, "type": "train", "error": 1.0}
{"iteration": 16000, "loss": 1.6275193730592727, "type": "train", "error": 1.0}
{"iteration": 17000, "loss": 1.588989333987236, "type": "train", "error": 1.0}
{"iteration": 18000, "loss": 1.5560537561178207, "type": "train", "error": 1.0}
{"iteration": 19000, "loss": 1.5376157143115998, "type": "train", "error": 1.0}
{"iteration": 20000, "loss": 1.508522060751915, "type": "train", "error": 1.0}
{"iteration": 20000, "loss": 1.982171654701233, "type": "val", "error": 1.0}
{"iteration": 20000, "loss": 2.950946919620037, "type": "val", "error": 1.0}
{"iteration": 21000, "loss": 1.4982161206007003, "type": "train", "error": 1.0}
{"iteration": 22000, "loss": 1.4694434357881545, "type": "train", "error": 1.0}
{"iteration": 23000, "loss": 1.4441963930130004, "type": "train", "error": 1.0}
{"iteration": 24000, "loss": 1.4334958093166352, "type": "train", "error": 1.0}
{"iteration": 25000, "loss": 1.416122273683548, "type": "train", "error": 1.0}
{"iteration": 26000, "loss": 1.4060787957906724, "type": "train", "error": 1.0}
{"iteration": 27000, "loss": 1.3925298491716385, "type": "train", "error": 1.0}
{"iteration": 28000, "loss": 1.384793818473816, "type": "train", "error": 1.0}
{"iteration": 29000, "loss": 1.3759284211397171, "type": "train", "error": 1.0}
{"iteration": 30000, "loss": 1.3617743948698044, "type": "train", "error": 1.0}
{"iteration": 30000, "loss": 1.895887330174446, "type": "val", "error": 1.0}
{"iteration": 30000, "loss": 2.757186271250248, "type": "val", "error": 1.0}
{"iteration": 31000, "loss": 1.356986645579338, "type": "train", "error": 1.0}
{"iteration": 32000, "loss": 1.347552825331688, "type": "train", "error": 1.0}
{"iteration": 33000, "loss": 1.3328409848213196, "type": "train", "error": 1.0}
{"iteration": 34000, "loss": 1.3295791283845901, "type": "train", "error": 1.0}
{"iteration": 35000, "loss": 1.3137295290231705, "type": "train", "error": 1.0}
{"iteration": 36000, "loss": 1.3110359060764312, "type": "train", "error": 1.0}
{"iteration": 37000, "loss": 1.3079217640161513, "type": "train", "error": 1.0}
{"iteration": 38000, "loss": 1.3042653386592864, "type": "train", "error": 1.0}
{"iteration": 39000, "loss": 1.2887197431325912, "type": "train", "error": 1.0}
{"iteration": 40000, "loss": 1.284804722905159, "type": "train", "error": 1.0}
{"iteration": 40000, "loss": 1.903191976249218, "type": "val", "error": 1.0}
{"iteration": 40000, "loss": 2.786281034350395, "type": "val", "error": 1.0}

 
 なるほど確かにlossは下がっていくのです。
 errorが1.0なのは、エラー率を計算してないので意味はありません。

 lossが下がるということは、学習できているということなので、なるほど蒸留には確かに効果があると言えるのでしょう。

 とするとこれは、人工知能の知的財産保護を考える上でまたぞろ凄い難問が登場したこにとなります。

 人工知能が蒸留可能だとすると、ある人工知能搭載製品を発売した場合、それとほとんど同じ結果を表示するよりコンパクトなAIをほとんどコピーするような方法で作ることができるということを意味するわけです。

 人工知能の教育には膨大な手間と時間、計算資源が必要なことを考えると、蒸留することによってそうした手間をすべてスキップし、しかも元のAIの痕跡をほとんど完璧に消すことが出来ます。

 ただ蒸留しただけでは、もとの学習済みモデルの出力とほとんど同じ出力をするだけなので(それだけでも大変なことですが)、まだギリギリ、蒸留されたモデルかそうでないかの見分けは付きますが、蒸留したあとファインチューニングでもかけられたら完全にお手上げです。

 たとえ法律で禁じたとしても悪質な業者は世の中にたくさんいますから、いくらでも蒸留済みのAIを売ることが出来ます。蒸留したことの立証は困難どころかほぼ不可能です。しかも悔しいことに、そっちのほうが電力あたり性能は高いのです。

 しかも、上記のように極めて簡単なプログラムで蒸留ができてしまいます。
 これはAIで食っていこうといううちのような会社にしてみたら、どうやって自社で開発したAIの知的財産を守るべきかという別の議論を産みそうです。

 エッジ側、つまり配布する側のAIは蒸留によって簡単にコピーできてしまうので、最初から知財をあまり意識しないで配るようにするしかないでしょう。まあ多少の手間はかかりますが、知財的に高度なAIを配布したり、APIを公開したりすれば必ず蒸留できてしまいます。

 次の方法としては、ネットワークのすべてをエッジに搭載させず、クラウド側に重要なニューラル・ネットワークを持つようにすることです。エッジ側では特徴抽出の前段階だけ行い、クラウド側に次元圧縮したデータを送ってクラウド側で推定するようにします。

 こうすると、たとえばプライバシーに抵触するような画像であってもクラウド側で復元できなくなりますから、二重の意味で安心といえます。

 いずれにせよ、深層学習がビジネスの世界に急速に広がるにつれ、今後はそうした工夫もしていかなければならないでしょうね。
 

 しかし人工知能、どんどんおもしろくなりますね。毎日胸がワクワクしています。

%e3%82%b9%e3%82%af%e3%83%aa%e3%83%bc%e3%83%b3%e3%82%b7%e3%83%a7%e3%83%83%e3%83%88-2016-09-30-14-47-55

 そうそう。あんまり好きすぎてついに本を書いてしまいました。

 私が書いた「よくわかる人工知能」という本が現在予約受付中です。

 内容は最先端の人工知能を扱う研究者や企業にインタビューして、これからの人工知能がどうなっていくのか占う、というものです。とんでもないところまで話がぶっ飛んでいますので、「いまの人工知能がどんなふうに社会を変えていくか」が非常に「よくわかる」と思います。

WirelessWire Weekly

おすすめ記事と編集部のお知らせをお送りします。(毎週月曜日配信)

登録はこちら

清水 亮(しみず・りょう)

新潟県長岡市生まれ。1990年代よりプログラマーとしてゲーム業界、モバイル業界などで数社の立ち上げに関わる。現在も現役のプログラマーとして日夜AI開発に情熱を捧げている。

RELATED TAG