PythonLinearNonlinearControl/tests/configs/test_cartpole.py

31 lines
946 B
Python

import pytest
import numpy as np
from PythonLinearNonlinearControl.configs.cartpole \
import CartPoleConfigModule
class TestCalcCost():
def test_calc_costs(self):
# make config
config = CartPoleConfigModule()
# set
pred_len = 5
state_size = 4
input_size = 1
pop_size = 2
pred_xs = np.ones((pop_size, pred_len, state_size))
g_xs = np.ones((pop_size, pred_len, state_size)) * 0.5
input_samples = np.ones((pop_size, pred_len, input_size)) * 0.5
costs = config.input_cost_fn(input_samples)
assert costs.shape == (pop_size, pred_len, input_size)
costs = config.state_cost_fn(pred_xs, g_xs)
assert costs.shape == (pop_size, pred_len, 1)
costs = config.terminal_state_cost_fn(pred_xs[:, -1, :],\
g_xs[:, -1, :])
assert costs.shape == (pop_size, 1)