Linear State Space Examples

This tutorial describes the support for linear and linear gaussian state space models.

At this point, the package only supports linear time-invariant models without a separate p vector. The canonical form of the linear model is

\[u_{n+1} = A u_n + B w_{n+1}\]

with

\[z_n = C u_n + v_n\]

and optionally $v_n \sim N(0, D)$ and $w_{n+1} \sim N(0,I)$. If you pass noise into the solver, it no longer needs to be Gaussian. More generally, support could be added for $u_{n+1} = A(p,n) u_n + B(p,n) w_{n+1}$ where $p$ is a vector of differentiable parameters, and the $A$ and $B$ are potentially matrix-free operators.

Simulating a Linear (and Time-Invariant) State Space Model

Creating a LinearStateSpaceProblem and simulating it for a simple, linear equation.

using DifferenceEquations, LinearAlgebra, Distributions, Random, Plots, DataFrames, Zygote
A = [0.95 6.2;
     0.0 0.2]
B = [0.0; 0.01;;] # matrix
C = [0.09 0.67;
     1.00 0.00]
D = [0.1, 0.1] # diagonal observation noise
u0 = zeros(2)
T = 10

prob = LinearStateSpaceProblem(A, B, u0, (0, T); C, observables_noise = D, syms = [:a, :b])
sol = solve(prob)
retcode: Success
Interpolation: Piecewise constant interpolation
t: 0:10
u: 11-element Vector{Vector{Float64}}:
 [0.0, 0.0]
 [0.0, -0.0066529695533883975]
 [-0.041248411231008066, -0.0003260792780927083]
 [-0.04120768219363245, -0.005171055886909522]
 [-0.07120784458278986, 0.002074506231364348]
 [-0.05478551371919141, -0.0009500929123929035]
 [-0.05793681409006784, 0.014676418931937342]
 [0.035953823992447086, 0.011935454170726788]
 [0.10815594865133082, -0.009049269610703065]
 [0.04664267963240527, 0.004410913620137401]
 [0.0716582100956369, -0.0035250997437888073]

The u vector of the simulated solution can be plotted using the standard recipes, including the use of the optional syms. See the SciML docs for more options.

plot(sol)
Example block output

By default, the solution provides an interface to access the simulated u. That is, sol.u[...] = sol[...],

sol[2]
2-element Vector{Float64}:
  0.0
 -0.0066529695533883975

Or to get the first element of the last step

sol[end][1] #first element of last step
0.0716582100956369

Finally, to extract the full vector

@show sol[2, :];  # whole second vector
11-element Vector{Float64}:
  0.0
 -0.0066529695533883975
 -0.0003260792780927083
 -0.005171055886909522
  0.002074506231364348
 -0.0009500929123929035
  0.014676418931937342
  0.011935454170726788
 -0.009049269610703065
  0.004410913620137401
 -0.0035250997437888073

The results for all of sol.u can be loaded in a dataframe, where the column names will be the (optionally) provided symbols.

df = DataFrame(sol)
11×3 DataFrame
Rowtimestampab
Int64Float64Float64
100.00.0
210.0-0.00665297
32-0.0412484-0.000326079
43-0.0412077-0.00517106
54-0.07120780.00207451
65-0.0547855-0.000950093
76-0.05793680.0146764
870.03595380.0119355
980.108156-0.00904927
1090.04664270.00441091
11100.0716582-0.0035251

Other results, such as the simulated noise and observables, can be extracted from the solution

sol.z # observables
11-element Vector{Vector{Float64}}:
 [-0.18714943809424156, -0.48925906447821027]
 [-0.2620124140597859, -0.28657934793047346]
 [-0.31424331659238075, 0.47538328392125406]
 [-0.28995703759003366, 0.6626625996704064]
 [-0.44638015131207326, -0.8941294896873483]
 [-0.2858489925740556, -0.5752895881232579]
 [0.2987040430465367, -0.12859444838078796]
 [-0.17012177350159804, 0.017811035166601923]
 [-0.013978447544399074, 0.0420250471332029]
 [0.05838652708953471, 0.1990822575059309]
 [-0.7513961009515336, 0.25007547718357326]
