from pycppad import *
import time
def pycppad_test_runge_kutta_4_cpp() :
x_1 = 0; # use this variable to switch x_1 between float and ad(float)
def fun(t , y) :
f = x_1 * y
return f
# Number of Runge-Kutta times steps to include in the function object
M = 100
# Start time for recording the pycppad function object
s0 = time.time()
# Declare three independent variables. The operation sequence does not
# depend on x, so we could use any value here.
x = numpy.array( [.1, .1, .1] )
a_x = independent( numpy.array( x ) )
# First independent variables, x[0], is the value of y(0)
a_y = numpy.array( [ a_x[0] ] )
# Make x_1 a variable so can use rk4 with various coefficients.
x_1 = a_x[1]
# Make dt a variable so can use rk4 with various step sizes.
dt = a_x[2]
# f(t, y) does not depend on t, so no need to make t a variable.
t = ad(0.)
# Record the operations for 10 time step
for k in range(M) :
a_y = runge_kutta_4(fun, t, a_y, dt)
t = t + dt
# define the AD function rk4 : x -> y
rk4 = adfun(a_x, a_y)
# amount of time it took to tape this function object
tape_sec = time.time() - s0
# make the fucntion object more efficient
s0 = time.time()
rk4.optimize()
opt_sec = time.time() - s0
ti = 0. # initial time
tf = 1. # final time
N = M * 100 # number of time steps
dt = (tf - ti) / N # size of time step
x_0 = 2. # use this for initial value of y(t)
x_1 = .5 # use this for coefficient in ODE
# python version of integrator with float values
s0 = time.time()
t = ti
y = numpy.array( [ x_0 ] );
for k in range(N) :
y = runge_kutta_4(fun, t, y, dt)
t = t + dt
# number of seconds to solve the ODE using python float
python_sec = time.time() - s0
# check solution is correct
assert( abs( y[0] - x_0 * exp( x_1 * tf ) ) < 1e-10 )
# pycppad function object version of integrator
s0 = time.time()
t = ti
x = numpy.array( [ x_0 , x_1 , dt ] )
for k in range(N/M) :
y = rk4.forward(0, x);
x[0] = y[0];
# number of seconds to solve the ODE using pycppad function object
cpp_sec = time.time() - s0
# check solution is correct
assert( abs( y[0] - x_0 * exp( x_1 * tf ) ) < 1e-10 )
# Uncomment the print statement below to see actual times on your machine
format = 'cpp_sec = %8f,\n'
format = format + 'python_sec/cpp_sec = %5.1f\n'
format = format + 'tape_sec/cpp_sec = %5.1f\n'
format = format + 'opt_sec/cpp_sec = %5.1f'
s = cpp_sec
# print format % (s, python_sec/s, tape_sec/s, opt_sec/s )
# check that C++ is always more than 30 times faster (on all systems)
assert( 30. * cpp_sec <= python_sec )