St_Hakky’s blog

プログラミング/心理学/人事/留学/データサイエンス/機械学習/Deep Learning/バイオインフォマティクス/日頃思ったこと/人事のデータサイエンスしてみたい

[Deep Learning] Batch sizeをどうやって決めるかについてまとめる

こんにちは。

Deep Learningを自分でゼロから組んで(fine tuningとかではなく)、全部ゼロから学習させるのって大変ですよね。

特に、ハイパーパラメーターの設定にすごく悩みます。トップカンファレンスに出されているような高精度の論文では、そういうハイパーパラメーターはさも当然かのごとく設定されているので、まぁモデルを使い回す分には特に問題ないんですが、こういう風に自分で決めようとすると本当に悩ましいです。

また、Deep Learingは学習に非常に時間がかかりますし、それぞれのハイパーパラメーターの相互関係とかもあり、問題の切り分けが難しいです。

その意味で、グリッドサーチなども対象を決めてやらざるをえず、その場合でもあたりをつけておきたいのが正直な所だと思います。

実際、ハイパーパラメーターにはいくつも種類があるんですが、今回はその中でも、Batch sizeについてどうやって決めるかをまとめておこうと思います。

〇注意事項

この記事では、「私が今使っているデータに対してこうした」というものの中で、汎用的に当てはまるんじゃないかなってものを書いただけなので、「全てのデータに対して必ずこれが当てはまる」とは限りません。

また、私自身まだまだDeepの専門家集団から比べると弱者中の弱者ですので、もしかしたら間違えているかもしれません。

上記、未来の自分も含めお気をつけ下さい笑。

〇よく論文で見るBatch size

Deep Learningの論文を読んでいるとどうやって学習をさせたかみたいな話はほぼ乗っているので、そういうのを見ていてよく見かけるのは以下あたりではないかと思います(すみません、僕の観測範囲ですが)。

  • 1
  • 32
  • 128
  • 256
  • 512

だいたい、1だと完全に確率的勾配降下法になりますし、512だと学習速度をあげたかったのかなという気持ちが見えます。このあたりについてどれにするべきかというところを考察してみたいと思います。

〇mini-batch SGDにおけるBatch size

SGD以外にもAdamとか色々あるんですが、それについては以下の記事を参照してください。

st-hakky.hatenablog.com

SGDによるパラメーターの更新ですが、以下の式のような形で表わせます(参考)。

$$
\theta_{t+1} \longleftarrow \theta_{t} - \epsilon (t) \frac{1}{B} \sum^{B-1}_{b=0} \frac{\partial \mathcal{L} (\theta, \mathbf{m}_b )}{ \partial \theta }
$$

$B$ となっているところが、バッチのサイズです。 $B$ が1であれば、オンラインのSGDになりますし、逆に$B$ がデータ数全体であれば、バッチ学習になります(バッチ学習は確率的でもなくなるのですが)。

ミニバッチは、1〜データ全体の間の数を $B$ に指定することをさします。各エポックごとにランダムにミニバッチのセットが代わり、そのミニバッチそれぞれでパラメーターを更新します。

〇ミニバッチのサイズによって学習の何が変わるか

ミニバッチのサイズによって学習における何が変わるかなのですが、主に以下の3点が変わります(たぶん。。。おそらく。。。きっと。。。)。

  • 1つのサンプルデータに対する反応度
  • 1epochの計算速度
  • メモリ使用量

それぞれ見ていきます。

■1つのサンプルデータに対する反応度

1つのサンプルデータって言葉を使うとちょっとなんかアレなんですが、要は例えば画像のデータセットがあったとして、その中の1枚の画像のこと、と考えてください。

こう考えたときに、更新式をもう一度見てみます。

$$
\theta_{t+1} \longleftarrow \theta_{t} - \epsilon (t) \frac{1}{B} \sum^{B-1}_{b=0} \frac{\partial \mathcal{L} (\theta, \mathbf{m}_b )}{ \partial \theta }
$$

ここで、ミニバッチの単位で重みが更新されていきます。平均をミニバッチ毎にとって、そのミニバッチごとにパラメーターを更新するわけです。

ということは、ミニバッチの単位が小さければ小さいほど、1つ1つのデータに敏感に反応すると、観ることができます。逆にミニバッチの単位が大きければ大きいほど、平均化されるので、1つ1つのデータよりもミニバッチ全体の特徴を捉えるということになります。

