Ensemble Modeling

Ensemble modeling is the process of building predictors which are combinations of predictive models. In this tutorial we will show how to use EMA.jl to build such ensemble models.

The Predictive Models

For this tutorial we will use a set of SIR-type models as the basis. In particular, we will use a basic SIR model, an SIRHD, and an SIRHD model with vaccintation. The construction of the models is as follows:

using EasyModelAnalysis, LinearAlgebra

@parameters t β=0.05 c=10.0 γ=0.25
@variables S(t)=990.0 I(t)=10.0 R(t)=0.0
∂ = Differential(t)
N = S + I + R # This is recognized as a derived variable
eqs = [∂(S) ~ -β * c * I / N * S,
    ∂(I) ~ β * c * I / N * S - γ * I,
    ∂(R) ~ γ * I];

@named sys = ODESystem(eqs);
tspan = (0, 30)
prob = ODEProblem(sys, [], tspan);

@parameters t β=0.1 c=10.0 γ=0.25 ρ=0.1 h=0.1 d=0.1 r=0.1
@variables S(t)=990.0 I(t)=10.0 R(t)=0.0 H(t)=0.0 D(t)=0.0
∂ = Differential(t)
N = S + I + R + H + D # This is recognized as a derived variable
eqs = [∂(S) ~ -β * c * I / N * S,
    ∂(I) ~ β * c * I / N * S - γ * I - h * I - ρ * I,
    ∂(R) ~ γ * I + r * H,
    ∂(H) ~ h * I - r * H - d * H,
    ∂(D) ~ ρ * I + d * H];

@named sys2 = ODESystem(eqs);

prob2 = ODEProblem(sys2, [], tspan);

@parameters t β=0.1 c=10.0 γ=0.25 ρ=0.1 h=0.1 d=0.1 r=0.1 v=0.1
@parameters t β2=0.1 c2=10.0 ρ2=0.1 h2=0.1 d2=0.1 r2=0.1
@variables S(t)=990.0 I(t)=10.0 R(t)=0.0 H(t)=0.0 D(t)=0.0
@variables Sv(t)=0.0 Iv(t)=0.0 Rv(t)=0.0 Hv(t)=0.0 Dv(t)=0.0
@variables I_total(t)

∂ = Differential(t)
N = S + I + R + H + D + Sv + Iv + Rv + Hv + Dv # This is recognized as a derived variable
eqs = [∂(S) ~ -β * c * I_total / N * S - v * Sv,
    ∂(I) ~ β * c * I_total / N * S - γ * I - h * I - ρ * I,
    ∂(R) ~ γ * I + r * H,
    ∂(H) ~ h * I - r * H - d * H,
    ∂(D) ~ ρ * I + d * H,
    ∂(Sv) ~ -β2 * c2 * I_total / N * Sv + v * Sv,
    ∂(Iv) ~ β2 * c2 * I_total / N * Sv - γ * Iv - h2 * Iv - ρ2 * Iv,
    ∂(Rv) ~ γ * I + r2 * H,
    ∂(Hv) ~ h2 * I - r2 * H - d2 * H,
    ∂(Dv) ~ ρ2 * I + d2 * H,
    I_total ~ I + Iv
];

@named sys3 = ODESystem(eqs)
sys3 = structural_simplify(sys3)
prob3 = ODEProblem(sys3, [], tspan);
ODEProblem with uType Vector{Float64} and tType Int64. In-place: true
timespan: (0, 30)
u0: 10-element Vector{Float64}:
 990.0
  10.0
   0.0
   0.0
   0.0
   0.0
   0.0
   0.0
   0.0
   0.0

Representing Ensemble Models with the SciML EnsembleProblem

The SciML libraries allow for what's known as an EnsembleProblem, which is an object that solves many simultainous problems and represents the aggregate solution. This object is documented in the DifferentialEquations.jl documentation and has all kinds of features, such as automated GPU acceleration, though we will instead focus just on the subset of features required for this demonstration. To build an EnsembleProblem, the main object is the prob_func, which is a function of (prob,i,repeat) which describes what the ith problem should be. The prob in this case is a prototype problem, which we are effectively ignoring for our use case.

Thus a simple EnsembleProblem which ensembles the three models built above is as follows:

probs = [prob, prob2, prob3]
enprob = EnsembleProblem(probs)
EnsembleProblem with problem Array

Here, prob_func returns model i on the ith iteration, and thus if we solve with 3 trajectories we will get the solution to all three models. This looks like:

sol = solve(enprob; saveat = 1);
EnsembleSolution Solution of length 3 with uType:
ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#629#generated_observed#555"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, CompositeAlgorithm{Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Rosenbrock23{1, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, OrdinaryDiffEq.AutoSwitchCache{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Rosenbrock23{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}, Rational{Int64}, Int64}}, OrdinaryDiffEq.CompositeInterpolationData{ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#629#generated_observed#555"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.CompositeCache{Tuple{OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.Rosenbrock23Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, Matrix{Float64}, Matrix{Float64}, OrdinaryDiffEq.Rosenbrock23Tableau{Float64}, SciMLBase.TimeGradientWrapper{ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#629#generated_observed#555"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Vector{Float64}, Vector{Float64}}, SciMLBase.UJacobianWrapper{ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#629#generated_observed#555"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Float64, Vector{Float64}}, LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LinearSolve.DefaultLinearSolver, LinearSolve.DefaultLinearSolverInit{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, LinearAlgebra.QRCompactWY{Float64, Matrix{Float64}, Matrix{Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Int64}}, Nothing, Nothing, Nothing, LinearAlgebra.SVD{Float64, Float64, Matrix{Float64}, Vector{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int32}}, Base.RefValue{Int32}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Base.RefValue{Int64}}}, LinearSolve.InvPreconditioner{LinearAlgebra.Diagonal{Float64, Vector{Float64}}}, LinearAlgebra.Diagonal{Float64, Vector{Float64}}, Float64, Bool}, FiniteDiff.JacobianCache{Vector{Float64}, Vector{Float64}, Vector{Float64}, Vector{Float64}, UnitRange{Int64}, Nothing, Val{:forward}(), Float64}, FiniteDiff.GradientCache{Nothing, Vector{Float64}, Vector{Float64}, Float64, Val{:forward}(), Float64, Val{true}()}, Float64, Rosenbrock23{1, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}, Nothing}}, OrdinaryDiffEq.AutoSwitchCache{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Rosenbrock23{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}, Rational{Int64}, Int64}}}, DiffEqBase.Stats, Vector{Int64}}

