Making Neuroblox models work with GraphDynamics

In this notebook I'll show some toy examples of Neuroblox systems and make them interoperate with GraphDynamics.

There are some functions in Neuroblox.GraphDynamicsInterop to try and automate this process, but I'm going to show you here the "manual" way of doing this, which is typically going to be better supported and more powerful / robust.

using Neuroblox, ModelingToolkit, GraphDynamics
using Neuroblox.GraphDynamicsInterop

using Neuroblox:
    paramscoping,
    NeuralMassBlox,
    get_namespaced_sys,
    generate_weight_param,
    Connector

using CairoMakie, StochasticDiffEq

Van der Pol

First, let's consider something that's suspiciously like the noisy Van der Pol oscillator in Neuroblox/src/blox/neural_mass.jl, except:

  • It has 2 inputs instead of only 1 (I added jcn_x as an additional complication)
  • It has some 'computed variables' (i.e. variables which don't really exist in the system solution, but can be calulated based on its parameters and states)
struct VanDerPol <: NeuralMassBlox
    params
    system
    namespace
    function VanDerPol(; name, namespace=nothing, θ=1.0, ϕ=0.1)
        p = paramscoping(θ=θ, ϕ=ϕ)
        θ, ϕ = p
        sts = @variables begin
            # our regular dynamical variables
            x(t)=0.0
            [output=true]
            y(t)=0.0

            # our algebraic inputs (these are really computed variables too)
            jcn_x(t)
            [input=true]
            jcn(t)
            [input=true]

            # our extra computed variables
            jcn_tot(t)
            r(t)
        end

        @brownian ξ

        eqs = [
            # Dynamical equations
            D(x) ~ y + jcn_x,
            D(y) ~ θ*(1-x^2)*y - x + ϕ*ξ + jcn,

            # Extra computed variables
            jcn_tot ~ jcn_x + jcn,
            r ~ √(x^2 + y^2)
        ]
        sys = System(eqs, t, sts, p; name=name)
        new(p, sys, namespace)
    end
end;

GraphDynamics version

Marking as supported

First, lets mark VanDerPol as a supported type by defining issupported.

We'll also tell GraphDynamicsInterop that it's not a composite block by defining a method for components that just returns a Tuple containing the oscillator itself:

GraphDynamicsInterop.issupported(::VanDerPol) = true
GraphDynamicsInterop.components(v::VanDerPol) = (v,)

Converting to Subsystem

GraphDynamics.jl uses a type called Subsystem which is really a bundle around a SubsystemStates which stores the dynamical states of an object, and SubsystemParams which stores whatever parameters might affect the evolution of the object (or be important to it in some other way).

We need to teach GraphDynamicsInterop how to convert a VanDerPol into a Subsystem{VanDerPol}

function GraphDynamicsInterop.to_subsystem(v::VanDerPol)
    # Extract the default values of the parameters θ and ϕ
    θ = GraphDynamicsInterop.recursive_getdefault(v.θ)
    ϕ = GraphDynamicsInterop.recursive_getdefault(v.ϕ)

    params = SubsystemParams{VanDerPol}(; θ, ϕ)

    # Extract the default values of dynamical states x and y
    x = GraphDynamicsInterop.recursive_getdefault(v.x)
    y = GraphDynamicsInterop.recursive_getdefault(v.y)
    states = SubsystemStates{VanDerPol}(; x, y)

    # Form a Subsystem from the states and params
    Subsystem(states, params)
end

Here it is in action:

let @named v = VanDerPol(;θ=10.0)
    sys = GraphDynamicsInterop.to_subsystem(v)
    @info "States and params out of a Subsystem:" sys.x sys.y sys.θ sys.ϕ
end
┌ Info: States and params out of a Subsystem:
│   sys.x = 0.0
│   sys.y = 0.0
│   sys.θ = 10.0
└   sys.ϕ = 0.1

Inputs

Now let's define what an "input" to a VanDerPol must look like. In the MTK definition above, we said it had two possible inputs, jcn_x and jcn, so we'll make the "zero input" to a VanDerPol be a NamedTuple with those two names as keys:

GraphDynamics.initialize_input(s::Subsystem{VanDerPol}) = (;jcn_x = 0.0, jcn = 0.0)

The subsystem differential

Now we can get to the interesting stuff: defining the differential equations of a Subsystem.

The idea is that we add a method to GraphDynamics.subsustem_differential that takes in the subsystem, whatever inputs were "sent" to it, and the time (which we don't need), and then we compute a SubsystemStates whose entries correspond to the derivatives of the respective states

function GraphDynamics.subsystem_differential(sys::Subsystem{VanDerPol}, inputs, t)
    # Unpack the states and params we need
    (;x, y, θ) = sys # this is fancy syntax for x = sys.x; y = sys.y; θ = sys.θ

    # Unpack the inputs
    (;jcn_x, jcn) = inputs

    return SubsystemStates{VanDerPol}(
        #=d/dt=#x = y + jcn_x,
        #=d/dt=#y = θ*(1-x^2)*y - x + jcn
    )
end

Noise terms

The keen-eyed may have noticed that we didn't include the ϕ*ξ term in the differential for y. This is because we only use subsystem_differential for the non-stochastic part of the ODE.

To include stochastic noise, we first tell GraphDynamics that our VanDerPol oscillator is stochastic (by default, it's assumed to not be stochastic)

GraphDynamics.isstochastic(::Type{VanDerPol}) = true

function GraphDynamics.apply_subsystem_noise!(v_noise, sys::Subsystem{VanDerPol}, t)
    v_noise[2] = sys.ϕ
end

The above method works by mutating a vector of potential noise terms because noise is typically "sparse", i.e. not all of our variables experience noise directly.

Writing v_noise[2] = sys.ϕ is eqivalent to the ξ*ϕ term where ξ is a Brownian variable.

If x also were to experience noise, you'd mutate v_noise[1] as well.

Currently, GraphDynamics assumes that each source of noise in the equations is independant, and does not support cases where x and y see correlated noise. This is equivalent to either 0 or 1 Brownian variable per state.

Computed properties

Now lets deal with the computed properties r, jcn_x, jcn and jcn_total.

r is different from the others because it does not depend on the inputs, it only depends on the internal states / parameters of the subsystem itself. We can tell GraphDynamics how to compute r by adding a method to computed_properties which returns a NamedTuple whose keys are the property names, and the values are functions to compute them:

function GraphDynamics.computed_properties(v::Subsystem{VanDerPol})
    r_func(v) = √(v.x^2 + v.y^2)
    (; r = r_func)
end

Likewise, for computed properties that depend on a subsystem's inputs, we define a method on

computed_properties_with_inputs, except the functions returned will have an extra argument for the inputs:

function GraphDynamics.computed_properties_with_inputs(v::Subsystem{VanDerPol})
    jcn_x(v, input) = input.jcn_x
    jcn(v, input) = input.jcn
    jcn_tot(v, input) = input.jcn_x + input.jcn
    (; jcn_x, jcn, jcn_tot)
end

Solving a system of Van der Pols

Lets simulate a couple of VanDerPol oscillators. We can't couple them together yet because we haven't talked about connections, but we can at least run two parallel VdP oscillators and look at the results:

lets first solve the regular version:

tspan = (0.0, 2.0)
seed = 1234

g_vdp = MetaDiGraph()
@named v1 = VanDerPol(θ = 1.0, ϕ = 2.0)
@named v2 = VanDerPol(θ = 0.5, ϕ = 0.25)

add_blox!(g_vdp, v1)
add_blox!(g_vdp, v2);

Here's a solution computed with the regular machinery:

let
    @named sys = system_from_graph(g_vdp)
    # Seed with a set value so we get consistent results
    prob = SDEProblem(sys, [], tspan; seed=seed)
    sol = solve(prob, RKMil())
    f = Figure()
    ax = Axis(f[1, 1], xlabel="t")
    lines!(ax, sol.t, sol[v1.r], label="v1.r")
    lines!(ax, sol.t, sol[v2.r], label="v2.r")
    f
end
Example block output

And now lets try to do the same with GraphDynamics:

let
    # this is how we tell it to use GraphDynamics! ------↓
    @named sys = system_from_graph(g_vdp; graphdynamics=true)
    # Use the same seed as the MTK solution to make sure we get consistent results
    prob = SDEProblem(sys, [], tspan; seed=seed)
    sol = solve(prob, RKMil())
    f = Figure()
    ax = Axis(f[1, 1], xlabel="t")
    lines!(ax, sol.t, sol[:v1₊r], label="v1.r")
    lines!(ax, sol.t, sol[:v2₊r], label="v2.r")
    f
end
Example block output

DBS Source blox

Now lets implement another random blox, the DBS Source blox from Neuroblox/src/blox/DBS_sources.jl, because lets say we want to use this to drive our VanDerPol oscillator.

struct DBS <: Neuroblox.StimulusBlox
    params::Vector{Num}
    system::ODESystem
    namespace::Union{Symbol, Nothing}
    stimulus::Function
    function DBS(;
        name,
        namespace=nothing,
        frequency=130.0,
        amplitude=2.5,
        pulse_width=0.066,
        offset=0.0,
        start_time=0.0,
        smooth=1e-4
    )
        # Ensure consistent numeric types for all parameters
        frequency, amplitude, pulse_width, offset, start_time, smooth =
            promote(frequency, amplitude, pulse_width, offset, start_time, smooth)

        # Convert to kHz (to match interal time in ms)
        frequency_khz = frequency/1000.0

        # Create stimulus function based on smooth/non-smooth square wave
        stimulus = if smooth == 0
            t -> Neuroblox.square(t, frequency_khz, amplitude, offset, start_time, pulse_width)
        else
            t -> Neuroblox.square(t, frequency_khz, amplitude, offset, start_time, pulse_width, smooth)
        end

        p = Neuroblox.paramscoping(
            tunable=false;
            frequency=frequency,
            amplitude=amplitude,
            pulse_width=pulse_width,
            offset=offset,
            start_time=start_time
        )

        sts = @variables u(t) [output = true]
        eqs = [u ~ stimulus(t)]
        sys = System(eqs, t, sts, p; name=name)

        new(p, sys, namespace, stimulus)
    end
end

This is a bit special because it doesn't actually have any dynamical state. It basically just has it's u ~ stimulus(t), which we'll actually treat as a parameter in the GraphDynamics approach:

GraphDynamics version

Basics

GraphDynamicsInterop.issupported(::DBS) = true

GraphDynamicsInterop.components(d::DBS) = (d,)


GraphDynamics.initialize_input(s::Subsystem{DBS}) = (;)

function GraphDynamicsInterop.to_subsystem(d::DBS)
    # Extract the DBS stimulus function
    stimulus = getfield(d, :stimulus)
    params = SubsystemParams{DBS}(; stimulus)

    # Return *empty* states
    states = SubsystemStates{DBS}()

    # Form a Subsystem from the states and params
    Subsystem(states, params)
end

A do-nothing differential

Since there's no dynamics, we can skip subsystem_differential and instead just tell it that apply_subsystem_differential! does nothing:

function GraphDynamics.apply_subsystem_differential!(_, d::Subsystem{DBS}, _, _)
    nothing
end

Connections

Now, all we need to do is define some connections between some blox. Let's go for a connection from the DBS source to a Van der Pol.

Suppose our regular Neuroblox connection rule looks like:

function Neuroblox.Connector(
    blox_src::DBS,
    blox_dest::NeuralMassBlox;
    kwargs...
)
    sys_src = get_namespaced_sys(blox_src)
    sys_dest = get_namespaced_sys(blox_dest)

    w = generate_weight_param(blox_src, blox_dest; kwargs...)

    eq = sys_dest.jcn ~ w * sys_src.u

    return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w)
end

Then the GraphDynamics version of this can be as simple as

function (c::GraphDynamicsInterop.BasicConnection)(sys_src::Subsystem{DBS},
                                                   sys_dst::Subsystem{VanDerPol},
                                                   t)

    w = c.weight
    jcn_x = 0.0 # do nothing to jcn_x
    jcn = w * sys_src.stimulus(t) # drive jcn
    (; jcn_x, jcn) # This must match the form of initialize_input(sys_dst)
end

Solve a system with the DBS and VdP Blox

g_vdp_dbs = MetaDiGraph()
@named dbs = DBS()
add_edge!(g_vdp_dbs, dbs => v1; weight=1.0)
add_edge!(g_vdp_dbs, dbs => v2; weight=1.0)

let
    @named sys = system_from_graph(g_vdp_dbs)

    prob = SDEProblem(sys, [], tspan; seed=seed)
    sol = solve(prob, RKMil())
    f = Figure()
    ax = Axis(f[1, 1], xlabel="t")
    lines!(ax, sol.t, sol[v1.r], label="v1.r")
    lines!(ax, sol.t, sol[v2.r], label="v2.r")
    f
end
Example block output

and with GraphDynamics:

let
    # this is how we tell it to use GraphDynamics! ----------↓
    @named sys = system_from_graph(g_vdp_dbs; graphdynamics=true)

    prob = SDEProblem(sys, [], tspan; seed=seed)
    sol = solve(prob, RKMil())
    f = Figure()
    ax = Axis(f[1, 1], xlabel="t")
    lines!(ax, sol.t, sol[:v1₊r], label="v1.r")
    lines!(ax, sol.t, sol[:v2₊r], label="v2.r")
    f
end
Example block output

TODO: A neuron model with events

TODO: Composite blox


This page was generated using Literate.jl.