ここがまず1点目のミニバッチのサイズに対する学習の挙動の違いです。

■1epochの計算速度

1epochの計算速度が変わります(ここでいっているのは、学習の収束の速さなどではなく、単純に計算速度の話です)。パラメーターの更新の回数というのが、ミニバッチのサイズが小さいほど多くなるので、これは当然と言えば当然です。

なので、とにかく一回ネットワークの学習を一通りepochを回してやってしまいたいと思ったら、1つのバッチサイズは大きくした方が早く終わります。

これが2つ目です。

■メモリ使用量

ミニバッチ単位でデータを読みとり、それを使うとなると、これはミニバッチのサイズが大きければ大きいほど食います。

〇僕のミニバッチサイズの決め方

完全に独自流で誰に教わったわけでもありませんが、一人でやっているときは以下のような感じでやります。

1. 適当に32とかにして、一回学習させてみる
2. もっと学習速度あげて、他のパラメーターをとりあえず色々試したいなぁっておもったら、ミニバッチのサイズをあげる
2. lossの揺れ幅が各epochで大きいなぁってなったら、ミニバッチのサイズを上げる
2. 揺れ幅が小さくて、データのサイズも大きくない場合は、ミニバッチのサイズを下げる

まず、1をやります。その後、状況によって色々変えるという風にしてパラメーターを変えるみたいな感じでやります。その時設定するミニバッチのサイズは、上にあげたよく使われているミニバッチのサイズから選びます(ここはあんまり気にしません)。

〇バッチサイズを変えて学習させてみた

適当なデータで、ネットワークを組んで、バッチサイズを変えて学習させて見た例をTensorboardで出してみました(本筋ではないので、詳細には書きませんが)。縦軸がloss、横軸がepochです。また、学習によってlossがかわりますが、横軸の間隔はほぼ同じです。

以下はデータ数だいたい1万くらいのデータのときに、ミニバッチサイズ32でやった時です。

f:id:St_Hakky:20171116161013p:plain

めちゃブレブレっすね笑。で、以下がミニバッチサイズ128でやったときのやつです。

f:id:St_Hakky:20171116161100p:plain

学習の挙動が割とかわっていることがわかります。また、各epochの速さも変わっていました。

また、今回はデータの値が結構大きく変わる傾向があったので、あまりデータに敏感に反応して学習が進むとうまくいかなさそうというのはありました。なので、少しバッチサイズは大きめで取ったほうがいいのかなぁと思ってました。

〇まとめ

データに依存するのですが、一度ネットワークにデータを適当なミニバッチのサイズでかけてみて、それを見ながら考えるというのでいいのかな、と思っています。

ただ、これよりもいい方法や上で間違えているところがあれば、教えてください。。。。

それでは。

Pythonのクラスメソッド(class method)の定義の仕方とstaticmethodとの違い

こんにちは。今日は上の件について書きます。

○クラスメソッド (class method) とは

クラスメソッドとは、クラス内で定義されたメソッドで、インスタンス化しなくても呼び出すことができるメソッドのことです。

これは、インスタンスではなくて、クラスそのものに対してなんらかの操作をするメソッドを定義するときに用います。

普通、クラス内のメソッドを呼び出そうとした場合は、一度インスタンス化しないといけません。例えば、以下のクラスを定義したと思います。

# 普通のクラス
class SampleClass:
    
    def samle_method(self):
        print("sample")

これのクラスのメソッドを呼び出そうとしたとき、以下のようにするとエラーがでます。

# error!!
SampleClass.sample_method()

クラス内のメソッドを使おうとした場合は、まずインスタンス化して、その後、そのインスタンスからメソッドを呼び出す必要があります。

# 一度インスタンス化
sample_class = SampleClass()

# methodを呼び出す
sample_class.samle_method()

# 出力がちゃんとでる


ここで、クラスメソッドを使うと、インスタンス化することなく、メソッドを呼び出すことができます。

# クラスメソッドを使った例
class SampleClassMethod:

    @classmethod
    def samle_method(clf):
        print("sample")

# インスタンス化することなく、メソッドを呼び出せる
SampleClassMethod.samle_method()

# 出力は"sample"

こんな風な呼び出しを可能にするのが、クラスメソッドです。

○クラスメソッドの定義と呼び出し

クラスメソッドの定義の方法には次の二つの方法があります。

  • classmethod()関数を使う
  • @classmethodというデコレーターを使う

