kinase

package
v2.0.0-dev0.2.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jul 16, 2024 License: BSD-3-Clause Imports: 7 Imported by: 0

README

Kinase Learning Implementation

This implements central elements of the Kinase learning rule, including variables with associated time constants used for integrating calcium signals through the cascade of progressively longer time-integrals, from Ca -> CaM calmodulin (at MTau) -> CaMKII (CaP at PTau for LTP role) -> DAPK1 (CaD at DTau for LTD role).

See kinaseq example for an exploration of the implemented equations, and kinase repository for documentation and simulations about the biophysical basis of the equations.

Time constants and Variables

  • MTau (2 or 5) = calmodulin (CaM) time constant in cycles (msec) -- for synaptic-level integration this integrates on top of Ca signal from send->CaSyn * recv->CaSyn, each of which are typically integrated with a 30 msec Tau.

  • PTau (40) = LTP spike-driven Ca factor (CaP) time constant in cycles (msec), simulating CaMKII in the Kinase framework, with 40 on top of MTau roughly tracking the biophysical rise time. Computationally, CaP represents the plus phase learning signal that reflects the most recent past information.

  • DTau (40) = LTD spike-driven Ca factor (CaD) time constant in cycles (msec), simulating DAPK1 in Kinase framework. Computationally, CaD represents the minus phase learning signal that reflects the expectation representation prior to experiencing the outcome (in addition to the outcome).

TODO

  • remove NMDA and CaLrn from neuron if not used
  • use neuron-level recv, send final CaSpk* values as regressors! key!

Closed-form Expression for cascading integration (not faster)

Rishi Chaudhri suggested the following approach:

Wolfram Alpha Solution

But unfortunately all the FastExp calls end up being slower than directly computing the cascaded equations.

// Equations for below, courtesy of Rishi Chaudhri:
// 
// CaAtT computes the 3 Ca values at (currentTime + ti), assuming 0
// new Ca incoming (no spiking). It uses closed-form exponential functions.
func (kp *CaDtParams) CaAtT(ti int32, caM, caP, caD *float32) {
	t := float32(ti)
	mdt := kp.MDt
	pdt := kp.PDt
	ddt := kp.DDt
	if kp.ExpAdj.IsTrue() { // adjust for discrete
		mdt *= 1.11
		pdt *= 1.03
		ddt *= 1.03
	}
	mi := *caM
	pi := *caP
	di := *caD

	*caM = mi * math32.FastExp(-t*mdt)

	em := math32.FastExp(t * mdt)
	ep := math32.FastExp(t * pdt)

	*caP = pi*math32.FastExp(-t*pdt) - (pdt*mi*math32.FastExp(-t*(mdt+pdt))*(em-ep))/(pdt-mdt)

	epd := math32.FastExp(t * (pdt + ddt))
	emd := math32.FastExp(t * (mdt + ddt))
	emp := math32.FastExp(t * (mdt + pdt))

	*caD = pdt*ddt*mi*math32.FastExp(-t*(mdt+pdt+ddt))*(ddt*(emd-epd)+(pdt*(epd-emp))+mdt*(emp-emd))/((mdt-pdt)*(mdt-ddt)*(pdt-ddt)) - ddt*pi*math32.FastExp(-t*(pdt+ddt))*(ep-math32.FastExp(t*ddt))/(ddt-pdt) + di*math32.FastExp(-t*ddt)
}

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type BinWeights

type BinWeights struct {
	Bin0, Bin1, Bin2, Bin3, Bin4, Bin5, Bin6, Bin7 float32
}

BinWeights are 8 coefficients for computing Ca based on binned spike counts, for linear regression computation.

func (*BinWeights) Init

func (bw *BinWeights) Init(b0, b1, b2, b3, b4, b5, b6, b7 float32)

func (*BinWeights) Product

func (bw *BinWeights) Product(b0, b1, b2, b3, b4, b5, b6, b7 float32) float32

Product returns product of weights times bin values

type CaDtParams

type CaDtParams struct {

	// CaM (calmodulin) time constant in cycles (msec),
	// which is the first level integration.
	// For CaLearn, 2 is best; for CaSpk, 5 is best.
	// For synaptic-level integration this integrates on top of Ca
	// signal from send->CaSyn * recv->CaSyn, each of which are
	// typically integrated with a 30 msec Tau.
	MTau float32 `default:"2,5" min:"1"`

	// LTP spike-driven potentiation Ca factor (CaP) time constant
	// in cycles (msec), simulating CaMKII in the Kinase framework,
	// cascading on top of MTau.
	// Computationally, CaP represents the plus phase learning signal that
	// reflects the most recent past information.
	// Value tracks linearly with number of cycles per learning trial:
	// 200 = 40, 300 = 60, 400 = 80
	PTau float32 `default:"40,60,80" min:"1"`

	// LTD spike-driven depression Ca factor (CaD) time constant
	// in cycles (msec), simulating DAPK1 in Kinase framework,
	// cascading on top of PTau.
	// Computationally, CaD represents the minus phase learning signal that
	// reflects the expectation representation prior to experiencing the
	// outcome (in addition to the outcome).
	// Value tracks linearly with number of cycles per learning trial:
	// 200 = 40, 300 = 60, 400 = 80
	DTau float32 `default:"40,60,80" min:"1"`

	// rate = 1 / tau
	MDt float32 `display:"-" json:"-" xml:"-" edit:"-"`

	// rate = 1 / tau
	PDt float32 `display:"-" json:"-" xml:"-" edit:"-"`

	// rate = 1 / tau
	DDt float32 `display:"-" json:"-" xml:"-" edit:"-"`
	// contains filtered or unexported fields
}

