Source code for arborize.definitions

import abc
import dataclasses
import typing
from abc import abstractmethod

from ._util import Assert, Copy, Iterable, MechId, MechIdTuple, Merge
from .exceptions import ModelDefinitionError

if typing.TYPE_CHECKING:  # pragma: nocover
    from .parameter import Parameter


[docs] @dataclasses.dataclass class CableProperties(Copy, Merge, Assert, Iterable): Ra: float = None cm: float = None """ Axial resistivity in ohm/cm. """
[docs] def to_dict(self) -> dict: return dataclasses.asdict(self)
[docs] class CablePropertiesDict(typing.TypedDict, total=False): Ra: float cm: float
[docs] @dataclasses.dataclass class Ion(Copy, Merge, Assert, Iterable): rev_pot: float = None int_con: float = None ext_con: float = None
[docs] def to_dict(self) -> dict: return dataclasses.asdict(self)
[docs] class IonDict(typing.TypedDict, total=False): rev_pot: float int_con: float ext_con: float
[docs] class Mechanism: def __init__(self, parameters: dict[str, float]): super().__init__() self.parameters = parameters
[docs] def merge(self, other): for key, value in other.parameters.items(): self.parameters[key] = value
[docs] def copy(self): return Mechanism(self.parameters.copy())
[docs] def to_dict(self): return {k: v for k, v in self.parameters.items()}
[docs] class Synapse(Mechanism): mech_id: MechIdTuple def __init__(self, parameters, mech_id: MechId): super().__init__(parameters) self.mech_id = to_mech_id(mech_id)
[docs] def copy(self): return type(self)(self.parameters.copy(), to_mech_id(self.mech_id))
[docs] class ExpandedSynapseDict(typing.TypedDict, total=False): mechanism: MechId parameters: dict[str, float]
SynapseDict = dict[str, float] | ExpandedSynapseDict
[docs] def is_mech_id(mech_id): return str(mech_id) == mech_id or ( tuple(mech_id) == mech_id and 0 < len(mech_id) < 4 and all(str(part) == part for part in mech_id) )
[docs] def to_mech_id(mech_id: MechId) -> MechIdTuple: if mech_id is None: raise ValueError("Mech id may not be None") return (mech_id,) if not isinstance(mech_id, tuple) else tuple(mech_id)
[docs] class CableType: cable: CableProperties ions: dict[str, Ion] mechs: dict[MechId, Mechanism] synapses: dict[MechId, Synapse] def __init__(self, cable_property_class=CableProperties): self.cable = cable_property_class() self.ions = {} self.mechs = {} self.synapses = {}
[docs] def copy(self): def_ = type(self)() def_.cable = self.cable.copy() def_.ions = {k: v.copy() for k, v in self.ions.items()} def_.mechs = {k: v.copy() for k, v in self.mechs.items()} def_.synapses = {k: v.copy() for k, v in self.synapses.items()} return def_
[docs] def set(self, param: "Parameter"): if hasattr(param, "set_cable_params"): param.set_cable_params(self.cable) if hasattr(param, "set_mech_params"): param.set_mech_params(self.mechs)
[docs] def to_dict(self): cable_dict = {} cable_dict["ions"] = {k: v.to_dict() for k, v in self.ions.items()} cable_dict["cable"] = self.cable.to_dict() cable_dict["mechanisms"] = {k: v.to_dict() for k, v in self.mechs.items()} cable_dict["synapses"] = {k: v.to_dict() for k, v in self.synapses.items()} return cable_dict
[docs] @classmethod def anchor( cls, defs: typing.Iterable["CableType"], synapses: dict[MechId, Synapse] = None, use_defaults: bool = False, ion_class=Ion, ) -> "CableType": def_ = cls() if not use_defaults else cls.default(ion_class) if synapses is not None: # We need to merge the local synapses on top of the global ones, # without mutating the global dictionary. So we: # - Create a new cable type for the global synapses globaldef = cls() # - Add the synapses to it for key, value in synapses.items(): globaldef.add_synapse(key, value) # - Merge the local synapses over it globaldef._mergedict(globaldef.synapses, def_.synapses) # - Transfer the result to our def. def_.synapses = globaldef.synapses # Merge the definitions onto our def. Each merge overwrites our values, with the # last item in the list having the final say. for def_right in defs: if def_right is None: continue def_.merge(def_right) return def_
[docs] def merge(self, def_right: "CableType"): self.cable.merge(def_right.cable) self._mergedict(self.ions, def_right.ions) self._mergedict(self.mechs, def_right.mechs) self._mergedict(self.synapses, def_right.synapses)
def _mergedict(self, dself, dother): for key, value in dother.items(): if key in dself: dself[key].merge(dother[key]) else: dself[key] = value.copy()
[docs] def assert_(self): self.cable.assert_() for ion_name, ion in self.ions.items(): try: ion.assert_() except ValueError as e: raise ValueError( f"Missing '{e.args[1]}' value in ion '{ion_name}'", ion_name, e.args[1], ) from None
[docs] @classmethod def default(cls, ion_class=Ion): default = cls() default.cable.Ra = 35.4 default.cable.cm = 1 default.ions = default_ions_dict(ion_class) return default
[docs] def add_ion(self, key: str, ion: Ion): if key in self.ions: raise KeyError(f"An ion named '{key}' already exists.") self.ions[key] = ion
[docs] def add_mech(self, mech_id: MechId, mech: Mechanism): if not is_mech_id(mech_id): raise ValueError(f"'{mech_id}' is not a valid mechanism id.") if mech_id in self.mechs: raise KeyError(f"A mechanism with id '{mech_id}' already exists.") self.mechs[mech_id] = mech
[docs] def add_synapse(self, label: str | MechId, synapse: Synapse): mech_id = synapse.mech_id or to_mech_id(label) if not is_mech_id(mech_id): raise ValueError(f"'{mech_id}' is not a valid mechanism id.") if label in self.synapses: raise KeyError(f"A synapse with label '{label}' already exists.") self.synapses[label] = synapse
[docs] class CableTypeDict(typing.TypedDict, total=False): cable: CablePropertiesDict ions: dict[str, IonDict] mechanisms: dict[MechId, dict[str, float]] synapses: dict[MechId, SynapseDict]
[docs] class default_ions_dict(dict): def __init__(self, ion_class, *args, **kwargs): super().__init__(*args, **kwargs) self._ion_class = ion_class def _make_defaults(self): self._defaults = { "na": self._ion_class(rev_pot=50.0, int_con=10.0, ext_con=140.0), "k": self._ion_class(rev_pot=-77.0, int_con=54.4, ext_con=2.5), "ca": self._ion_class(rev_pot=132.4579341637009, int_con=5e-05, ext_con=2.0), "h": self._ion_class(rev_pot=0.0, int_con=1.0, ext_con=1.0), } def __setitem__(self, key, ion): if key not in self: if not hasattr(self, "_defaults"): if not hasattr(self, "_ion_class"): self._ion_class = type(ion) self._make_defaults() value = self._defaults[key].copy() # Do a criss-cross merge to merge defaults into the original ion object value.merge(ion) ion.merge(value) super().__setitem__(key, ion)
CT = typing.TypeVar("CT", bound=CableType) """ Type variable for cable types. """ CP = typing.TypeVar("CP", bound=CableProperties) """ Type variable for cable properties. """ I = typing.TypeVar("I", bound=Ion) # noqa: E741 """ Type variable for ions. """ M = typing.TypeVar("M", bound=Mechanism) # noqa: E741 """ Type variable for mechanisms. """ S = typing.TypeVar("S", bound=Synapse) # noqa: E741 """ Type variable for synapses. """
[docs] class Definition(typing.Generic[CT, CP, I, M, S], abc.ABC): @classmethod @property @abstractmethod def cable_type_class(cls) -> type[CT]: # pragma: nocover pass @classmethod @property @abstractmethod def cable_properties_class(cls) -> type[CP]: # pragma: nocover pass @classmethod @property @abstractmethod def ion_class(cls) -> type[I]: # pragma: nocover pass @classmethod @property @abstractmethod def mechanism_class(cls) -> type[M]: # pragma: nocover pass @classmethod @property @abstractmethod def synapse_class(cls) -> type[S]: # pragma: nocover pass def __init__(self, use_defaults=False): self._cable_types: dict[str, CT] = {} self._synapse_types: dict[MechId, S] = {} self.use_defaults = use_defaults
[docs] def copy(self): model = type(self)(self.use_defaults) for label, def_ in self._cable_types.items(): model.add_cable_type(label, def_.copy()) for label, def_ in self._synapse_types.items(): model.add_synapse_type(label, def_) return model
[docs] def get_cable_types(self) -> dict[str, CT]: return {k: v.copy() for k, v in self._cable_types.items()}
[docs] def get_synapse_types(self) -> dict[str, S]: return {k: v.copy() for k, v in self._synapse_types.items()}
[docs] def add_cable_type(self, label: str, def_: CT): if label in self._cable_types: raise KeyError(f"Cable type {label} already exists.") self._cable_types[label] = def_
[docs] def add_synapse_type(self, label: str | MechId, synapse: S): mech_id = synapse.mech_id or to_mech_id(label) if not is_mech_id(mech_id): raise ValueError(f"'{mech_id}' is not a valid synapse mechanism.") if label in self._synapse_types: raise KeyError(f"Synapse type {label} already exists.") self._synapse_types[label] = synapse
[docs] def to_dict(self) -> dict: cable_dict = {k: v.to_dict() for k, v in self._cable_types.items()} synapse_dict = {k: v.to_dict() for k, v in self._synapse_types.items()} model_dict = {"cable_types": cable_dict, "synapse_types": synapse_dict} return model_dict
[docs] class ModelDefinition(Definition[CableType, CableProperties, Ion, Mechanism, Synapse]):
[docs] @classmethod @property def cable_type_class(cls): return CableType
[docs] @classmethod @property def cable_properties_class(cls): return CableProperties
[docs] @classmethod @property def ion_class(cls): return Ion
[docs] @classmethod @property def mechanism_class(cls): return Mechanism
[docs] @classmethod @property def synapse_class(cls): return Synapse
[docs] class ModelDefinitionDict(typing.TypedDict, total=False): cable_types: dict[str, CableTypeDict] synapse_types: dict[MechId, SynapseDict]
@typing.overload def define_model( template: ModelDefinition, definition: ModelDefinitionDict, /, use_defaults: bool = ..., ) -> ModelDefinition: ... @typing.overload def define_model( definition: ModelDefinitionDict, /, use_defaults: bool = ... ) -> ModelDefinition: ...
[docs] def define_model(templ_or_def, def_dict=None, /, use_defaults=False) -> ModelDefinition: if def_dict is None: model = _parse_dict_def(ModelDefinition, templ_or_def) else: model = templ_or_def.copy() model.merge(_parse_dict_def(ModelDefinition, def_dict)) model.use_defaults = use_defaults return model
D = typing.TypeVar("D", bound=Definition) def _parse_dict_def(cls: type[D], def_dict: ModelDefinitionDict) -> D: model = cls() for label, def_input in def_dict.get("cable_types", {}).items(): ct = _parse_cable_type(cls, def_input) model.add_cable_type(label, ct) for label, def_input in def_dict.get("synapse_types", {}).items(): st = _parse_synapse_def(cls, label, def_input) model.add_synapse_type(label, st) return model def _parse_cable_type(cls: type[Definition], cable_dict: CableTypeDict): try: def_ = cls.cable_type_class(cls.cable_properties_class) def_.cable = cls.cable_properties_class(**cable_dict.get("cable", {})) for k, v in cable_dict.get("ions", {}).items(): parsed = _parse_ion_def(cls, v) def_.add_ion(k, parsed) for mech_id, v in cable_dict.get("mechanisms", {}).items(): def_.add_mech(mech_id, _parse_mech_def(cls, v)) for label, v in cable_dict.get("synapses", {}).items(): def_.add_synapse(label, _parse_synapse_def(cls, label, v)) return def_ except Exception as e: raise ModelDefinitionError( f"{cable_dict} is not a valid cable type definition." ) from e def _parse_ion_def(cls: type[Definition], ion_dict: IonDict): try: return cls.ion_class(**ion_dict) except Exception as e: raise ModelDefinitionError(f"{ion_dict} is not a valid ion definition.") from e def _parse_mech_def(cls: type[Definition], mech_dict: dict[str, float]): try: mech = cls.mechanism_class(mech_dict.copy()) return mech except Exception as e: raise ModelDefinitionError( f"{mech_dict} is not a valid mechanism definition." ) from e def _parse_synapse_def(cls: type[Definition], key, synapse_dict: SynapseDict): try: if "mechanism" in synapse_dict: # If `mechanism` is specified, it must be an expanded dict synapse_dict: ExpandedSynapseDict synapse = cls.synapse_class( # And if no parameters are given, set no parameters synapse_dict.get("parameters", {}).copy(), synapse_dict["mechanism"], ) else: # Otherwise, unless the key `parameters` is given, assume it's short form synapse = cls.synapse_class( # And treat all given dict items as parameters synapse_dict.get("parameters", synapse_dict).copy(), key, ) return synapse except Exception as e: raise ModelDefinitionError( f"{synapse_dict} is not a valid synapse definition." ) from e
[docs] class mechdict(dict): def __getitem__(self, item): return super().__getitem__((item,) if isinstance(item, str) else item) def __setitem__(self, key, value): return super().__setitem__((key,) if isinstance(key, str) else key, value)