TF-Agents 에이전트 구성하기


TF-Agents - TensorFlow 강화학습 라이브러리

강화학습의 에이전트 (Agent)환경 (Environment)에서 현재의 상태 (State)를 인식하고, 행동 (Action)을 선택합니다.

이번에는 TF-Agents 에이전트를 구성하는 과정에 대해 소개합니다.



1) Q-Network 준비하기

예제

import tensorflow as tf

from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.agents.dqn import dqn_agent
from tf_agents.networks import q_network
from tf_agents.utils import common

env_name = 'CartPole-v0'
train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

q_net = q_network.QNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=(100,)
)

print(q_net)
print(q_net.input_tensor_spec)
<tf_agents.networks.q_network.QNetwork object at 0x7f463908ec88>
BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
    dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
    dtype=float32))

tf_agents.networks.q_network 모듈의 QNetwork 클래스는 Q-Learning에 사용되는 인공신경망 (Neural Network)입니다.

예제에서는 train_env.observation_spec(), train_env.action_spec()을 인자로 입력했습니다.

이 인자들은 신경망의 입력과 출력을 결정합니다.

fc_layer_params는 신경망의 레이어별 뉴런 유닛의 개수를 지정합니다.

QNetwork 객체의 input_tensor_spec은 신경망의 입력 사양을 반환합니다.




2) 에이전트 구성하기

예제

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-3)
train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
  train_env.time_step_spec(),
  train_env.action_spec(),
  q_network=q_net,
  optimizer=optimizer,
  td_errors_loss_fn=common.element_wise_squared_loss,
  train_step_counter=train_step_counter)

agent.initialize()

print(agent)
<tf_agents.agents.dqn.dqn_agent.DqnAgent object at 0x7f4623e709b0>

tf_agents.agents 모듈의 DqnAgent 클래스는 DQN Agent를 구성하기 위해 사용합니다.

첫번째, 두번째 인자는 TimeStepAction의 사양입니다.

앞에서 구성한 Q-Network를 신경망으로 사용하고, AdamOptimizer를 옵티마이저로 합니다.

td_errors_loss_fn는 타겟과 출력값의 오차를 계산하기 위한 함수를 지정합니다.

tf.utils.common 모듈의 element_wise_squared_loss 함수는 최소제곱오차 (Mean Squared Error)를 반환합니다.

train_step_counter로 지정한 tf.Variable(0)은 훈련이 한 번 이루어질 때마다 값이 1씩 증가합니다.

DqnAgentinitilize() 메서드는 에이전트를 초기화합니다.



이전글/다음글