tk-42 commited on
Commit
7dd4c0a
·
verified ·
1 Parent(s): 83ad776

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +144 -1
README.md CHANGED
@@ -1,3 +1,146 @@
1
  ---
2
- {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ tags:
3
+ - deep-reinforcement-learning
4
+ - reinforcement-learning
5
+ - stable-baselines3
6
+ - atari
7
+ model-index:
8
+ - name: PPO Agent
9
+ results:
10
+ - task:
11
+ type: reinforcement-learning # Required. Example: automatic-speech-recognition
12
+ dataset:
13
+ type: PongNoFrameskip-v4 # Required. Example: common_voice. Use dataset id from https://hf.co/datasets
14
+ name: PongNoFrameskip-v4 # Required. Example: Common Voice zh-CN
15
+ metrics:
16
+ - type: mean_reward # Required. Example: wer
17
+ value: 20.3 # Required. Example: 20.90
18
  ---
19
+ # PPO Agent playing PongNoFrameskip-v4
20
+ This is a trained model of a **PPO agent playing PongNoFrameskip-v4 using the [stable-baselines3 library](https://stable-baselines3.readthedocs.io/en/master/index.html)** (our agent is the 🟢 one).
21
+
22
+ The training report: https://wandb.ai/simoninithomas/HFxSB3/reports/Atari-HFxSB3-Benchmark--VmlldzoxNjI3NTIy
23
+
24
+
25
+ ## Evaluation Results
26
+ Mean_reward: `20.3 +/- 0.0`
27
+
28
+ # Usage (with Stable-baselines3)
29
+ - You need to use `gymnasium==0.29.1` since it **includes Atari Roms**.
30
+ - The Action Space is 6 since we use only **possible actions in this game**.
31
+
32
+
33
+ Watch your agent interacts :
34
+
35
+ ```python
36
+ # Import the libraries
37
+ import os
38
+
39
+ import gymnasium
40
+
41
+ from stable_baselines3 import PPO
42
+ from stable_baselines3.common.vec_env import VecNormalize
43
+
44
+ from stable_baselines3.common.env_util import make_atari_env
45
+ from stable_baselines3.common.vec_env import VecFrameStack
46
+
47
+ from huggingface_sb3 import load_from_hub, push_to_hub
48
+
49
+ # Load the model
50
+ checkpoint = load_from_hub("tk-42/ppo-PongNoFrameskip-v4", "ppo-PongNoFrameskip-v4.zip")
51
+
52
+ # Because we using 3.7 on Colab and this agent was trained with 3.8 to avoid Pickle errors:
53
+ custom_objects = {
54
+ "learning_rate": 0.0,
55
+ "lr_schedule": lambda _: 0.0,
56
+ "clip_range": lambda _: 0.0,
57
+ }
58
+
59
+ model= PPO.load(checkpoint, custom_objects=custom_objects)
60
+
61
+ env = make_atari_env('PongNoFrameskip-v4', n_envs=1)
62
+ env = VecFrameStack(env, n_stack=4)
63
+
64
+ obs = env.reset()
65
+ while True:
66
+ action, _states = model.predict(obs)
67
+ obs, rewards, dones, info = env.step(action)
68
+ env.render()
69
+ ```
70
+
71
+
72
+ ## Training Code
73
+ ```python
74
+ import wandb
75
+ import gymnasium
76
+
77
+ from stable_baselines3 import PPO
78
+ from stable_baselines3.common.env_util import make_atari_env
79
+ from stable_baselines3.common.vec_env import VecFrameStack, VecVideoRecorder
80
+ from stable_baselines3.common.callbacks import CheckpointCallback
81
+
82
+ from wandb.integration.sb3 import WandbCallback
83
+
84
+ from huggingface_sb3 import load_from_hub, push_to_hub
85
+
86
+ config = {
87
+ "env_name": "PongNoFrameskip-v4",
88
+ "num_envs": 8,
89
+ "total_timesteps": int(10e6),
90
+ "seed": 4089164106,
91
+ }
92
+
93
+ run = wandb.init(
94
+ project="HFxSB3",
95
+ config = config,
96
+ sync_tensorboard = True, # Auto-upload sb3's tensorboard metrics
97
+ monitor_gym = True, # Auto-upload the videos of agents playing the game
98
+ save_code = True, # Save the code to W&B
99
+ )
100
+
101
+ # There already exists an environment generator
102
+ # that will make and wrap atari environments correctly.
103
+ # Here we are also multi-worker training (n_envs=8 => 8 environments)
104
+ env = make_atari_env(config["env_name"], n_envs=config["num_envs"], seed=config["seed"]) #PongNoFrameskip-v4
105
+
106
+ print("ENV ACTION SPACE: ", env.action_space.n)
107
+
108
+ # Frame-stacking with 4 frames
109
+ env = VecFrameStack(env, n_stack=4)
110
+ # Video recorder
111
+ env = VecVideoRecorder(env, "videos", record_video_trigger=lambda x: x % 100000 == 0, video_length=2000)
112
+
113
+ # https://github.com/DLR-RM/rl-trained-agents/blob/10a9c31e806820d59b20d8b85ca67090338ea912/ppo/PongNoFrameskip-v4_1/PongNoFrameskip-v4/config.yml
114
+ model = PPO(policy = "CnnPolicy",
115
+ env = env,
116
+ batch_size = 256,
117
+ clip_range = 0.1,
118
+ ent_coef = 0.01,
119
+ gae_lambda = 0.9,
120
+ gamma = 0.99,
121
+ learning_rate = 2.5e-4,
122
+ max_grad_norm = 0.5,
123
+ n_epochs = 4,
124
+ n_steps = 128,
125
+ vf_coef = 0.5,
126
+ tensorboard_log = f"runs",
127
+ verbose=1,
128
+ )
129
+
130
+ model.learn(
131
+ total_timesteps = config["total_timesteps"],
132
+ callback = [
133
+ WandbCallback(
134
+ gradient_save_freq = 1000,
135
+ model_save_path = f"models/{run.id}",
136
+ ),
137
+ CheckpointCallback(save_freq=10000, save_path='./pong',
138
+ name_prefix=config["env_name"]),
139
+ ]
140
+ )
141
+
142
+ model.save("ppo-PongNoFrameskip-v4.zip")
143
+ push_to_hub(repo_id="tk-42/ppo-PongNoFrameskip-v4",
144
+ filename="ppo-PongNoFrameskip-v4.zip",
145
+ commit_message="Added Pong trained agent")
146
+ ```