Start_python’s diary

ふたり暮らし

アラフィフ夫婦のフリーランスプラン

Q-learning(Q学習)のQ値を見える化

gymの倒立振子を使って強化学習Q-learning(Q学習)第2回

 

はじめに

前回は、状態を「4つの要素を6分割」して1296通りの中から今ある状態のときの「右と左」に「報酬と罰則」を与えながら得点の高い方を選ぶやり方でした。

今回は、状態を「2つの要素を8分割と6分割」にして48通りでやってみます。

これで8✖️6✖️2の三次元配列のq_tableにできるのて、得点の変化を見ながら学習させていけそうです。

表の横が棒の角度(-0.5,~0.5)、縦が棒の角速度(-2.0~2.0)です。
上が大きいときは右に移動、下が大きいときは左に移動となります。

f:id:Start_python:20191127224706g:plain

30試行回数くらいの様子です。50試行回数超えると安定してきます。前回より試行回数が少なくて安定するのはおそらく総要素数が少ないためだと思います。

 

プログラムのコード

# coding:utf-8
# [0]ライブラリのインポート
import gym  #倒立振子(cartpole)の実行環境
import numpy as np import time import os # [1]Q関数を離散化して定義する関数 ------------ # 観測した状態を離散値にデジタル変換する def bins(clip_min, clip_max, num): return np.linspace(clip_min, clip_max, num + 1)[1:-1] # 各値を離散値に変換 def digitize_state(observation): cart_pos, cart_v, pole_angle, pole_v = observation digitized = [ #np.digitize(cart_pos, bins=bins(-2.4, 2.4, num_dizitized)), #np.digitize(cart_v, bins=bins(-3.0, 3.0, num_dizitized)), np.digitize(pole_angle, bins=bins(-0.5, 0.5, num_split1)), np.digitize(pole_v, bins=bins(-2.0, 2.0, num_split2)) ] return sum([x * (num_split1**i) for i, x in enumerate(digitized)]) # [2]行動a(t)を求める関数 ------------------------------------- def get_action(next_state, episode): #徐々に最適行動のみをとる、ε-greedy法 epsilon = 0.5 * (1 / (episode + 1)) if epsilon <= np.random.uniform(0, 1): next_action = np.argmax(q_table[next_state]) else: next_action = np.random.choice([0, 1]) return next_action # [3]Qテーブルを更新する関数 ------------------------------------- def update_Qtable(q_table, state, action, reward, next_state): gamma = 0.99 alpha = 0.5 next_Max_Q=max(q_table[next_state][0],q_table[next_state][1] ) q_table[state, action] = (1 - alpha) * q_table[state, action] + alpha * (reward + gamma * next_Max_Q) return q_table # [4]. メイン関数開始 パラメータ設定-------------------------------------------------------- env = gym.make('CartPole-v0') max_number_of_steps = 200 #1試行のstep数 num_consecutive_iterations = 100 #学習完了評価に使用する平均試行回数 num_episodes = 100 #総試行回数 goal_average_reward = 195 #この報酬を超えると学習終了(中心への制御なし) num_render = 50 #表示開始の試行回数 # 状態を分割してQ関数(表)を作成 num_split1 = 8 #分割数1 num_split2 = 6 #分割数2 q_table = np.random.uniform( low=-1, high=1, size=(num_split1*num_split2, env.action_space.n)) total_reward_vec = np.zeros(num_consecutive_iterations) #各試行の報酬を格納 np.set_printoptions(precision=1, suppress=True) #print用フォーマット # [5] メインルーチン-------------------------------------------------- for episode in range(num_episodes): #試行数分繰り返す # 環境の初期化 observation = env.reset() state = digitize_state(observation) action = np.argmax(q_table[state]) episode_reward = 0 for t in range(max_number_of_steps): #1試行のループ if episode > num_render: # cartPoleを描画する env.render() time.sleep(0.001) # 行動a_tの実行により、s_{t+1}, r_{t}などを計算する observation, reward, done, info = env.step(action) # 報酬を設定し与える if done: if t < 195: reward = -200 #こけたら罰則 else: reward = 1 #立ったまま終了時は罰則はなし else: reward = 1 #各ステップで立ってたら報酬追加 episode_reward += reward #報酬を追加 # 離散状態s_{t+1}を求め、Q関数を更新する next_state = digitize_state(observation) #t+1での観測状態を、離散値に変換 q_table = update_Qtable(q_table, state, action, reward, next_state) # 次の行動a_{t+1}を求める action = get_action(next_state, episode) # a_{t+1} state = next_state if episode > num_render: # ここにq_tableを表示 q_table_print = np.reshape(q_table, (1,num_split1*num_split2*2),'F') os.system('cls') print(np.reshape(q_table_print, (2,num_split2,num_split1))) # 棒が倒れたらやり直し if done: print('%d Episode finished after %f time steps / mean %f' % (episode, t + 1, total_reward_vec.mean())) total_reward_vec = np.hstack((total_reward_vec[1:], episode_reward)) #報酬を記録 if episode > num_render: time.sleep(1) break

 

解説

前回なら下の状態のとき、

カート位置 □□■□□□
カート速度 □■□□□□
棒の角度  □□□□■□
棒の角速度 □□□■□□

今回は下の状態になります。

棒の角度  □□□□□■□□
棒の角速度 □□□■□□

わかりやすく表にすると下の位置です。

□□□□□□□□
□□□□□□□□
□□□□□□□□
□□□□□■□□
□□□□□□□□
□□□□□□□□

横が棒の角度、縦が棒の角速度です。

この位置にある数値[2つ]の大きい方の配列番号でactionが0か1かを決めます。

最初はカートの位置と棒の角度の2要素でやってみましたが上手くいきませんでした。

 

np.set_printoptions関数

print時の表示のフォーマットを指定できます。一度宣言するだけでオッケーです。

np.set_printoptions(precision=1, suppress=True) 

小数点以下を一桁にして指数表示の禁止します。

printした画面をクリアする

Windowsの場合はこちらです。

import os

os.system('cls')

 


リスト (多次元配列)

またまた多次元配列が出てきました。リストが理解できると強化学習もかなりわかりやすいです。逆にリストが理解できなければ先にリストの勉強をした方がいいかもしれません。

np.reshape関数

配列を形状変換します。

np.reshape(q_table, (1,8*4*2),'F')

配列q_tableを1✖️64の一次元配列に変えます。「'F'」で並べ方(方向)を指定しています。

np.reshape(b, (2,8,4))

配列bを2✖️8✖️4の三次元配列に変えます。並べ方はそのままです。

 

最後にカートの位置と棒の角度の2要素でやってみたときの様子を載せておきます。

f:id:Start_python:20191127225103g:plain

明らかにマイナス得点ばかりで安定しませんでした。カートの位置よりも棒の角速度が重要だと気付かされました。

 

 

次回はSARSA法について勉強していきたいと思います。

 


参考サイト

qiita.com

code-examples.net

 

deepage.net

qiita.com

 

 

保存ファイル

lesson45.py

 

 

文責:Luke