sol.W # Simulated Noise
1×10 Matrix{Float64}:
 -0.665297  0.100451  -0.510584  0.310872  …  -1.14364  0.622077  -0.440728

We can also solve the model by passing in fixed noise, which will be useful for joint likelihoods. First, let's extract the noise from the previous solution, then rerun the simulation but with a different initial value

noise = sol.W
u0_2 = [0.1, 0.0]
prob2 = LinearStateSpaceProblem(
    A, B, u0_2, (0, T); C, observables_noise = D, syms = [:a, :b], noise)
sol2 = solve(prob2)
plot(sol2)
Example block output

To construct an IRF we can take the model and perturb just the first element of the noise,

function irf(A, B, C, T = 20)
    noise = Matrix([1.0; zeros(T - 1)]')
    problem = LinearStateSpaceProblem(A, B, zeros(2), (0, T); C, noise, syms = [:a, :b])
    return solve(problem)
end
plot(irf(A, B, C))
Example block output

Let's find the 2nd observable at the end of the IRF.

function last_observable_irf(A, B, C)
    sol = irf(A, B, C)
    return sol.z[end][2]  # return 2nd argument of last observable
end
last_observable_irf(A, B, C)
0.03119456447624772

But everything in this package is differentiable. Let's differentiate the observable of the IRF with respect to all the parameters using Zygote.jl,

gradient(last_observable_irf, A, B, C)  # calculates gradient wrt all arguments
([0.5822985368900442 0.0050313813671367304; 4.469834483178462 0.041592752634585235], [0.37735360253530714; 3.119456447624773;;], [0.0 0.0; 0.03119456447624772 5.242880000000006e-16])

Gradients of other model elements (e.g. .u) are also possible. With this in mind, let's find the gradient of the mean of the 1st element of the IRF of the solution with respect to a particular noise vector.

function mean_u_1(A, B, C, noise, u0, T)
    problem = LinearStateSpaceProblem(A, B, u0, (0, T); noise, syms = [:a, :b])
    sol = solve(problem)
    u = sol.u # see issue #75 workaround
    # can have nontrivial functions and even non-mutating loops
    return mean(u[i][1] for i in 1:T)
end
u0 = [0.0, 0.0]
noise = sol.W # from simulation above
mean_u_1(A, B, C, noise, u0, T)
# dropping a few arguments from derivative
gradient((noise, u0) -> mean_u_1(A, B, C, noise, u0, T), noise, u0)
([0.05079876954953124 0.045314515146875 … 0.0 0.0], [0.702526121523242, 5.600882710405467])

Simulating Ensembles and Fixing Noise

If you pass in a distribution for the initial condition, it will draw an initial condition. Below, we will simulate from a deterministic evolution equation, without any observation noise.

using Distributions, DiffEqBase
u0 = MvNormal([1.0 0.1; 0.1 1.0])  # mean zero initial conditions
prob = LinearStateSpaceProblem(A, nothing, u0, (0, T); C)
sol = solve(prob)
plot(sol)
Example block output

With this, we can simulate an ensemble of solutions from different initial conditions (and we will turn back on the noise). The EnsembleSummary calculates a set of quantiles by default.

T = 10
trajectories = 50
prob = LinearStateSpaceProblem(A, B, u0, (0, T); C)
sol = solve(EnsembleProblem(prob), DirectIteration(), EnsembleThreads(); trajectories)
summ = EnsembleSummary(sol)  #calculate summarize statistics from the
plot(summ)  # shows quantiles by default
Example block output

Observables and Marginal Likelihood using a Kalman Filter

If you provide observables and provide a distribution for the observables_noise then the model can provide a calculation of the likelihood.

The simplest case is if you use a gaussian prior and have gaussian observation noise. First, let's simulate some data with included observation noise. If passing in a matrix or vector, the observables_noise argument is intended to be the cholesky of the covariance matrix. At this point, only diagonal observation noise is allowed.

u0 = MvNormal([1.0 0.1; 0.1 1.0])  # draw from mean zero initial conditions
T = 10
prob = LinearStateSpaceProblem(A, B, u0, (0, T); C, observables_noise = D, syms = [:a, :b])
sol = solve(prob)
sol.z # simulated observables with observation noise
11-element Vector{Vector{Float64}}:
 [0.598475169829116, -1.0822256484943358]
 [0.1329933415785443, 5.047461784579176]
 [0.40911401600380315, 6.2466361975746745]
 [0.6730333761302273, 5.485213123192824]
 [0.513540597955876, 5.171856705187101]
 [-0.019971275916727205, 4.916955261454044]
 [0.26829145452798275, 4.382685757101418]
 [0.48987759810652076, 4.204651144441524]
 [0.7937912121382924, 4.230423809638505]
 [0.4292808157026788, 3.877742308604573]
 [-0.15897777524759493, 3.8992974828468716]

Next, we will find the log likelihood of these simulated observables using u0 as a prior and with the true parameters.

The new arguments we pass to the problem creation are u0_prior_variance, u0_prior_mean, and observables. The u0 is ignored for the filtering problem, but must match the size. The KalmanFilter() argument to the solve is unnecessary since it can be selected automatically given the priors and observables.

Note

The timing convention is such that observables are expected to match the predictions starting at the second time period. As the likelihood of the first element u0 comes from a prior, the observables start at the next element, and hence the observables and noise sequences should be 1 less than the tspan.

observables = hcat(sol.z...)  # Observables required to be matrix.  Issue #55
observables = observables[:, 2:end] # see note above on likelihood and timing
noise = copy(sol.W) # save for later
u0_prior_mean = [0.0, 0.0]
# use covariance of distribution we drew from
u0_prior_var = cov(u0)

prob = LinearStateSpaceProblem(A, B, u0, (0, size(observables, 2)); C, observables,
    observables_noise = D, syms = [:a, :b], u0_prior_var, u0_prior_mean)
sol = solve(prob, KalmanFilter())
# plot(sol) The `u` is the sequence of posterior means.
sol.logpdf
-7.839964529508742

Hence, the logpdf provides the log likelihood marginalizing out the latent noise variables.

As before, we can differentiate the kalman filter itself.

function kalman_likelihood(A, B, C, D, u0_prior_mean, u0_prior_var, observables)
    prob = LinearStateSpaceProblem(A, B, u0, (0, size(observables, 2)); C, observables,
        observables_noise = D, syms = [:a, :b], u0_prior_var, u0_prior_mean)
    return solve(prob).logpdf
end
kalman_likelihood(A, B, C, D, u0_prior_mean, u0_prior_var, observables)
# Find the gradient wrt the A, B, C and priors variance.
gradient(
    (A, B, C, u0_prior_var) -> kalman_likelihood(
        A, B, C, D, u0_prior_mean, u0_prior_var, observables),
    A,
    B,
    C,
    u0_prior_var)
([-133.5095014493091 0.01625470364401546; -994.1287072958926 -17.60179936944906], [2.5912522479545523; 16.46829428759027;;], [-40.00627001093838 -0.6106697665931414; 3.6559842551553743 0.27602777474578843], [-0.033380711375140246 -0.0718773558763173; 0.2858994991572116 -0.24722736911425042])
Note

Some gradients, such as those for observables, have not been implemented, so test carefully. This is a general theme with gradients and Zygote.jl in general. Your best friend in this process is the spectacular ChainRulesTestUtils.jl package. See test_rrule usage in the linear unit tests.

Joint Likelihood with Noise

A key application of these methods is to find the joint likelihood of the latent variables (i.e., the noise) and the model definition.

The actual calculation of the likelihood is trivial in that case, and just requires iteration of the linear system while accumulating the likelihood given the observation noise.

Crucially, the differentiability with respect to the high-dimensional noise vector enables gradient-based sampling and estimation methods that would otherwise be infeasible.

function joint_likelihood(A, B, C, D, u0, noise, observables)
    prob = LinearStateSpaceProblem(
        A, B, u0, (0, size(observables, 2)); C, observables, observables_noise = D, noise)
    return solve(prob).logpdf
end
u0 = [0.0, 0.0]
joint_likelihood(A, B, C, D, u0, noise, observables)
-1185.1078766525466

And as always, this can be differentiated with respect to the state-space matrices and the noise. Choosing a few parameters,

gradient(
    (A, u0, noise) -> joint_likelihood(A, B, C, D, u0, noise, observables), A, u0, noise)
([-90.52563586793134 -4.353483695771945; -521.7372931543744 -27.606445302959873], [375.60284243376253, 2996.329148709774], [27.25131937815565 23.70589211193122 … 2.472661536452844 -0.010303016357891895])

Composition of State Space Models and AD

While the above gradients have been with respect to the full state space objects A, B, etc. those themselves could be generated through a separate procedure and the whole object differentiated. For example, let's repeat the above examples where we generate the A matrix from some sort of deep parameters.

First, we will generate some observations with a generate_model proxy, which could be replaced with something more complicated but still differentiable

function generate_model(β)
    A = [β 6.2;
         0.0 0.2]
    B = Matrix([0.0 0.001]') # [0.0; 0.001;;] gives a zygote bug
    C = [0.09 0.67;
         1.00 0.00]
    D = [0.01, 0.01]
    return (; A, B, C, D)
end

function simulate_model(β, u0; T = 200)
    mod = generate_model(β)
    prob = LinearStateSpaceProblem(
        mod.A, mod.B, u0, (0, T); mod.C, observables_noise = mod.D)
    sol = solve(prob) # simulates
    observables = hcat(sol.z...)
    observables = observables[:, 2:end] # see note above on likelihood and timing
    return observables, sol.W
end

# Fix a "pseudo-true" and generate noise and observables
β = 0.95
u0 = [0.0, 0.0]
observables, noise = simulate_model(β, u0)
([-0.09954709114589842 0.005232798392808081 … 0.20762700975229048 0.05441666003670729; 0.009244093512761615 0.09910139963314891 … 0.11489849804601801 0.23434696840661354], [1.113706766440712 -0.23095703333715878 … -0.2006356175209672 -0.09819817554657458])

Next, we will evaluate the marginal likelihood using the kalman filter for a particular β value,

function kalman_model_likelihood(β, u0_prior_mean, u0_prior_var, observables)
    mod = generate_model(β) # generate model from structural parameters
    prob = LinearStateSpaceProblem(
        mod.A, mod.B, u0, (0, size(observables, 2)); mod.C, observables,
        observables_noise = mod.D, u0_prior_var, u0_prior_mean)
    return solve(prob).logpdf
end
u0_prior_mean = [0.0, 0.0]
u0_prior_var = [1e-10 0.0;
                0.0 1e-10]  # starting with degenerate prior
kalman_model_likelihood(β, u0_prior_mean, u0_prior_var, observables)
339.974308955627

Given the observation error, we would not expect the pseudo-true to exactly maximize the log likelihood. To show this, we can optimize it using the Optim package, specifically using a gradient-based optimization routine

using Optimization, OptimizationOptimJL
# Create a function to minimize only of β and use Zygote based gradients
function kalman_objective(β, p)
    -kalman_model_likelihood(β, u0_prior_mean, u0_prior_var, observables)
end
kalman_objective(0.95, nothing)
gradient(β -> kalman_objective(β, nothing), β) # Verifying it can be differentiated

optf = OptimizationFunction(kalman_objective, Optimization.AutoZygote())
β0 = [0.91] # start off of the pseudotrue
optprob = OptimizationProblem(optf, β0)
optsol = solve(optprob, LBFGS())  # reverse-mode AD is overkill here
retcode: Success
u: 1-element Vector{Float64}:
 0.9900786340744506

In this way, this package composes with others such as DifferentiableStateSpaceModels.jl which takes a set of structural parameters and an expected difference equation to generate a state-space model.

Similarly, we can find the joint likelihood for a particular β value and noise. Here we will add in prior. Some form of prior or regularization is generally necessary for these sorts of nonlinear models.

function joint_model_posterior(β, u0, noise, observables, noise_prior, β_prior)
    mod = generate_model(β) # generate model from structural parameters
    prob = LinearStateSpaceProblem(mod.A, mod.B, u0, (0, size(observables, 2)); mod.C,
        observables, observables_noise = mod.D, noise)
    return solve(prob).logpdf + sum(logpdf.(noise_prior, noise)) + logpdf(β_prior, β) # posterior
end
u0 = [0.0, 0.0]
noise_prior = Normal(0.0, 1.0)
β_prior = Normal(β, 0.03) # prior local to the true value
joint_model_posterior(β, u0, noise, observables, noise_prior, β_prior)
64.07140511636317

Which we can turn into a differentiable objective by adding in a prior on the noise

function joint_model_objective(x, p)
    -joint_model_posterior(x[1], u0, Matrix(x[2:end]'), observables, noise_prior, β_prior)
end # extract noise and parameeter from vector
x0 = vcat([0.95], noise[1, :])  # starting at the true noise
joint_model_objective(x0, nothing)
gradient(x -> joint_model_objective(x, nothing), x0) # Verifying it can be differentiated

# optimize
optf = OptimizationFunction(joint_model_objective, Optimization.AutoZygote())
optprob = OptimizationProblem(optf, x0)
optsol = solve(optprob, LBFGS())
retcode: Success
u: 201-element Vector{Float64}:
  0.9974638076163224
 -0.061918324475983354
 -0.08909884072072614
  0.051033663315531656
  0.0317249889527825
 -0.08397498660787954
 -0.11553250279515634
 -0.1130319864113659
 -0.14445809898828096
 -0.21706056871830864
  ⋮
  0.08517384748002621
  0.07135785076995982
  0.04317100161435068
  0.10106197543690171
  0.0741233383247633
  0.14491287211656687
  0.1658242261886731
  0.11809470574069815
  0.0032163881613767463

This "solves" the problem relatively quickly, despite the high-dimensionality. However, from a statistics perspective note that this last optimization process does not do especially well in recovering the pseudotrue if you increase the prior variance on the β parameter. Maximizing the posterior is usually the wrong thing to do in high-dimensions because the mode is not a typical set.

Caveats on Gradients and Performance

A few notes on performance and gradients:

  1. As this is using reverse-mode AD it will be efficient for fairly large systems as long as the ultimate value of your differentiable program. With a little extra work and unit tests, it could support structured matrices/etc. as well.
  2. Getting to much higher scales, where the A,B,C,D are so large that matrix-free operators are necessary, is feasible but will require generalizing those to LinearOperators. This would be reasonably easy for joint likelihood and feasible but possible for the Kalman filter.
  3. At this point, there is no support for forward-mode auto-differentiation. For smaller systems with a kalman filter, this should dominate the alternatives, and efficient forward-mode AD rules for the kalman filter exist (see the supplementary materials in the Differentiable State Space Models paper). However, it would be a significant amount of work to add end-to-end support and fulfill standard SciML interfaces, and perhaps waiting for Enzyme or similar AD systems that provide both forward/reverse/mixed mode makes sense.
  4. Forward-mode AD is likely inappropriate for the joint-likelihood based models, since the dimensionality of the noise is always large.
  5. The gradient rules are written using ChainRules.jl so in theory they will work with any supporting AD. In practice, though, Zygote is the most tested, and other systems have inconsistent support for Julia at this time.