Neuroblox API
This page documents the API functions available to users of Neuroblox. This page is focused on utilities like plotting, system generation and querying, and . For documentation of the various Blox, see the Blox documentation page.
Model Creation and Querying
At the highest level, a Neuroblox model consists of a weighted directed graph (represented as a MetaDiGraph) whose nodes are Blox representing neurons and populations of neurons, and whose edges are connectors dictating how the dynamics of the Blox affect each other. The weights of the edges represent the strengths of synaptic connections. The model graph is used to generate a system of ordinary differential equations that can be solved using Julia's differential equations solvers.
Graphs and Systems
The following functions are used in the construction of graphs.
NeurobloxBase.add_blox!
— Functionadd_blox!(g::MetaDiGraph, blox)
Add a Blox as a vertex to a graph g
.
NeurobloxBase.create_adjacency_edges!
— Functioncreate_adjacency_edges!(g::MetaDiGraph, adj_matrix::Matrix{T};
connection_rule = "basic")
Given an adjacency matrix, populate the graph g
with the edges stored in the adjacency matrix. The connection_rule
keyword argument dictates the type of connection. Possible values for this argument include the following: - "basic": simple weighted connection - "psp": postsynaptic potential connection
The following functions are used in the construction of systems from graphs or lists of Blox.
NeurobloxBase.system_from_graph
— Functionsystem_from_graph(g::MetaDiGraph, p=Num[]; name, simplify=true, graphdynamics=false, kwargs...)
Take in a MetaDiGraph
g
describing a network of neural structures (and optionally a vector of extra parameters p
) and construct a System
which can be used to construct various Problem
types (i.e. ODEProblem
) for use with DifferentialEquations.jl solvers.
If simplify
is set to true
(the default), then the resulting system will have structural_simplify
called on it with any remaining keyword arguments forwared to structural_simplify
. That is,
@named sys = system_from_graph(g; kwarg1=x, kwarg2=y)
is equivalent to
@named sys = system_from_graph(g; simplify=false)
sys = structural_simplify(sys; kwarg1=x, kwarg2=y)
See the docstring for structural_simplify
for information on which options it supports.
If graphdynamics=true
(defaults to false
), the output will be a GraphSystem
from GraphDynamics.jl, and the kwargs
will be sent to the GraphDynamics
constructor instead of using ModelingToolkit.jl. The GraphDynamics.jl backend is typically significantly faster for large neural systems than the default backend, but is experimental and does not yet support all Neuroblox.jl features.
NeurobloxBase.system_from_parts
— Functionsystem_from_parts(parts::AbstractVector; name)
Compose a list of Blox into one ModelingToolkit System without adding the connections between these Blox.
The following functions are used to query graphs.
NeurobloxBase.get_system
— Functionget_system(g::MetaDiGraph)
Get a vector of ModelingToolkit Systems corresponding to the Blox of the graph g
.
Missing docstring for connectors_from_graph
. Check Documenter's build log for details.
NeurobloxBase.generate_discrete_callbacks
— Functiongenerate_discrete_callbacks(g::MetaDiGraph, bc::Connector,
eqs::AbstractVector{<:Equation};
t_block = missing)
Get the list of discrete callbacks for a graph g
.
Blox
Blox are the basic components of neural circuits. For documentation of the kinds of blox available, please see the Blox documentation page. Internally these are represented as ModelingToolkit systems.
The following functions are used to query blox.
NeurobloxBase.get_exci_neurons
— FunctionGet the excitatory neurons of a Blox.
NeurobloxBase.get_inh_neurons
— FunctionGet the inhibitory neurons of a Blox.
Missing docstring for get_input_equations
. Check Documenter's build log for details.
Base.nameof
— Functionnameof(blox)
Return the un-namespaced name of blox
. See also namespaceof
and namespaced_nameof
.
NeurobloxBase.namespaced_nameof
— Functionnamespaced_nameof(blox)
Return the name of the blox, prefaced by its entire inner namespace (all levels of the namespace EXCEPT for the highest level). See also inner_namespaceof
.
ModelingToolkit.inputs
— FunctionModelingToolkit.inputs(blox; namespaced = false)
Return the input variables of a blox. If the kwarg namespaced
is set, then the resulting equations will be namespaced using the system's inner namespace.
ModelingToolkit.outputs
— FunctionModelingToolkit.outputs(blox::AbstractBlox; namespaced=false)
Return output variables of a blox
. If the kwarg namespaced
is set, the resulting equations will be namespaced using the system's inner namespace.
Connectors
Connectors connect between Blox. They are characterized by connection equations dictating how the state of one Blox affects the other, affects that are triggered when the source spikes, a weight representing the strength of the synaptic connection, and a learning rule dictating how that weight changes when the source fires.
When constructing the graph, the desired connections are represented as edges and added using add_edge!
, but the Connectors are not instantiated until the full system is created using system_from_graph
.
The following are used to query properties of connectors.
NeurobloxBase.discrete_callbacks
— Functiondiscrete_callbacks(c::Connector)
Get the discrete events of a connection. These include the affects triggered whenever a spike occurs, as well as other things (e.g. switch behavior at some time t_event
).
NeurobloxBase.sources
— Functionsources(c::Connector)
Get the source Blox of the connection.
NeurobloxBase.destinations
— Functiondestinations(c::Connector)
Get the destination Blox of the connection.
NeurobloxBase.weights
— Functionweights(c::Connector)
Get the weight or weight matrix for a connection.
NeurobloxBase.delays
— Functiondelays(c::Connector)
Get the delays of the connection.
NeurobloxBase.spike_affects
— Functionspike_affects(c::Connector)
Get the affects that are triggered every time the pre-synaptic Blox fires.
NeurobloxBase.learning_rules
— Functionlearning_rules(c::Connector)
Get the learning rule for the weight of the connector. Examples are NeurobloxPharma.HebbianPlasticity
and NeurobloxPharma.HebbianModulationPlasticity
.
The following functions must be defined every time the user wants to define a new kind of connection between two types of Blox.
NeurobloxBase.connection_equations
— Functionconnection_equations(source, destination, w; kwargs...)
Return the list of equations for a connection. This should be implemented any time one creates a new type of connection between Blox. If the connection is event-based, implement connection_callbacks
as well.
NeurobloxBase.connection_spike_affects
— Functionconnection_spike_affects(source, destination, w)
Return the list of affects that occur every time the presynaptic neuron fires. This should be implemented any time one creates a new type of connection that uses event-based spiking between Blox.
Missing docstring for connection_learning_rule
. Check Documenter's build log for details.
NeurobloxBase.connection_callbacks
— Functionconnection_callbacks(source, destination; kwargs...)
Returns the callbacks for the connection, if it is an event-based connection. This should be implemented any time one creates a new type of connection between Blox. If the connection is continuous, implement connection_equations
as well.
Connection rules are used to define the structure of connections between composite blox (i.e. ones consisting of multiple neurons). They are passed in as a connection_rule
keyword argument to add_edge!
, which can be :hypergeometric
, :density
, or :weightmatrix
. Internally, these keyword arguments correspond to the following functions.
NeurobloxBase.hypergeometric_connections
— Functionhypergeometric_connections(neurons_src, neurons_dst)
Create connections between two populations of neurons. For each postsynaptic (destination) neuron, randomly generate connections from the pool of presynaptic neurons, while guaranteeing that almost all source neurons have the same number of connections, and almost all destination neurons have the same number of connections.
Keyword arguments: - rng
: choice of random number generator - density
: specifies the total number of connections (with density = 1 being fully connected) - weight
: a number or vector indicating the weights of each connection
NeurobloxBase.density_connections
— Functiondensity_connections(neurons_src, neurons_dst, name_src, name_dst; kwargs...)
Create connections between two populations of neurons. Unlike hypergeometric_connections
, this process is entirely random and will not guarantee that neurons have roughly equal numbers of connections.
Keyword arguments: - rng
: choice of random number generator - density
: specifies the total number of connections (with density = 1 being fully connected) - weight
: a number or vector indicating the weights of each connection
Missing docstring for weight_matrix_connections
. Check Documenter's build log for details.
NeurobloxBase.indegree_constrained_connections
— Functionindegree_constrained_connections(neurons_src, neurons_dst, name_src, name_dst;
kwargs...)
Create connections between two populations of neurons such that every destination neuron has the same in-degree, but no guarantees are made for the degrees of source neurons.
Keyword arguments:
connection_matrix
: pre-specify the connection matrix.
Plotting
This section documents helpers for generating plots from solutions to Neuroblox simulations. The backend for generating plots for Neuroblox is Makie. In order to call these functions, one must have a Makie backend installed, such as CairoMakie or GLMakie.
NeurobloxBase.meanfield
— Functionmeanfield(blox, sol)
Plot the mean-field voltage (in mV) as a function of time (in ms) for a blox.
Note: this function requires Makie to be loaded.
See also meanfield!
, meanfield_timeseries
.
NeurobloxBase.meanfield!
— Functionmeanfield!(ax::Axis, blox, sol)
Update an existing plot to show the mean-field voltage (in mV) as a function of time (in ms) for a blox.
Note: this function requires Makie to be loaded.
See also meanfield!
, meanfield_timeseries
.
NeurobloxBase.rasterplot
— Functionrasterplot(blox, sol; threshold = nothing, kwargs...)
Create a scatterplot of neuron firing events, where the x-axis is time (in ms) and the y-axis is the neuron's index. Internally calls detect_spikes
, and the threshold
kwarg is propagated to detect_spikes
, while the rest of the kwargs are Makie kwargs.
Note: this function requires Makie to be loaded.
See also rasterplot!
, detect_spikes
.
NeurobloxBase.rasterplot!
— Functionrasterplot!(ax::Axis, blox, sol; threshold = nothing, kwargs...)
Update an existing plot to show a scatterplot of neuron firing events, where the x-axis is time (in ms) and the y-axis is the neuron's index. Internally calls detect_spikes
. The threshold
kwarg is the voltage threshold for detect_spikes
, while the rest of the kwargs are Makie kwargs.
Note: this function requires Makie to be loaded.
See also rasterplot!
, detect_spikes
.
NeurobloxBase.stackplot
— Functionstackplot(blox, sol)
Plot the voltage timeseries of the neurons in a Blox, stacked on top of each other.
Note: this function requires Makie to be loaded.
See also stackplot!
.
NeurobloxBase.stackplot!
— Functionstackplot!(ax::Axis, blox, sol)
Update an existing plot to show the voltage timeseries of the neurons in a Blox, stacked on top of each other.
Note: this function requires Makie to be loaded.
See also stackplot
.
NeurobloxBase.frplot
— Functionfrplot(blox, sol; win_size = 10, overlap = 0, transient = 0, threshold = nothing, kwargs...)
Plot the firing frequency (either individual firing frequency for a neuron or mean firing frequency for a population) of a blox as a function of time (in s). The named keyword arguments are propagated to firing_rate
, while the rest of the kwargs are propagated to Makie for plotting.
Note: this function requires Makie to be loaded.
See also frplot!
, firing_rate
.
NeurobloxBase.frplot!
— Functionfrplot!(ax::Axis, blox, sol; win_size = 10, overlap = 0, transient = 0, threshold = nothing, kwargs...)
Update an existing plot with the firing frequency (either individual firing frequency for a neuron or mean firing frequency for a population) of a blox as a function of time (in s). The named keyword arguments are propagated to firing_rate
, while the rest of the kwargs are propagated to Makie for plotting.
Note: this function requires Makie to be loaded.
See also frplot
, firing_rate
.
NeurobloxBase.voltage_stack
— Functionvoltage_stack(blox, sol; kwargs...)
Create and display a stackplot
of the voltage timeseries.
Note: this function requires Makie to be loaded.
NeurobloxBase.powerspectrumplot
— Functionpowerspectrumplot(blox, sol; sampling_rate = nothing, method = nothing, window = nothing, kwargs...)
Plot the power spectrum of the solution (intensity as a function of frequency). The named keyword arguments are propagated to the internal powerspectrum
call, while the rest of the keyword arguments are propagated to Makie for plotting.
Note: this function requires Makie to be loaded.
See also powerspectrumplot!
, powerspectrum
.
NeurobloxBase.powerspectrumplot!
— Functionpowerspectrumplot!(ax::Axis, blox, sol; sampling_rate = nothing, method = nothing, window = nothing, kwargs...)
Update an existing plot with the power spectrum of the solution (intensity as a function of frequency). The named keyword arguments are propagated to the internal powerspectrum
call, while the rest of the keyword arguments are propagated to Makie for plotting.
Note: this function requires Makie to be loaded.
See also powerspectrumplot
, powerspectrum
.
Additionally there are several helpers for extracting useful information from solutions to simulations, such as the timing of spikes and the firing rate. Several of these are called by the plotting functions.
NeurobloxBase.detect_spikes
— Functiondetect_spikes(blox::AbstractNeuron, sol;
threshold = nothing,
tolerance = 1e-3,
ts = nothing,
scheduler = :serial)
Find the spikes of a timeseries, where spikes are defined to have voltage greater than the threshold
. Return a SparseVector that is equal to 1 at the time indices of the spikes.
Keyword arguments: - threshold
: threshold voltage for a spike - tolerance
: the range around the threshold value in which a maxima counts as a spike - ts
: time - scheduler
:
NeurobloxBase.firing_rate
— Functionfiring_rate(blox, sol;
transient = 0, win_size = last(sol.t) - transient,
overlap = 0, threshold = nothing,
scheduler = :serial, kwargs...)
Keyword arguments: - transient: - win_size - overlap - threshold - scheduler
NeurobloxBase.inter_spike_intervals
— Functioninter_spike_intervals(blox::AbstractNeuron, sol; threshold, ts)
Return the time intervals between subsequent spikes in the solution of a single neuron.
inter_spike_intervals(blox::AbstractNeuron, sol; threshold, ts)
Return the time intervals between subsequent spikes in the solution of a Blox. Outputs a matrix whose rows are the interspike intervals for a single neuron.
NeurobloxBase.flat_inter_spike_intervals
— Functionflat_inter_spike_intervals(blox::AbstractNeuron, sol; threshold, ts)
Return the time intervals between subsequent spikes in the solution of a Blox. Concatenates the lists of interspike intervals of all vectors into a single vector.
NeurobloxBase.powerspectrum
— Functionpowerspectrum(cb, sol; sampling_rate = nothing, method = periodogram, window = nothing)
Plot the powerspectrum of the voltage timeseries for a set of Blox.
NeurobloxBase.voltage_timeseries
— Functionvoltage_timeseries(cb, sol; ts)
Return the voltage timeseries of a Blox or collection of Blox.
NeurobloxBase.meanfield_timeseries
— Functionmeanfield_timeseries(cb, sol, state; ts)
Return the timeseries of the average value of state variable state
over a collection of neurons or a composite. Provide the optional kwarg ts
to return the variable's value at specific times ts
.
NeurobloxBase.state_timeseries
— Functionstate_timeseries(blox, sol::SciMLBase.AbstractSolution, state::String; ts = nothing)
Return the timeseries for the state variable named state
as a vector. Provide the optional kwarg ts
to return the variable's value at specific times ts
.
state_timeseries(cb::Union{AbstractComposite, AbstractVector{<:AbstractBlox}},
sol::SciMLBase.AbstractSolution, state::String; ts = nothing)
Return the state_timeseries
of the state variable state
for each Blox in a composite or vector of Blox. The resulting collection of timeseries are stacked as rows of a matrix. Provide the optional kwarg ts
to return the variable's value at specific times ts
.
Reinforcement Learning
The following section documents the infrastructure for performing reinforcement learning on neural systems. The neural system acts as the agent, while the environment is a series of sensory stimuli presented to the model. The agent's action is the classification of the stimuli, and the choice follows some policy. Learning occurs as the connection weights of the system are updated according to some learning rule after each choice made by the agent.
NeurobloxPharma.Agent
— FunctionAgent(g::MetaDiGraph; name, graphdynamics=false, kwargs...)
Create a RL agent from a graph representing a neural circuit. This contains the system constructed from the graph, as well as its policy, connections, and the learning rules of each connection, which are extracted from the graph. The graphdynamics
kwarg sets whether to construct a GraphDynamics system or ModelingToolkit system from the graph.
NeurobloxPharma.ClassificationEnvironment
— TypeClassificationEnvironment(stim::ImageStimulus, N_trials; name, namespace, t_stimulus, t_pause)
Create an environment for reinforcement learning. A set of images is presented to the agent to be classified. This struct stores the correct class for each image, and the current trial of the experiment.
Arguments:
- stim: The ImageStimulus, created from a set of images
- N_trials: Number of trials. The agent performs one classification each trial.
- t_stimulus: The length of time the stimulus is on (ms)
- t_pause: The length of time the stimulus is off (ms)
The following policies are implemented in Neuroblox. Policies are represented by the AbstractActionSelection
type. Policies are added as nodes to the graph, with the set of actions represented by incoming connections.
NeurobloxPharma.GreedyPolicy
— TypeGreedyPolicy(; name, t_decision, namespace, competitor_states = Num[], competitor_params = Num[])
A policy that performs classification by picking the state with the highest value among competitor_states
. t_decision
is the time of the decision.
NeurobloxBase.action_selection_from_graph
— Functionaction_selection_from_graph(g::MetaDiGraph)
If one of the Blox in the graph g
is a policy for reinforcement learning, then return it. Otherwise return nothing
. A graph can only have one action selection Blox.
The following learning rules are implemented in Neuroblox:
NeurobloxPharma.HebbianPlasticity
— TypeHebbianPlasticity(; K, W_lim,
state_pre = nothing,
state_post = nothing,
t_pre = nothing,
t_post = nothing)
Hebbian learning rule. Every trial of the RL experiment, update the weight according to the following:
\[ w_{j+1} = w_j + \text{feedback} × Kx_\text{pre}x_\text{post}(W_\text{lim} - w)\]
where feedback
indicates the correctness of the agent's action during the trial, and the x
indicate the activities of the pre- and post-synaptic neurons.
Arguments: - K: the learning rate of the connection - W_lim: the maximum weight for the connection
See also HebbianModulationPlasticity
.
NeurobloxPharma.HebbianModulationPlasticity
— TypeHebbianModulationPlasticity(; K, decay, α, θₘ,
state_pre = nothing,
state_post = nothing,
t_pre = nothing,
t_post = nothing,
t_mod = nothing,
modulator = nothing)
Hebbian learning rule, but modulated by the dopamine reward prediction error. The weight update is largest when the reward prediction error is far from the modulation threshold θₘ.
\[ ϵ = \text{feedback} - (\text{DA}_b - \text{DA}) w_{j+1} = w_j + \max(\times Kx_\text{pre}x_\text{post}ϵ(ϵ + θₘ) dσ(α(ϵ + θₘ)) - \text{decay} × w, -w)\]
where feedback
indicates the correctness of the agent's action during the trial, DA_b is the baseline dopamine level and DA is the modulator's dopamine release, and dσ is the derivative of the logistic function. The decay prevents the weights from diverging.
Arguments: - K: the learning rate of the connection - decay: Decay of the weight update - α: the selectivity of the derivative of the logistic function - θₘ: the modulation threshold for the reward prediction error
See also HebbianPlasticity
.
NeurobloxBase.weight_gradient
— Functionweight_gradient(lr::AbstractLearningRule, sol, w, feedback)
Calculate the way that the weight w
should change based on the solution of the reinforcement learning experiment.
The following functions are used to run reinforcement learning experiments.
NeurobloxBase.run_warmup
— Functionrun_warmup(agent::AbstractAgent, env::AbstractEnvironment, t_warmup; kwargs...)
Run the initial solve of the RL experiment for t_warmup
.
NeurobloxBase.run_trial!
— Functionrun_trial!(agent::AbstractAgent, env::AbstractEnvironment, weights, u0; kwargs...)
Run a single trial of a RL experiment. Update the connection weights according to the learning rules.
NeurobloxBase.run_experiment!
— Functionrun_experiment!(agent::AbstractAgent, env::AbstractEnvironment; verbose=false, t_warmup=0, kwargs...)
Perform a full RL experiment with agent
in the environment env
. Will run until the maximum number of trials in env
is reached.