Kerasでニューラルネットワークを作る
Deep Q-Network(DQN)による倒立振子 第3回
はじめに
今回は「mainQN = QNetwork(hidden_size=hidden_size, learning_rate=learning_rate) # メインのQネットワーク」この一行を見ていきます。
クラス:QNetworkを呼び出しています。
# [2]Q関数をディープラーニングのネットワークをクラスとして定義 class QNetwork: def __init__(self, learning_rate=0.01, state_size=4, action_size=2, hidden_size=10): self.model = Sequential() self.model.add(Dense(16, activation='relu', input_dim=4)) self.model.add(Dense(16, activation='relu')) self.model.add(Dense(2, activation='linear')) self.optimizer = Adam(lr=0.01) # 誤差を減らす学習方法はAdam self.model.compile(loss=huberloss, optimizer=self.optimizer)
self.model = Sequential()
Sequentialはただ層を積み上げるだけの単純なモデル。まずSequentialモデルを作ります。
self.model.add(Dense(16, activation='relu', input_dim=4))
最初の隠れ層を作ります。16個のノードそれぞれに対して4つの数値が入ってきますが、その4つの数値の重み付け和(+バイアス)を適当にとって最後に関数をかけます。そのときの関数が活性化関数です。今回はランプ関数。
self.model.add(Dense(16, activation='relu'))
次の隠れ層を作ります。16個のノードからそれぞれ値が次の16個のノードにやってきます。ノードにやってきた16個の値に対して重み付き和とランプ関数を適用します。
self.model.add(Dense(2, activation='linear'))
最後に出力層です。16個のノードから(活性化関数が適用されたあと)それぞれ値がやってきます。出力は2次元で、ノードにやってきた10個の値に対して重み付き和と線形関数を適用します。
self.optimizer = Adam(lr=0.01) # 誤差を減らす学習方法はAdam
オプティマイザ(最適化アルゴリズム)はモデルをコンパイルする際に必要となるパラメータの1つです。オプティマイザはたくさんありますが今回はAdamを使っています。
Adamオプティマイザ
例: keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
デフォルトパラメータは提案論文に従います。
引数
lr: 0以上の浮動小数点数、学習率
beta_1: 浮動小数点数, 0 < beta < 1. 一般的に1に近い値です
beta_2: 浮動小数点数, 0 < beta < 1. 一般的に1に近い値です
epsilon: 0以上の浮動小数点数、微小量、NoneならばデフォルトでK.epsilon()
decay: 0以上の浮動小数点数、各更新の学習率減衰.
amsgrad: 論文"On the Convergence of Adam and Beyond"にあるAdamの変種であるAMSGradを適用するかどうか
self.model.compile(loss=huberloss, optimizer=self.optimizer)
最後にコンパイルを行います。
1つめは最適化関数、2つめは損失関数、3つめは評価指標のリストです。
これで2つの隠れ層を持つニューラルネットワークが出来上がります。
活性化関数(activation)とは?
ニューラルネットワークでは、前の層からやってきた値の重み付き和を取ります。そこにバイアスの定数を加え、最後にある関数を適用します。
この適用する関数が活性化関数です。
活性化関数はたくさんありますが今回使っているのはこの2つ
ランプ関数(relu)
0以上の値を取る。急で折れ線。値がある程度(ゼロから見て)大きくなると線形になり、小さくなるとゼロに近くなる。reluは係数で0未満の値をとることもできる。
線形関数(linear)
単純に係数をかけてバイアス(偏り)を加える。
本日はここまで。次回に続きます。
参考サイト
保存ファイル
lesson49.py
文責:Luke