3

I am trying to reproduce the results for the simple grid-world environment in [1]. But it turns out that using a dynamically learned PBA makes the performance worse and I cannot obtain the results shown in Figure 1 (a) in [1] (with the same hyperparameters). Here is the result I got: enter image description here

The issue I found is that the learning procedure is stuck due to bad PBA in the early stages of training. Without PBA, Sarsa can converge well.

Did anyone try the method before? I am really puzzled and how the authors obtain these good results? There are some top conference papers using the same method in [1], for example, [2] and [3].

[1] Expressing Arbitrary Reward Functions as Potential-Based Advice

[2] Learning from demonstration for shaping through inverse reinforcement learning

[3] Policy Transfer using Reward Shaping

Is the method itself defective or anything wrong with my code? Here is part of my codes:

import copy
import numpy as np
import pandas as pd

def expert_reward(s, action):
    if (action == RIGHT) or (action == DOWN):
        return 1.0
    return 0.0

class DynamicPBA:
    def __init__(self, actions, learning_rate=0.1, reward_decay=0.99):
        self.lr = learning_rate
        self.gamma = reward_decay
        self.actions = actions
        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64) #q table for current time step
        self.q_table_ = pd.DataFrame(columns=self.actions, dtype=np.float64) #q table for the current time step
        self.check_state_exist(str((0,0)))

    def learn(self, s, a, r, s_, a_): #(s,a) denotes current state and action, r denotes reward, (s_, a_) denotes the next state and action
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        q_target = r + self.gamma * self.q_table.loc[s_, a_]
        self.q_table.loc[s, a] = self.q_table.loc[s, a] + self.lr * (q_target - q_predict)

    def update(self):
        self.q_table = copy.deepcopy(self.q_table_)

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # append new state to q table
            self.q_table = self.q_table.append(
                pd.Series(
                    [0]*len(self.actions),
                    index=self.q_table.columns,
                    name=state,
                    )
                )
            self.q_table_ = self.q_table_.append(
                pd.Series(
                    [0]*len(self.actions),
                    index=self.q_table_.columns,
                    name=state,
                    )
                )

#######Main part

RL = SarsaTable(actions=list(range(len(actions_dict))), reward_decay=0.99, learning_rate=0.05)
expert = DynamicPBA(actions=list(range(len(actions_dict))), learning_rate=0.1, reward_decay=0.99)
for episode in range(100):
    # initial observation
    s = (0,0)
    env.reset(s)
    action = RL.choose_action(str(s))

    r_episode_s = 0
    r_episode = 0

    current_step = 0
    while True:        

        # RL take action and get next observation and reward
        s_, _, reward, status = env.step(action)
        current_step += 1

        action_ = RL.choose_action(str(s_))
        # update dynamic potentials
        expert_r = -expert_reward(s, action) 
        expert.learn(str(s), action, expert_r, str(s_), action_)

        # compute PBA
        F = expert.gamma * expert.q_table_.loc[str(s_), action_] - expert.q_table.loc[str(s), action]

        #update expert PBA table
        expert.update()

        RL.learn(str(s), action, reward+F, str(s_), action_, status)

        # swap observation
        s = s_
        action = action_

        # break while loop when end of this episode
        if status != 'not_over':
            break
        if current_step>10000:
            print(episode, r_episode, r_episode_s, current_step)
            break     
        # learning rate decay
        RL.lr = RL.lr*0.999
#     expert.update()
nbro
  • 39,006
  • 12
  • 98
  • 176
bcxiao
  • 33
  • 3

1 Answers1

2

Is the method itself defective or anything wrong with my code?

There does indeed appear to be an issue with the code, the publications are fine (I know most of those authors and would very much trust their writing too :) ).

The first issue I see, and likely the most important, is that the update() calls of DynamicPBA frequently update the contents of self.q_table to those of self.q_table_, but the contents of self.q_table_ are never updated. So, your q_table is essentially always filled with a bunch of $0$ values that never get a chance to properly learn.

I did not check the rest of the implementation in detail, but at a glance it all looks fine to me. So I guess that changing the last line of learn() to update an entry in self.q_table_ rather than self.q_table should go a long way towards fixing your issue.

Dennis Soemers
  • 9,894
  • 2
  • 25
  • 66