73 lines
2.4 KiB
Python
73 lines
2.4 KiB
Python
import pytest
|
|
import numpy as np
|
|
|
|
from PythonLinearNonlinearControl.envs.cartpole import CartPoleEnv
|
|
|
|
class TestCartPoleEnv():
|
|
"""
|
|
"""
|
|
def test_step(self):
|
|
env = CartPoleEnv()
|
|
|
|
curr_x = np.ones(4)
|
|
curr_x[2] = np.pi / 6.
|
|
|
|
env.reset(init_x=curr_x)
|
|
|
|
u = np.ones(1)
|
|
|
|
next_x, _, _, _ = env.step(u)
|
|
|
|
d_x0 = curr_x[1]
|
|
d_x1 = (1. + env.config["mp"] * np.sin(np.pi / 6.) \
|
|
* (env.config["l"] * (1.**2) \
|
|
+ env.config["g"] * np.cos(np.pi / 6.))) \
|
|
/ (env.config["mc"] + env.config["mp"] * np.sin(np.pi / 6.)**2)
|
|
d_x2 = curr_x[3]
|
|
d_x3 = (-1. * np.cos(np.pi / 6.) \
|
|
- env.config["mp"] * env.config["l"] * (1.**2) \
|
|
* np.cos(np.pi / 6.) * np.sin(np.pi / 6.) \
|
|
- (env.config["mp"] + env.config["mc"]) * env.config["g"] \
|
|
* np.sin(np.pi / 6.)) \
|
|
/ (env.config["l"] \
|
|
* (env.config["mc"] \
|
|
+ env.config["mp"] * np.sin(np.pi / 6.)**2))
|
|
|
|
expected = np.array([d_x0, d_x1, d_x2, d_x3]) * env.config["dt"] \
|
|
+ curr_x
|
|
|
|
assert next_x == pytest.approx(expected, abs=1e-5)
|
|
|
|
def test_bound_step(self):
|
|
env = CartPoleEnv()
|
|
|
|
curr_x = np.ones(4)
|
|
curr_x[2] = np.pi / 6.
|
|
|
|
env.reset(init_x=curr_x)
|
|
|
|
u = np.ones(1) * 1e3
|
|
|
|
next_x, _, _, _ = env.step(u)
|
|
|
|
u = env.config["input_upper_bound"][0]
|
|
|
|
d_x0 = curr_x[1]
|
|
d_x1 = (u + env.config["mp"] * np.sin(np.pi / 6.) \
|
|
* (env.config["l"] * (1.**2) \
|
|
+ env.config["g"] * np.cos(np.pi / 6.))) \
|
|
/ (env.config["mc"] + env.config["mp"] * np.sin(np.pi / 6.)**2)
|
|
d_x2 = curr_x[3]
|
|
d_x3 = (-u * np.cos(np.pi / 6.) \
|
|
- env.config["mp"] * env.config["l"] * (1.**2) \
|
|
* np.cos(np.pi / 6.) * np.sin(np.pi / 6.) \
|
|
- (env.config["mp"] + env.config["mc"]) * env.config["g"] \
|
|
* np.sin(np.pi / 6.)) \
|
|
/ (env.config["l"] \
|
|
* (env.config["mc"] \
|
|
+ env.config["mp"] * np.sin(np.pi / 6.)**2))
|
|
|
|
expected = np.array([d_x0, d_x1, d_x2, d_x3]) * env.config["dt"] \
|
|
+ curr_x
|
|
|
|
assert next_x == pytest.approx(expected, abs=1e-5) |