それぞれのやり方を見ていきます。

■classmethod()関数を使う場合

これはあんまり推奨されていないやり方っぽいので、紹介だけします。

class SampleClassMethod:
    def sample_method(clf):
        return clf
    sample_method = classmethod(sample_method)

クラス内で、クラスメソッド化したいメソッドを再度定義することで実現できます。これは、以下で紹介するデコレーターと基本的には同じ動作をするので、以下の方を使えばいいかと思います。

■@classmethodのデコレーターを使う

以下のような感じです。

class SampleClassMethod:
    @classmethod
    def sample_method(clf):
        return clf

こっちの方がPythonっぽいといえばそうかもしれないです笑。

○staticmethodとclassmethodの違い

似てないんですけど、混乱しがちなものにstaticmethodがあります。どちらもインスタンスを作成せずともメソッドを呼び出すことができるのですが、違いがあります。

簡単に特徴を述べると以下のような違いがあります。

  • staticmethod : こちらは、クラス内に定義されていても、クラスに関係なく動き、受け取った引数のみ考慮します。乱暴に言うと、ただの関数と変わりません。
  • classmethod : こちらは、クラス自身を引数として受け取り、受け取った引数と共に用いることができます。

具体的な例をみながら、classmethodの特徴と、staticmethodとの違いについてみたいと思います。まず、以下のようなクラスを用意します。

class DiffClassStaticMethod:
    @classmethod
    def class_meth(*args):
        return args
    
    @staticmethod
    def static_meth(*args):
        return args

まずclassmethodを使ってみます。

# classmethodの方の呼び出し
# 引数に何も用意しなくても、Demo.klassmethはDemoクラスを第一引数として受け取ります。
DiffClassStaticMethod.class_meth()

# 出力 => (__main__.DiffClassStaticMethod,)

引数に何も指定しなくても、第一引数に、クラスを受け取っていることがわかります。次に、staticmethodを使ってみます。

# ごく普通のシンプルな関数としてふるまいます
DiffClassStaticMethod.static_meth()

# 出力 => ()

クラスを受け取らず、staticmethodに定義された引数のみ受け取ります。

正直staticmethodは実用的には関数と何ら変わりないので、関数として書いちゃう方がいいと思います。

○classmethodの慣例

上でみたように、classmethodは、第一引数にクラスを受け取ります。なので、それを以下のようにclfとして書くのが慣例らしいです(エラーはでないですが)。

class ClassMethod:
    @classmethod
    def class_meth(clf, *args):
        print(clf)

以上です。

PandasのDataFrame / Seriesでリスト内の要素にマッチする or しない行 (row) だけ取り出す

こんにちは。

今日はpandasのメモを。

〇やりたいこと

listとかでよくやる以下みたいな判定をやりたいんです。

list_data = [1,2,3,4,5]

if 1 in list_data:
  print('あります')
else:
  print('ありませんでした')

このような、in演算子を用いて、リスト内に含まれている要素であるかどうかという判定を行い、pandasのマッチした or していない行だけ取り出したいんです。

つまり、pandasだと以下のようなことをやりたいイメージです。

# 当然ですが、以下のコードはエラーが起きます。
import pandas as pd
data = pd.DataFrame({'a': [1,2,3,4,5], 'b': [1,2,3,4,5]})

new_data = data[data['a'] in [1,2]]
# エラー
# ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().

これどうやってやるんだろうと思って調べていたんですが、やり方を見つけました。

〇リスト内の要素にマッチ”する”rowを取得する

さて、どうやってやるかなんですが、以下のようにしてやります。

import pandas as pd
data = pd.DataFrame({'a': [1,2,3,4,5], 'b': [1,2,3,4,5]})

new_data = data[data['a'].isin([1,2])]
print(new_data)

isin()を使うんですね、なるほどという感じです。これの出力は、以下のような感じです。

   a  b
0  1  1
1  2  2

〇リスト内の要素にマッチ”しない”rowを取得する

次にマッチ"しない"rowの取得ですが、以下のようにやります。

import pandas as pd
data = pd.DataFrame({'a': [1,2,3,4,5], 'b': [1,2,3,4,5]})

new_data = data[~data['a'].isin([1,2])]
print(new_data)


前に~をつけると、マッチしないものがでに入るそうです(おまえそこは、isnotinとかじゃないのかよと思った人僕の友達です)。