CaDtParams has rate constants for integrating Ca calcium at different time scales, including final CaP = CaMKII and CaD = DAPK1 timescales for LTP potentiation vs. LTD depression factors.

func (*CaDtParams) Defaults

func (kp *CaDtParams) Defaults()

func (*CaDtParams) FromCa

func (kp *CaDtParams) FromCa(ca float32, caM, caP, caD *float32)

FromCa updates CaM, CaP, CaD from given current calcium value, which is a faster time-integral of calcium typically.

func (*CaDtParams) PDTauForNCycles

func (kp *CaDtParams) PDTauForNCycles(ncycles int)

PDTauForNCycles sets the PTau and DTau parameters in proportion to the total number of cycles per theta learning trial, e.g., 200 = 40, 280 = 60

func (*CaDtParams) Update

func (kp *CaDtParams) Update()

type Linear

type Linear struct {
	// Kinase Neuron params
	Neuron NeurCaParams

	// Kinase Synapse params
	Synapse SynCaParams

	// total number of cycles (1 MSec) to run per learning trial
	NCycles int `min:"10" default:"200"`

	// number of plus cycles
	PlusCycles int `default:"50"`

	// NumBins is the number of bins to accumulate spikes over NCycles
	NumBins int `default:"8"`

	// CyclesPerBin = NCycles / NumBins
	CyclesPerBin int `edit:"-"`

	// MaxHz is the maximum firing rate to sample in minus, plus phases
	MaxHz int `default:"120"`

	// StepHz is the step size for sampling Hz
	StepHz int `default:"10"`

	// NTrials is number of trials per Hz case
	NTrials int `default:"100"`

	// Total Trials is number of trials for all data
	TotalTrials int `edit:"-"`

	// Sending neuron
	Send Neuron

	// Receiving neuron
	Recv Neuron

	// Standard synapse values
	StdSyn Synapse

	// Linear synapse values
	LinearSyn Synapse

	// ErrDWt is the target error dwt: PlusHz - MinusHz
	ErrDWt float32

	// binned integration of send, recv spikes
	SpikeBins []float32

	// Data to fit the regression
	Data table.Table
}

Linear performs a linear regression to approximate the synaptic Ca integration between send and recv neurons.

func (*Linear) Cycle

func (ls *Linear) Cycle(nr *Neuron, expInt float32, cyc int)

Cycle does one cycle of neuron updating, with given exponential spike interval based on target spiking firing rate.

func (*Linear) Defaults

func (ls *Linear) Defaults()

func (*Linear) Init

func (ls *Linear) Init()

func (*Linear) InitTable

func (ls *Linear) InitTable()

func (*Linear) Regress

func (ls *Linear) Regress()

Regress runs the linear regression on the data

func (*Linear) Run

func (ls *Linear) Run()

Run generates data

func (*Linear) SetBins

func (ls *Linear) SetBins(sn, rn *Neuron, off, row int)

func (*Linear) SetSynState

func (ls *Linear) SetSynState(sy *Synapse, row int)

func (*Linear) StartTrial

func (ls *Linear) StartTrial()

func (*Linear) Trial

func (ls *Linear) Trial(sendMinusHz, sendPlusHz, recvMinusHz, recvPlusHz float32, ti, row int)

Trial runs one trial

func (*Linear) Update

func (ls *Linear) Update()

type NeurCaParams

type NeurCaParams struct {

	// SpikeG is a gain multiplier on spike impulses for computing CaSpk:
	// increasing this directly affects the magnitude of the trace values,
	// learning rate in Target layers, and other factors that depend on CaSpk
	// values, including RLRate, UpdateThr.
	// Larger networks require higher gain factors at the neuron level:
	// 12, vs 8 for smaller.
	SpikeG float32 `default:"8,12"`

	// time constant for integrating spike-driven calcium trace at sender and recv
	// neurons, CaSyn, which then drives synapse-level integration of the
	// joint pre * post synapse-level activity, in cycles (msec).
	// Note: if this param is changed, then there will be a change in effective
	// learning rate that can be compensated for by multiplying
	// PathParams.Learn.KinaseCa.CaScale by sqrt(30 / sqrt(SynTau)
	SynTau float32 `default:"30" min:"1"`

	// rate = 1 / tau
	SynDt float32 `display:"-" json:"-" xml:"-" edit:"-"`

	// time constants for integrating CaSpk across M, P and D cascading levels.
	// Typically the same as in CaLrn and Path level for synaptic integration.
	Dt CaDtParams `display:"inline"`
	// contains filtered or unexported fields
}

