-
Notifications
You must be signed in to change notification settings - Fork 61
Open
Description
In the notebook example/2022-12-10-textrl-elon-musk.ipynb, the reward calculation in the MyRLEnv class should be updated for correct scoring. Specifically, the function get_reward needs modification.
Current Code:
class MyRLEnv(TextRLEnv):
def get_reward(self, input_item, predicted_list, finish):
reward = 0
if finish or len(predicted_list) >= self.env_max_length:
predicted_text = tokenizer.convert_tokens_to_string(predicted_list[0])
# sentiment classifier
reward = sentiment(input_item[0] + predicted_text)[0][0]['score'] * 10
return rewardThe current code concatenates input_item[0] with the predicted text to calculate the sentiment score. However, input_item should be referenced differently to ensure proper reward calculation.
reward = sentiment(input_item['input'] + predicted_text)[0][0]['score'] * 10Metadata
Metadata
Assignees
Labels
No labels