import os import tempfile import imageio from stable_baselines3.common.vec_env import VecVideoRecorder import numpy as np import gymnasium as gym from stable_baselines3 import PPO from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv def generate_video(model, video_fp, video_length_in_episodes=5): eval_env = model.get_env() max_video_length_in_steps = ( video_length_in_episodes * eval_env.get_attr("spec")[0].max_episode_steps ) with tempfile.TemporaryDirectory() as temp_dp: vec_env = VecVideoRecorder( eval_env, temp_dp, record_video_trigger=lambda x: x == 0, video_length=max_video_length_in_steps, ) frame_count = 0 episode_count = 0 obs = vec_env.reset() for _ in range(max_video_length_in_steps): action, _ = model.predict(obs, deterministic=True) obs, _, dones, _ = vec_env.step(action) frame_count += 1 if dones: episode_count += 1 if episode_count >= video_length_in_episodes: break vec_env.close() temp_fp = vec_env.video_recorder.path # TODO: Fix this. # Use ffmpeg to remove the last frame (it is the first frame in a new episode). os.system( f"""ffmpeg -y -i {temp_fp} -vf "select='not(eq(n,{frame_count}))'" {video_fp} > /dev/null 2>&1""" ) # os.rename(temp_fp, file_path) def generate_gif(model, file_path, video_length_in_episodes=5): eval_env = model.get_env() max_video_length_in_steps = ( video_length_in_episodes * eval_env.get_attr("spec")[0].max_episode_steps ) render_image = lambda: eval_env.render(mode="rgb_array") images = [] episode_count = 0 obs = eval_env.reset() images.append(render_image()) for _ in range(max_video_length_in_steps): action, _ = model.predict(obs) obs, _, dones, _ = eval_env.step(action) if dones: episode_count += 1 if episode_count >= video_length_in_episodes: break images.append(render_image()) imageio.mimsave( file_path, [np.array(img) for i, img in enumerate(images) if i % 2 == 0], fps=25 ) def load_ppo_model_for_video(model_fp, env_id): env = DummyVecEnv([lambda: Monitor(gym.make(env_id, render_mode="rgb_array"))]) model = PPO.load(model_fp, env=env) return model