NeurCaParams parameterizes the neuron-level spike-driven calcium signals, starting with CaSyn that is integrated at the neuron level and drives synapse-level, pre * post Ca integration, which provides the Tr trace that multiplies error signals, and drives learning directly for Target layers. CaSpk* values are integrated separately at the Neuron level and used for UpdateThr and RLRate as a proxy for the activation (spiking) based learning signal.

func (*NeurCaParams) CaFromSpike

func (np *NeurCaParams) CaFromSpike(spike float32, caSyn, caM, caP, caD *float32)

CaFromSpike updates Ca variables from spike input which is either 0 or 1

func (*NeurCaParams) Defaults

func (np *NeurCaParams) Defaults()

func (*NeurCaParams) Update

func (np *NeurCaParams) Update()

type Neuron

type Neuron struct {
	// Neuron spiking (0,1)
	Spike float32

	// Neuron probability of spiking
	SpikeP float32

	// CaSyn is spike-driven calcium trace for synapse-level Ca-driven learning:
	// exponential integration of SpikeG * Spike at SynTau time constant (typically 30).
	// Synapses integrate send.CaSyn * recv.CaSyn across M, P, D time integrals for
	// the synaptic trace driving credit assignment in learning.
	// Time constant reflects binding time of Glu to NMDA and Ca buffering postsynaptically,
	// and determines time window where pre * post spiking must overlap to drive learning.
	CaSyn float32

	// neuron-level spike-driven Ca integration
	CaSpkM, CaSpkP, CaSpkD float32

	TotalSpikes float32

	// binned count of spikes, for regression learning
	SpikeBins []float32
}

Neuron has Neuron state

func (*Neuron) Init

func (kn *Neuron) Init()

func (*Neuron) StartTrial

func (kn *Neuron) StartTrial()

type SynCaLinear

type SynCaLinear struct {
	CaP BinWeights `display:"inline"`
	CaD BinWeights `display:"inline"`

	// CaGain is extra multiplier for Synaptic Ca
	CaGain float32 `default:"1"`
	// contains filtered or unexported fields
}

SynCaLinear computes synaptic calcium using linear equations fit to cascading Ca integration, for computing final CaP = CaMKII (LTP) and CaD = DAPK1 (LTD) factors as a function of product of binned spike totals on the sending and receiving neurons.

func (*SynCaLinear) Defaults

func (kp *SynCaLinear) Defaults()

func (*SynCaLinear) FinalCa

func (kp *SynCaLinear) FinalCa(b0, b1, b2, b3, b4, b5, b6, b7 float32, caP, caD *float32)

FinalCa uses a linear regression to compute the final Ca values

func (*SynCaLinear) Theta200plus50

func (kp *SynCaLinear) Theta200plus50()

Theta200plus50 sets bin weights for a theta cycle learning trial of 200 cycles and a plus phase of 50

func (*SynCaLinear) Theta280plus70

func (kp *SynCaLinear) Theta280plus70()

Theta280plus70 sets bin weights for a theta cycle learning trial of 280 cycles and a plus phase of 70, with PTau & DTau at 56 (PDTauForNCycles)

func (*SynCaLinear) Update

func (kp *SynCaLinear) Update()

func (*SynCaLinear) WtsForNCycles

func (kp *SynCaLinear) WtsForNCycles(ncycles int)

WtsForNCycles sets the linear weights

type SynCaParams

type SynCaParams struct {
	// CaScale is a scaling multiplier on synaptic Ca values,
	// which due to the multiplication of send * recv are smaller in magnitude.
	// The default 12 value keeps them in roughly the unit scale,
	// and affects effective learning rate.
	CaScale float32 `default:"12"`

	// time constants for integrating at M, P, and D cascading levels
	Dt CaDtParams `display:"inline"`
	// contains filtered or unexported fields
}

SynCaParams has rate constants for integrating spike-driven Ca calcium at different time scales, including final CaP = CaMKII and CaD = DAPK1 timescales for LTP potentiation vs. LTD depression factors.

func (*SynCaParams) Defaults

func (kp *SynCaParams) Defaults()

func (*SynCaParams) FromCa

func (kp *SynCaParams) FromCa(ca float32, caM, caP, caD *float32)

FromCa updates CaM, CaP, CaD from given current synaptic calcium value, which is a faster time-integral of calcium typically. ca is multiplied by CaScale.

func (*SynCaParams) Update

func (kp *SynCaParams) Update()

type Synapse

type Synapse struct {
	CaSyn float32

	// CaM is first stage running average (mean) Ca calcium level (like CaM = calmodulin), feeds into CaP
	CaM float32

	// CaP is shorter timescale integrated CaM value, representing the plus, LTP direction of weight change and capturing the function of CaMKII in the Kinase learning rule
	CaP float32

	// CaD is longer timescale integrated CaP value, representing the minus, LTD direction of weight change and capturing the function of DAPK1 in the Kinase learning rule
	CaD float32

	// DWt is the CaP - CaD
	DWt float32
}

Synapse has Synapse state

func (*Synapse) Init

func (ks *Synapse) Init()

Directories

Path Synopsis
synca_plot plots kinase SynCa update equations
synca_plot plots kinase SynCa update equations

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL