PythonLinearNonlinearControl/PythonLinearNonlinearControl/envs/cartpole.py

181 lines
5.8 KiB
Python

import numpy as np
from matplotlib.axes import Axes
from .env import Env
from ..plotters.plot_objs import square
class CartPoleEnv(Env):
""" Cartpole Environment
Ref :
https://ocw.mit.edu/courses/
electrical-engineering-and-computer-science/
6-832-underactuated-robotics-spring-2009/readings/
MIT6_832s09_read_ch03.pdf
"""
def __init__(self):
"""
"""
self.config = {"state_size": 4,
"input_size": 1,
"dt": 0.02,
"max_step": 500,
"input_lower_bound": [-3.],
"input_upper_bound": [3.],
"mp": 0.2,
"mc": 1.,
"l": 0.5,
"g": 9.81,
"cart_size": (0.15, 0.1),
}
super(CartPoleEnv, self).__init__(self.config)
def reset(self, init_x=None):
""" reset state
Returns:
init_x (numpy.ndarray): initial state, shape(state_size, )
info (dict): information
"""
self.step_count = 0
theta = np.random.randn(1)
self.curr_x = np.array([0., 0., theta[0], 0.])
if init_x is not None:
self.curr_x = init_x
# goal
self.g_x = np.array([0., 0., -np.pi, 0.])
# clear memory
self.history_x = []
self.history_g_x = []
return self.curr_x, {"goal_state": self.g_x}
def step(self, u):
""" step environments
Args:
u (numpy.ndarray) : input, shape(input_size, )
Returns:
next_x (numpy.ndarray): next state, shape(state_size, )
cost (float): costs
done (bool): end the simulation or not
info (dict): information
"""
# clip action
if self.config["input_lower_bound"] is not None:
u = np.clip(u,
self.config["input_lower_bound"],
self.config["input_upper_bound"])
# step
# x
d_x0 = self.curr_x[1]
# v_x
d_x1 = (u[0] + self.config["mp"] * np.sin(self.curr_x[2])
* (self.config["l"] * (self.curr_x[3]**2)
+ self.config["g"] * np.cos(self.curr_x[2]))) \
/ (self.config["mc"] + self.config["mp"]
* (np.sin(self.curr_x[2])**2))
# theta
d_x2 = self.curr_x[3]
# v_theta
d_x3 = (-u[0] * np.cos(self.curr_x[2])
- self.config["mp"] * self.config["l"] * (self.curr_x[3]**2)
* np.cos(self.curr_x[2]) * np.sin(self.curr_x[2])
- (self.config["mc"] + self.config["mp"]) * self.config["g"]
* np.sin(self.curr_x[2])) \
/ (self.config["l"] * (self.config["mc"] + self.config["mp"]
* (np.sin(self.curr_x[2])**2)))
next_x = self.curr_x +\
np.array([d_x0, d_x1, d_x2, d_x3]) * self.config["dt"]
# TODO: costs
costs = 0.
costs += 0.1 * np.sum(u**2)
costs += 6. * self.curr_x[0]**2 \
+ 12. * (np.cos(self.curr_x[2]) + 1.)**2 \
+ 0.1 * self.curr_x[1]**2 \
+ 0.1 * self.curr_x[3]**2
# save history
self.history_x.append(next_x.flatten())
self.history_g_x.append(self.g_x.flatten())
# update
self.curr_x = next_x.flatten().copy()
# update costs
self.step_count += 1
return next_x.flatten(), costs, \
self.step_count > self.config["max_step"], \
{"goal_state": self.g_x}
def plot_func(self, to_plot, i=None, history_x=None, history_g_x=None):
""" plot cartpole object function
Args:
to_plot (axis or imgs): plotted objects
i (int): frame count
history_x (numpy.ndarray): history of state, shape(iters, state)
history_g_x (numpy.ndarray): history of goal state,
shape(iters, state)
Returns:
None or imgs : imgs order is ["cart_img", "pole_img"]
"""
if isinstance(to_plot, Axes):
imgs = {} # create new imgs
imgs["cart"] = to_plot.plot([], [], c="k")[0]
imgs["pole"] = to_plot.plot([], [], c="k", linewidth=5)[0]
imgs["center"] = to_plot.plot([], [], marker="o", c="k",
markersize=10)[0]
# centerline
to_plot.plot(np.linspace(-1., 1., num=50), np.zeros(50),
c="k", linestyle="dashed")
# set axis
to_plot.set_xlim([-1., 1.])
to_plot.set_ylim([-0.55, 1.5])
return imgs
# set imgs
cart_x, cart_y, pole_x, pole_y = \
self._plot_cartpole(history_x[i])
to_plot["cart"].set_data(cart_x, cart_y)
to_plot["pole"].set_data(pole_x, pole_y)
to_plot["center"].set_data(history_x[i][0], 0.)
def _plot_cartpole(self, curr_x):
""" plot cartpole fucntions
Args:
curr_x (numpy.ndarray): current catpole state
Returns:
cart_x (numpy.ndarray): x data of cart
cart_y (numpy.ndarray): y data of cart
pole_x (numpy.ndarray): x data of pole
pole_y (numpy.ndarray): y data of pole
"""
# cart
cart_x, cart_y = square(curr_x[0], 0.,
self.config["cart_size"], 0.)
# pole
pole_x = np.array([curr_x[0], curr_x[0] + self.config["l"]
* np.cos(curr_x[2]-np.pi/2)])
pole_y = np.array([0., self.config["l"]
* np.sin(curr_x[2]-np.pi/2)])
return cart_x, cart_y, pole_x, pole_y