Update: example script

This commit is contained in:
Shunichi09 2020-05-03 12:22:16 +09:00
parent 1e0c26c1de
commit 215bf5d28c
3 changed files with 35 additions and 21 deletions

View File

@ -47,7 +47,7 @@ def plot_result(history, history_g=None, ylabel="x",
axis1.legend(ncol=1, bbox_to_anchor=(0., 1.02, 1., 0.102), loc=3) axis1.legend(ncol=1, bbox_to_anchor=(0., 1.02, 1., 0.102), loc=3)
figure.savefig(path, bbox_inches="tight", pad_inches=0.05) figure.savefig(path, bbox_inches="tight", pad_inches=0.05)
def plot_results(args, history_x, history_u, history_g=None): def plot_results(history_x, history_u, history_g=None, args=None):
""" """
Args: Args:
@ -56,14 +56,21 @@ def plot_results(args, history_x, history_u, history_g=None):
Returns: Returns:
None None
""" """
plot_result(history_x, history_g=history_g, ylabel="x", env = "Env"
name= args.env + "-state_history", controller_type = "controller"
save_dir="./result/" + args.controller_type)
plot_result(history_u, history_g=np.zeros_like(history_u), ylabel="u",
name= args.env + "-input_history",
save_dir="./result/" + args.controller_type)
def save_plot_data(args, history_x, history_u, history_g=None): if args is not None:
env = args.env
controller_type = args.controller_type
plot_result(history_x, history_g=history_g, ylabel="x",
name= env + "-state_history",
save_dir="./result/" + controller_type)
plot_result(history_u, history_g=np.zeros_like(history_u), ylabel="u",
name= env + "-input_history",
save_dir="./result/" + controller_type)
def save_plot_data(history_x, history_u, history_g=None, args=None):
""" save plot data """ save plot data
Args: Args:
@ -72,16 +79,23 @@ def save_plot_data(args, history_x, history_u, history_g=None):
Returns: Returns:
None None
""" """
path = os.path.join("./result/" + args.controller_type, env = "Env"
args.env + "-history_x.pkl") controller_type = "controller"
if args is not None:
env = args.env
controller_type = args.controller_type
path = os.path.join("./result/" + controller_type,
env + "-history_x.pkl")
save_pickle(path, history_x) save_pickle(path, history_x)
path = os.path.join("./result/" + args.controller_type, path = os.path.join("./result/" + controller_type,
args.env + "-history_u.pkl") env + "-history_u.pkl")
save_pickle(path, history_u) save_pickle(path, history_u)
path = os.path.join("./result/" + args.controller_type, path = os.path.join("./result/" + controller_type,
args.env + "-history_g.pkl") env + "-history_g.pkl")
save_pickle(path, history_g) save_pickle(path, history_g)
def load_plot_data(env, controller_type, result_dir="./result"): def load_plot_data(env, controller_type, result_dir="./result"):

View File

@ -178,17 +178,17 @@ Use that histories to visualize the Animation or Figures.
```py ```py
# plot results # plot results
plot_results(args, history_x, history_u, history_g=history_g) plot_results(history_x, history_u, history_g=history_g)
save_plot_data(args, history_x, history_u, history_g=history_g) save_plot_data(history_x, history_u, history_g=history_g)
# create animation # create animation
animator = Animator(args, env) animator = Animator(env)
animator.draw(history_x, history_g) animator.draw(history_x, history_g)
``` ```
## Run Experiments ## Run Example Script
You can run the experiments as follows: You can run the example script as follows:
``` ```
python scripts/simple_run.py --env CartPole --controller CEM --save_anim 1 python scripts/simple_run.py --env CartPole --controller CEM --save_anim 1

View File

@ -37,8 +37,8 @@ def run(args):
history_x, history_u, history_g = runner.run(env, controller, planner) history_x, history_u, history_g = runner.run(env, controller, planner)
# plot results # plot results
plot_results(args, history_x, history_u, history_g=history_g) plot_results(history_x, history_u, history_g=history_g, args=args)
save_plot_data(args, history_x, history_u, history_g=history_g) save_plot_data(history_x, history_u, history_g=history_g, args=args)
if args.save_anim: if args.save_anim:
animator = Animator(env, args=args) animator = Animator(env, args=args)