30 lines
810 B
Python
30 lines
810 B
Python
import unittest
|
|
from core import thunder_density
|
|
from main import read_parameter
|
|
import numpy as np
|
|
import sympy
|
|
|
|
|
|
class Testing(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
# read_parameter('default.toml')
|
|
pass
|
|
|
|
def test_thunder_density(self):
|
|
i = sympy.symbols("i")
|
|
a = np.random.random()
|
|
b = np.random.random()
|
|
p = 1-1 / (1 + (i / a) ** b)
|
|
d_p = sympy.diff(p, i)
|
|
random_i = np.random.randint(10, 100)
|
|
v_from_thunder_density = thunder_density(random_i, 0, a, b)
|
|
v_from_diff = d_p.evalf(subs={i: random_i})
|
|
self.assertTrue(
|
|
abs(v_from_thunder_density - v_from_diff) < 1e-5, "与自动微分结果不一致"
|
|
) # add assertion here
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|