181 lines
5.8 KiB
Python
181 lines
5.8 KiB
Python
import numpy as np
|
|
from matplotlib.axes import Axes
|
|
|
|
from .env import Env
|
|
from ..plotters.plot_objs import square
|
|
|
|
|
|
class CartPoleEnv(Env):
|
|
""" Cartpole Environment
|
|
|
|
Ref :
|
|
https://ocw.mit.edu/courses/
|
|
electrical-engineering-and-computer-science/
|
|
6-832-underactuated-robotics-spring-2009/readings/
|
|
MIT6_832s09_read_ch03.pdf
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""
|
|
"""
|
|
self.config = {"state_size": 4,
|
|
"input_size": 1,
|
|
"dt": 0.02,
|
|
"max_step": 500,
|
|
"input_lower_bound": [-3.],
|
|
"input_upper_bound": [3.],
|
|
"mp": 0.2,
|
|
"mc": 1.,
|
|
"l": 0.5,
|
|
"g": 9.81,
|
|
"cart_size": (0.15, 0.1),
|
|
}
|
|
|
|
super(CartPoleEnv, self).__init__(self.config)
|
|
|
|
def reset(self, init_x=None):
|
|
""" reset state
|
|
|
|
Returns:
|
|
init_x (numpy.ndarray): initial state, shape(state_size, )
|
|
info (dict): information
|
|
"""
|
|
self.step_count = 0
|
|
|
|
theta = np.random.randn(1)
|
|
self.curr_x = np.array([0., 0., theta[0], 0.])
|
|
|
|
if init_x is not None:
|
|
self.curr_x = init_x
|
|
|
|
# goal
|
|
self.g_x = np.array([0., 0., -np.pi, 0.])
|
|
|
|
# clear memory
|
|
self.history_x = []
|
|
self.history_g_x = []
|
|
|
|
return self.curr_x, {"goal_state": self.g_x}
|
|
|
|
def step(self, u):
|
|
""" step environments
|
|
|
|
Args:
|
|
u (numpy.ndarray) : input, shape(input_size, )
|
|
Returns:
|
|
next_x (numpy.ndarray): next state, shape(state_size, )
|
|
cost (float): costs
|
|
done (bool): end the simulation or not
|
|
info (dict): information
|
|
"""
|
|
# clip action
|
|
if self.config["input_lower_bound"] is not None:
|
|
u = np.clip(u,
|
|
self.config["input_lower_bound"],
|
|
self.config["input_upper_bound"])
|
|
|
|
# step
|
|
# x
|
|
d_x0 = self.curr_x[1]
|
|
# v_x
|
|
d_x1 = (u[0] + self.config["mp"] * np.sin(self.curr_x[2])
|
|
* (self.config["l"] * (self.curr_x[3]**2)
|
|
+ self.config["g"] * np.cos(self.curr_x[2]))) \
|
|
/ (self.config["mc"] + self.config["mp"]
|
|
* (np.sin(self.curr_x[2])**2))
|
|
# theta
|
|
d_x2 = self.curr_x[3]
|
|
|
|
# v_theta
|
|
d_x3 = (-u[0] * np.cos(self.curr_x[2])
|
|
- self.config["mp"] * self.config["l"] * (self.curr_x[3]**2)
|
|
* np.cos(self.curr_x[2]) * np.sin(self.curr_x[2])
|
|
- (self.config["mc"] + self.config["mp"]) * self.config["g"]
|
|
* np.sin(self.curr_x[2])) \
|
|
/ (self.config["l"] * (self.config["mc"] + self.config["mp"]
|
|
* (np.sin(self.curr_x[2])**2)))
|
|
|
|
next_x = self.curr_x +\
|
|
np.array([d_x0, d_x1, d_x2, d_x3]) * self.config["dt"]
|
|
|
|
# TODO: costs
|
|
costs = 0.
|
|
costs += 0.1 * np.sum(u**2)
|
|
costs += 6. * self.curr_x[0]**2 \
|
|
+ 12. * (np.cos(self.curr_x[2]) + 1.)**2 \
|
|
+ 0.1 * self.curr_x[1]**2 \
|
|
+ 0.1 * self.curr_x[3]**2
|
|
|
|
# save history
|
|
self.history_x.append(next_x.flatten())
|
|
self.history_g_x.append(self.g_x.flatten())
|
|
|
|
# update
|
|
self.curr_x = next_x.flatten().copy()
|
|
# update costs
|
|
self.step_count += 1
|
|
|
|
return next_x.flatten(), costs, \
|
|
self.step_count > self.config["max_step"], \
|
|
{"goal_state": self.g_x}
|
|
|
|
def plot_func(self, to_plot, i=None, history_x=None, history_g_x=None):
|
|
""" plot cartpole object function
|
|
|
|
Args:
|
|
to_plot (axis or imgs): plotted objects
|
|
i (int): frame count
|
|
history_x (numpy.ndarray): history of state, shape(iters, state)
|
|
history_g_x (numpy.ndarray): history of goal state,
|
|
shape(iters, state)
|
|
Returns:
|
|
None or imgs : imgs order is ["cart_img", "pole_img"]
|
|
"""
|
|
if isinstance(to_plot, Axes):
|
|
imgs = {} # create new imgs
|
|
|
|
imgs["cart"] = to_plot.plot([], [], c="k")[0]
|
|
imgs["pole"] = to_plot.plot([], [], c="k", linewidth=5)[0]
|
|
imgs["center"] = to_plot.plot([], [], marker="o", c="k",
|
|
markersize=10)[0]
|
|
# centerline
|
|
to_plot.plot(np.linspace(-1., 1., num=50), np.zeros(50),
|
|
c="k", linestyle="dashed")
|
|
|
|
# set axis
|
|
to_plot.set_xlim([-1., 1.])
|
|
to_plot.set_ylim([-0.55, 1.5])
|
|
|
|
return imgs
|
|
|
|
# set imgs
|
|
cart_x, cart_y, pole_x, pole_y = \
|
|
self._plot_cartpole(history_x[i])
|
|
|
|
to_plot["cart"].set_data(cart_x, cart_y)
|
|
to_plot["pole"].set_data(pole_x, pole_y)
|
|
to_plot["center"].set_data(history_x[i][0], 0.)
|
|
|
|
def _plot_cartpole(self, curr_x):
|
|
""" plot cartpole fucntions
|
|
|
|
Args:
|
|
curr_x (numpy.ndarray): current catpole state
|
|
Returns:
|
|
cart_x (numpy.ndarray): x data of cart
|
|
cart_y (numpy.ndarray): y data of cart
|
|
pole_x (numpy.ndarray): x data of pole
|
|
pole_y (numpy.ndarray): y data of pole
|
|
"""
|
|
# cart
|
|
cart_x, cart_y = square(curr_x[0], 0.,
|
|
self.config["cart_size"], 0.)
|
|
|
|
# pole
|
|
pole_x = np.array([curr_x[0], curr_x[0] + self.config["l"]
|
|
* np.cos(curr_x[2]-np.pi/2)])
|
|
pole_y = np.array([0., self.config["l"]
|
|
* np.sin(curr_x[2]-np.pi/2)])
|
|
|
|
return cart_x, cart_y, pole_x, pole_y
|