Update: example script
This commit is contained in:
parent
1e0c26c1de
commit
215bf5d28c
|
@ -47,7 +47,7 @@ def plot_result(history, history_g=None, ylabel="x",
|
||||||
axis1.legend(ncol=1, bbox_to_anchor=(0., 1.02, 1., 0.102), loc=3)
|
axis1.legend(ncol=1, bbox_to_anchor=(0., 1.02, 1., 0.102), loc=3)
|
||||||
figure.savefig(path, bbox_inches="tight", pad_inches=0.05)
|
figure.savefig(path, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
|
||||||
def plot_results(args, history_x, history_u, history_g=None):
|
def plot_results(history_x, history_u, history_g=None, args=None):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -56,14 +56,21 @@ def plot_results(args, history_x, history_u, history_g=None):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
plot_result(history_x, history_g=history_g, ylabel="x",
|
env = "Env"
|
||||||
name= args.env + "-state_history",
|
controller_type = "controller"
|
||||||
save_dir="./result/" + args.controller_type)
|
|
||||||
plot_result(history_u, history_g=np.zeros_like(history_u), ylabel="u",
|
|
||||||
name= args.env + "-input_history",
|
|
||||||
save_dir="./result/" + args.controller_type)
|
|
||||||
|
|
||||||
def save_plot_data(args, history_x, history_u, history_g=None):
|
if args is not None:
|
||||||
|
env = args.env
|
||||||
|
controller_type = args.controller_type
|
||||||
|
|
||||||
|
plot_result(history_x, history_g=history_g, ylabel="x",
|
||||||
|
name= env + "-state_history",
|
||||||
|
save_dir="./result/" + controller_type)
|
||||||
|
plot_result(history_u, history_g=np.zeros_like(history_u), ylabel="u",
|
||||||
|
name= env + "-input_history",
|
||||||
|
save_dir="./result/" + controller_type)
|
||||||
|
|
||||||
|
def save_plot_data(history_x, history_u, history_g=None, args=None):
|
||||||
""" save plot data
|
""" save plot data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -72,16 +79,23 @@ def save_plot_data(args, history_x, history_u, history_g=None):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
path = os.path.join("./result/" + args.controller_type,
|
env = "Env"
|
||||||
args.env + "-history_x.pkl")
|
controller_type = "controller"
|
||||||
|
|
||||||
|
if args is not None:
|
||||||
|
env = args.env
|
||||||
|
controller_type = args.controller_type
|
||||||
|
|
||||||
|
path = os.path.join("./result/" + controller_type,
|
||||||
|
env + "-history_x.pkl")
|
||||||
save_pickle(path, history_x)
|
save_pickle(path, history_x)
|
||||||
|
|
||||||
path = os.path.join("./result/" + args.controller_type,
|
path = os.path.join("./result/" + controller_type,
|
||||||
args.env + "-history_u.pkl")
|
env + "-history_u.pkl")
|
||||||
save_pickle(path, history_u)
|
save_pickle(path, history_u)
|
||||||
|
|
||||||
path = os.path.join("./result/" + args.controller_type,
|
path = os.path.join("./result/" + controller_type,
|
||||||
args.env + "-history_g.pkl")
|
env + "-history_g.pkl")
|
||||||
save_pickle(path, history_g)
|
save_pickle(path, history_g)
|
||||||
|
|
||||||
def load_plot_data(env, controller_type, result_dir="./result"):
|
def load_plot_data(env, controller_type, result_dir="./result"):
|
||||||
|
|
10
README.md
10
README.md
|
@ -178,17 +178,17 @@ Use that histories to visualize the Animation or Figures.
|
||||||
|
|
||||||
```py
|
```py
|
||||||
# plot results
|
# plot results
|
||||||
plot_results(args, history_x, history_u, history_g=history_g)
|
plot_results(history_x, history_u, history_g=history_g)
|
||||||
save_plot_data(args, history_x, history_u, history_g=history_g)
|
save_plot_data(history_x, history_u, history_g=history_g)
|
||||||
|
|
||||||
# create animation
|
# create animation
|
||||||
animator = Animator(args, env)
|
animator = Animator(env)
|
||||||
animator.draw(history_x, history_g)
|
animator.draw(history_x, history_g)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Run Experiments
|
## Run Example Script
|
||||||
|
|
||||||
You can run the experiments as follows:
|
You can run the example script as follows:
|
||||||
|
|
||||||
```
|
```
|
||||||
python scripts/simple_run.py --env CartPole --controller CEM --save_anim 1
|
python scripts/simple_run.py --env CartPole --controller CEM --save_anim 1
|
||||||
|
|
|
@ -37,8 +37,8 @@ def run(args):
|
||||||
history_x, history_u, history_g = runner.run(env, controller, planner)
|
history_x, history_u, history_g = runner.run(env, controller, planner)
|
||||||
|
|
||||||
# plot results
|
# plot results
|
||||||
plot_results(args, history_x, history_u, history_g=history_g)
|
plot_results(history_x, history_u, history_g=history_g, args=args)
|
||||||
save_plot_data(args, history_x, history_u, history_g=history_g)
|
save_plot_data(history_x, history_u, history_g=history_g, args=args)
|
||||||
|
|
||||||
if args.save_anim:
|
if args.save_anim:
|
||||||
animator = Animator(env, args=args)
|
animator = Animator(env, args=args)
|
||||||
|
|
Loading…
Reference in New Issue