これの出力は以下のような感じです。

   a  b
2  3  3
3  4  4
4  5  5

〇そのほかの使い方について

上の例はいずれもSeriesに対してisinを使用し、DataFrameから対象の列を取ってくる例でしたが、DataFrameに大してもisinは使えます。

詳しくは公式ドキュメントを参照。


それでは。

chardetを使って文字コードを判定する

こんにちは。

今日は文字コードの判定について書きます。

文字コードの識別

もう基本UTF-8にしろよって感じなんですけど、たまにアップロードされたファイルの文字コードがなにか知りたいとかっていうシーンがあります。

今回はそれをPythonでやろうというものなのですが、そもそも文字コードを完璧に識別することはできません。完璧にやろうとすると、どこかから識別してもらうしか方法は基本的にありません。


しかし、バイトシーケンスのエンコーディングを識別するために、何もできないかというとそうでもありません。


バイトストリームが人の読むプレーンテキストだと仮定すると、言語にもルールがあるので、ヒューリスティックな方法や統計手法を使ってエンコーディングを探し出すことはできます。たとえば、ある特定のバイト値が結構あれば、このエンコーディングだな、みたいな感じです。


PyPIから提供されているchardetというパッケージは、対応している30種類のエンコーディングをそんな感じで判定してくれます(CHARacter DETectionで、chardet)。完璧ではないので、本来はあんまりするべきではなさそうですが、どうしても文字コードが特定できない場合に、このライブラリが使えます。

◯chardetを使って識別できるエンコーディング

本家のサイトの一番最初にかかれています。

◯インストール

普通にpipで入ります。

pip install chardet

◯chardetを使う方法

chardetを使うには、次の2つの方法があります。

それぞれ使い方を見ていきます。

コマンドラインからchardetを使う

以下のような感じで使うことができます。

chardetect file1 file2

ファイルは一つでもいいですが、複数ファイルを渡すこともできます。

実際に、utf-8で書かれたREADME.mdファイルで試してみます。

> chardetect README.md
README.md: utf-8 with confidence 0.99

あってますね(笑)。結果はこんなふうに、confidenceも含めて、どの文字コードと識別したかまで返してくれます。

pythonのコードの中から使う

当然のごとく、pythonコードからも使うことができます。

import chardet

with open("README.md", "rb") as f:
  print(chardet.detect(f.read()))

これを実行すると、以下のような出力が帰ってきます。dictで取得できるみたいですね。

{'encoding': 'utf-8', 'confidence': 0.99, 'language': ''}

たぶんこの手の話はwebページのクローリングとかでよく使うそうだなと思うんですが、最近はrequestsにもそういうのがあるとはいえ、必要になったら使おうかなと思います。

Centosでの作業したメモとか

こんにちは。

以下、Centosでよく行う作業のメモです。

Centosの環境

CentOS Linux release 7.3.1611 (Core)

Centosのバージョンの確認とOSが32bitか64bitかの確認

バージョンの確認は以下の通りで行えます。

cat /etc/redhat-release

32bitか64bitかの確認は、以下のコマンドで行えます。

arch

この出力結果が、X86_64の場合は64bit、i686の場合は32bitとなります。

Google Chromeをインストール

yumでいれたいので、yumでいれます。入れる時に参考になったのはこちらのサイト

まず準備としてリポジトリファイルを設定します。

vim /etc/yum.repos.d/google.chrome.repo

上で「google.chrome.repo」のファイルを開いたら、内容を以下の通り追加島須。

[google-chrome]
name=google-chrome
baseurl=http://dl.google.com/linux/chrome/rpm/stable/$basearch
enabled=1
gpgcheck=1
gpgkey=https://dl-ssl.google.com/linux/linux_signing_key.pub

んでもって情報を反映します。

yum update

管理者権限が必要になる箇所が何箇所かありますが、適時sudoを付けていれればいいかと。

そして、上記が完了したら、インストールします。

yum install google-chrome-stable

途中何回か色々聞かれますが、確認してインストールでいいと思います。以上で入ります。

Basic認証

htpasswdをインストールしたいので、それを含んだパッケージをインストールします。

yum -y install httpd-tools

〇終わりに

実際にはこれ以外にもいろいろやっているんだけど、メモすんのめんどくさい笑。

時間があれば、他にもCentosで何か作業したらメモを追加していきたいと思います。