We can access the 3 solutions as sol[i] respectively. Let's get the time series for S from each of the models:

sol[:, S]

Building a Dataset

Now let's build a dataset from our ensemble model. We will make our dataset for S, I, and R by taking a linear combination of our models and using the aforementioned interface on the ensemble solution.

weights = [0.2, 0.5, 0.3]
data = [
    S => vec(sum(stack(weights .* sol[:, S]), dims = 2)),
    I => vec(sum(stack(weights .* sol[:, I]), dims = 2)),
    R => vec(sum(stack(weights .* sol[:, R]), dims = 2))
]
plot(sol; idxs = S)
scatter!(data[1][2])
plot(sol; idxs = I)
scatter!(data[2][2])
plot(sol; idxs = R)
scatter!(data[3][2])

Now let's split that into training, ensembling, and forecast sections:

fullS = vec(sum(stack(weights .* sol[:, S]), dims = 2))
fullI = vec(sum(stack(weights .* sol[:, I]), dims = 2))
fullR = vec(sum(stack(weights .* sol[:, R]), dims = 2))

t_train = 0:14
data_train = [
    S => (t_train, fullS[1:15]),
    I => (t_train, fullI[1:15]),
    R => (t_train, fullR[1:15])
]
t_ensem = 0:21
data_ensem = [
    S => (t_ensem, fullS[1:22]),
    I => (t_ensem, fullI[1:22]),
    R => (t_ensem, fullR[1:22])
]
t_forecast = 0:30
data_forecast = [
    S => (t_forecast, fullS),
    I => (t_forecast, fullI),
    R => (t_forecast, fullR)
]

Bayesian Calibration

Now let's perform a Bayesian calibration on each of the models. This gives us multiple parameterizations for each model, which then gives an ensemble which is parameterizations x models in size.

probs = [prob, prob2, prob3]
ps = [[β => Uniform(0.01, 10.0), γ => Uniform(0.01, 10.0)] for i in 1:3]
datas = [data_train, data_train, data_train]
enprobs = bayesian_ensemble(probs, ps, datas)

Let's see how each of our models in the ensemble compare against the data when changed to use the fit parameters:

sol = solve(enprobs);

plot(sol; idxs = S)
scatter!(t_train, data_train[1][2][2])
plot(sol; idxs = I)
scatter!(t_train, data_train[2][2][2])
plot(sol; idxs = R)
scatter!(t_train, data_train[3][2][2])

Training the Ensemble Model

Now let's train the ensemble model. We will do that by solving a bit further than the calibration step. Let's build that solution data:

plot(sol; idxs = S)
scatter!(t_ensem, data_ensem[1][2][2])

We can obtain the optimal weights for ensembling by solving a linear regression of the solution's data against the wanted trajectory:

sol = solve(enprobs; saveat = t_ensem);
ensem_weights = ensemble_weights(sol, data_ensem)

Now we can extrapolate forward with these ensemble weights as follows:

sol = solve(enprobs; saveat = t_ensem);
ensem_prediction = sum(stack(ensem_weights .* sol[:, S]), dims = 2)
plot(sol; idxs = S, color = :blue)
plot!(t_ensem, ensem_prediction, lw = 5, color = :red)
scatter!(t_ensem, data_ensem[1][2][2])
ensem_prediction = sum(stack(ensem_weights .* sol[:, I]), dims = 2)
plot(sol; idxs = I, color = :blue)
plot!(t_ensem, ensem_prediction, lw = 3, color = :red)
scatter!(t_ensem, data_ensem[2][2][2])

Forecasting the Trained Ensemble

Once we have obtained the ensemble model, we can forecast ahead with it:

forecast_probs = [remake(enprobs.prob[i]; tspan = (t_train[1], t_forecast[end]))
                  for i in 1:length(enprobs.prob)]
fit_enprob = EnsembleProblem(forecast_probs)

sol = solve(fit_enprob; saveat = t_forecast);
ensem_prediction = sum(stack(ensem_weights .* sol[:, S]), dims = 2)
plot(sol; idxs = S, color = :blue)
plot!(t_forecast, ensem_prediction, lw = 3, color = :red)
scatter!(t_forecast, data_forecast[1][2][2])
ensem_prediction = sum(
    stack([ensem_weights[i] * sol[i][I] for i in 1:length(forecast_probs)]), dims = 2)
plot(sol; idxs = I, color = :blue)
plot!(t_forecast, ensem_prediction, lw = 3, color = :red)
scatter!(t_forecast, data_forecast[2][2][2])
ensem_prediction = sum(
    stack([ensem_weights[i] * sol[i][R] for i in 1:length(forecast_probs)]), dims = 2)
plot(sol; idxs = R, color = :blue)
plot!(t_forecast, ensem_prediction, lw = 3, color = :red)
scatter!(t_forecast, data_forecast[3][2][2])