# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
"""
Bare example of how this module works::
    import k1lib
    class CbA(k1lib.Callback):
        def __init__(self):
            super().__init__()
            self.initialState = 3
        def startBatch(self):
            print("startBatch - CbA")
        def startPass(self):
            print("startPass - CbA")
    class CbB(k1lib.Callback):
        def startBatch(self):
            print("startBatch - CbB")
        def endLoss(self):
            print("endLoss - CbB")
    # initialization
    cbs = k1lib.Callbacks()
    cbs.add(CbA()).add(CbB())
    model = lambda xb: xb + 3
    lossF = lambda y, yb: y - yb
    # training loop
    cbs("startBatch"); xb = 6; yb = 2
    cbs("startPass"); y = model(xb); cbs("endPass")
    cbs("startLoss"); loss = lossF(y, yb); cbs("endLoss")
    cbs("endBatch")
    print(cbs.CbA) # can reference the Callback object directly
So, point is, you can define lots of :class:`Callback` classes that
defines a number of checkpoint functions, like ``startBatch``. Then,
you can create a :class:`Callbacks` object that includes Callback
objects. When you do ``cbs("checkpoint")``, this will execute
``cb.checkpoint()`` of all the Callback objects.
Pretty much everything here is built upon this. The core training loop
has nothing to do with ML stuff. In fact, it's just a bunch of
``cbs("...")`` statements. Everything meaningful about the training
loop comes from different Callback classes. Advantage of this is that you
can tack on wildly different functions, have them play nicely with each
other, and remove entire complex functionalities by commenting out a
single line."""
import k1lib, time, os, logging, numpy as np
plt = k1lib.dep("matplotlib.pyplot"); import k1lib.cli as cli
from typing import Set, List, Union, Callable, ContextManager, Iterator
from collections import OrderedDict
__all__ = ["Callback", "Callbacks", "Cbs"]
[docs]class Callback:                                                                  # Callback
    r"""Represents a callback. Define specific functions
inside to intercept certain parts of the training
loop. Can access :class:`k1lib.Learner` like this::
    self.l.xb = self.l.xb[None]
This takes x batch of learner, unsqueeze it at
the 0 position, then sets the x batch again.
Normally, you will define a subclass of this and
define specific intercept functions, but if you
want to create a throwaway callback, then do this::
    Callback().withCheckpoint("startRun", lambda: print("start running"))
You can use :attr:`~k1lib.callbacks.callbacks.Cbs` (automatically exposed) for
a list of default Callback classes, for any particular needs.
**order**
You can also use `.order` to set the order of execution of the callback.
The higher, the later it gets executed. Value suggestions:
- 7: pre-default runs, like LossLandscape
- 10: default runs, like DontTrainValid
- 13: custom mods, like ModifyBatch
- 15: pre-recording mod
- 17: recording mods, like Profiler.memory
- 20: default recordings, like Loss
- 23: post-default recordings, like ParamFinder
- 25: guards, like TimeLimit, CancelOnExplosion
Just leave as default (10) if you don't know what values to choose.
**dependsOn**
If you're going to extend this class, you can also specify dependencies
like this::
    class CbC(k1lib.Callback):
        def __init__(self):
            super().__init__()
            self.dependsOn = {"Loss", "Accuracy"}
This is so that if somewhere, ``Loss`` callback class is temporarily
suspended, then CbC will be suspended also, therefore avoiding errors.
**Suspension**
If your Callback is mainly dormant, then you can do something like this::
    class CbD(k1lib.Callback):
        def __init__(self):
            super().__init__()
            self.suspended = True
        def startBatch(self):
            # these types of methods will only execute
            # if ``self.suspended = False``
            pass
        def analyze(self):
            self.suspended = False
            # do something that sometimes call ``startBatch``
            self.suspended = True
    cbs = k1lib.Callbacks().add(CbD())
    # dormant phase:
    cbs("startBatch") # does not execute CbD.startBatch()
    # active phase
    cbs.CbB.analyze() # does execute CbD.startBatch()
So yeah, you can easily make every checkpoint active/dormant by changing
a single variable, how convenient. See over :meth:`Callbacks.suspend`
for more."""                                                                     # Callback
    def __init__(self):                                                          # Callback
        self.l = None; self.cbs = None; self.suspended = False                   # Callback
        self.name = self.__class__.__name__; self.dependsOn:Set[str] = set()     # Callback
        self.order = 10 # can be modified by subclasses. A smaller order will be executed first # Callback
[docs]    def suspend(self):                                                           # Callback
        """Checkpoint, called when the Callback is temporarily suspended. Overridable""" # Callback
        pass                                                                     # Callback 
[docs]    def restore(self):                                                           # Callback
        """Checkpoint, called when the Callback is back from suspension. Overridable""" # Callback
        pass                                                                     # Callback 
    def __getstate__(self): state = dict(self.__dict__); state.pop("l", None); state.pop("cbs", None); return state # Callback
    def __setstate__(self, state): self.__dict__.update(state)                   # Callback
    def __repr__(self): return f"{self._reprHead}, can...\n{self._reprCan}"      # Callback
    @property                                                                    # Callback
    def _reprHead(self): return f"Callback `{self.name}`"                        # Callback
    @property                                                                    # Callback
    def _reprCan(self): return """- cb.something: to get specific attribute "something" from learner if not available
- cb.withCheckpoint(checkpoint, f): to quickly insert an event handler
- cb.detach(): to remove itself from its parent Callbacks"""                     # Callback
[docs]    def withCheckpoint(self, checkpoint:str, f:Callable[["Callback"], None]):    # Callback
        """Quickly set a checkpoint, for simple, inline-able functions
:param checkpoint: checkpoints like "startRun"
:param f: function that takes in the Callback itself"""                          # Callback
        setattr(self, checkpoint, lambda: f(self)); return self                  # Callback 
    def __call__(self, checkpoint):                                              # Callback
        if not self.suspended and hasattr(self, checkpoint):                     # Callback
            return getattr(self, checkpoint)() != None                           # Callback
[docs]    def attached(self):                                                          # Callback
        """Called when this is added to a :class:`Callback`. Overrides this to
do custom stuff when this happens."""                                            # Callback
        pass                                                                     # Callback 
[docs]    def detach(self):                                                            # Callback
        """Detaches from the parent :class:`Callbacks`"""                        # Callback
        self.cbs.remove(self.name); return self                                  # Callback  
Cbs = k1lib.Object()                                                             # Callback
Callback.lossCls = k1lib.Object()                                                # Callback
[docs]class Timings:                                                                   # Timings
    """List of checkpoint timings. Not intended to be instantiated by the end user.
Used within :class:`~k1lib.callbacks.callbacks.Callbacks`, accessible via
:attr:`Callbacks.timings` to record time taken to execute a single
checkpoint. This is useful for profiling stuff."""                               # Timings
    @property                                                                    # Timings
    def state(self):                                                             # Timings
        answer = dict(self.__dict__); answer.pop("getdoc", None); return answer  # Timings
    @property                                                                    # Timings
    def checkpoints(self) -> List[str]:                                          # Timings
        """List of all checkpoints encountered"""                                # Timings
        return [cp for cp in self.state if k1lib.isNumeric(self[cp])]            # Timings
    def __getattr__(self, attr):                                                 # Timings
        if attr.startswith("_"): raise AttributeError()                          # Timings
        self.__dict__[attr] = 0; return 0                                        # Timings
    def __getitem__(self, idx): return getattr(self, idx)                        # Timings
    def __setitem__(self, idx, value): setattr(self, idx, value)                 # Timings
[docs]    def plot(self):                                                              # Timings
        """Plot all checkpoints' execution times"""                              # Timings
        plt.figure(dpi=100); checkpoints = self.checkpoints                      # Timings
        timings = np.array([self[cp] for cp in checkpoints])                     # Timings
        maxTiming = timings.max()                                                # Timings
        if maxTiming >= 1:                                                       # Timings
            plt.bar(checkpoints, timings); plt.ylabel("Time (s)")                # Timings
        elif maxTiming >= 1e-3 and maxTiming < 1:                                # Timings
            plt.bar(checkpoints, timings*1e3); plt.ylabel("Time (ms)")           # Timings
        elif maxTiming >= 1e-6 and maxTiming < 1e-3:                             # Timings
            plt.bar(checkpoints, timings*1e6); plt.ylabel("Time (us)")           # Timings
        plt.xticks(rotation="vertical"); plt.show()                              # Timings 
[docs]    def clear(self):                                                             # Timings
        """Clears all timing data"""                                             # Timings
        for cp in self.checkpoints: self[cp] = 0                                 # Timings 
    def __repr__(self):                                                          # Timings
        cps = '\n'.join([f'- {cp}: {self[cp]}' for cp in self.checkpoints])      # Timings
        return f"""Timings object. Checkpoints:\n{cps}\n
Can...
- t.startRun: to get specific checkpoint's execution time
- t.plot(): to plot all checkpoints"""                                           # Timings 
_time = time.time                                                                # Timings
[docs]class Callbacks:                                                                 # Callbacks
    def __init__(self):                                                          # Callbacks
        self._l: k1lib.Learner = None; self.cbsDict = {}                         # Callbacks
        self._timings = Timings(); self.contexts = [[]]                          # Callbacks
    @property                                                                    # Callbacks
    def timings(self) -> Timings:                                                # Callbacks
        """Returns :class:`~k1lib.callbacks.callbacks.Timings` object"""         # Callbacks
        return self._timings                                                     # Callbacks
    @property                                                                    # Callbacks
    def l(self) -> "k1lib.Learner":                                              # Callbacks
        """:class:`k1lib.Learner` object. Will be set automatically when
you set :attr:`k1lib.Learner.cbs` to this :class:`Callbacks`"""                  # Callbacks
        return self._l                                                           # Callbacks
    @l.setter                                                                    # Callbacks
    def l(self, learner):                                                        # Callbacks
        self._l = learner                                                        # Callbacks
        for cb in self.cbs: cb.l = learner                                       # Callbacks
    @property                                                                    # Callbacks
    def cbs(self) -> List[Callback]:                                             # Callbacks
        """List of :class:`Callback`"""                                          # Callbacks
        return [*self.cbsDict.values()] # convenience method for looping over stuff # Callbacks
    def _sort(self) -> "Callbacks":                                              # Callbacks
        self.cbsDict = OrderedDict(sorted(self.cbsDict.items(), key=(lambda o: o[1].order))); return self # Callbacks
[docs]    def add(self, cb:Callback, name:str=None):                                   # Callbacks
        """Adds a callback to the collection.
Example::
    cbs = k1lib.Callbacks()
    cbs.add(k1lib.Callback().withCheckpoint("startBatch", lambda self: print("here")))
If you just want to insert a simple callback with a single checkpoint, then
you can do something like::
    cbs.add(["startBatch", lambda _: print("here")])"""                          # Callbacks
        if isinstance(cb, (list, tuple)):                                        # Callbacks
            return self.add(Callback().withCheckpoint(cb[0], cb[1]))             # Callbacks
        if not isinstance(cb, Callback): raise RuntimeError("`cb` is not a callback!") # Callbacks
        if cb in self.cbs: cb.l = self.l; cb.cbs = self; return self             # Callbacks
        cb.l = self.l; cb.cbs = self; name = name or cb.name                     # Callbacks
        if name in self.cbsDict:                                                 # Callbacks
            i = 0                                                                # Callbacks
            while f"{name}{i}" in self.cbsDict: i += 1                           # Callbacks
            name = f"{name}{i}"                                                  # Callbacks
        cb.name = name; self.cbsDict[name] = cb; self._sort()                    # Callbacks
        self._appendContext_append(cb); cb("attached"); return self              # Callbacks 
[docs]    def __contains__(self, e:str) -> bool:                                       # Callbacks
        """Whether a specific Callback name is in this :class:`Callback`."""     # Callbacks
        return e in self.cbsDict                                                 # Callbacks 
[docs]    def remove(self, *names:List[str]):                                          # Callbacks
        """Removes a callback from the collection."""                            # Callbacks
        for name in names:                                                       # Callbacks
            if name not in self.cbsDict: return print(f"Callback `{name}` not found") # Callbacks
            cb = self.cbsDict[name]; del self.cbsDict[name]; cb("detached")      # Callbacks
        self._sort(); return self                                                # Callbacks 
[docs]    def removePrefix(self, prefix:str):                                          # Callbacks
        """Removes any callback with the specified prefix"""                     # Callbacks
        for cb in self.cbs:                                                      # Callbacks
            if cb.name.startswith(prefix): self.remove(cb.name)                  # Callbacks
        return self                                                              # Callbacks 
[docs]    def __call__(self, *checkpoints:List[str]) -> bool:                          # Callbacks
        """Calls a number of checkpoints one after another.
Returns True if any of the checkpoints return anything at all"""                 # Callbacks
        self._checkpointGraph_call(checkpoints)                                  # Callbacks
        answer = False                                                           # Callbacks
        for checkpoint in checkpoints:                                           # Callbacks
            beginTime = _time()                                                  # Callbacks
            answer |= any([cb(checkpoint) for cb in self.cbs])                   # Callbacks
            self._timings[checkpoint] += _time() - beginTime                     # Callbacks
        return answer                                                            # Callbacks 
[docs]    def __getitem__(self, idx:Union[int, str]) -> Callback:                      # Callbacks
        """Get specific cbs.
:param idx: if :class:`str`, then get the Callback with this specific name,
    if :class:`int`, then get the Callback in that index."""                     # Callbacks
        return self.cbs[idx] if isinstance(idx, int) else self.cbsDict[idx]      # Callbacks 
[docs]    def __iter__(self) -> Iterator[Callback]:                                    # Callbacks
        """Iterates through all :class:`Callback`."""                            # Callbacks
        for cb in self.cbsDict.values(): yield cb                                # Callbacks 
[docs]    def __len__(self):                                                           # Callbacks
        """How many :class:`Callback` are there in total?"""                     # Callbacks
        return len(self.cbsDict)                                                 # Callbacks 
    def __getattr__(self, attr):                                                 # Callbacks
        if attr == "cbsDict": raise AttributeError(attr)                         # Callbacks
        if attr in self.cbsDict: return self.cbsDict[attr]                       # Callbacks
        else: raise AttributeError(attr)                                         # Callbacks
    def __getstate__(self):                                                      # Callbacks
        state = dict(self.__dict__); state.pop("_l", None); return state         # Callbacks
    def __setstate__(self, state):                                               # Callbacks
        self.__dict__.update(state)                                              # Callbacks
        for cb in self.cbs: cb.cbs = self                                        # Callbacks
    def __dir__(self):                                                           # Callbacks
        answer = list(super().__dir__())                                         # Callbacks
        answer.extend(self.cbsDict.keys())                                       # Callbacks
        return answer                                                            # Callbacks
    def __repr__(self):                                                          # Callbacks
        return "Callbacks:\n" + '\n'.join([f"- {cbName}" for cbName in self.cbsDict if not cbName.startswith("_")]) + """\n
Use...
- cbs.add(cb[, name]): to add a callback with a name
- cbs("startRun"): to trigger a specific checkpoint, this case "startRun"
- cbs.Loss: to get a specific callback by name, this case "Loss"
- cbs[i]: to get specific callback by index
- cbs.timings: to get callback execution times
- cbs.checkpointGraph(): to graph checkpoint calling orders
- cbs.context(): context manager that will detach all Callbacks attached inside the context
- cbs.suspend("Loss", "Cuda"): context manager to temporarily prevent triggering checkpoints""" # Callbacks
[docs]    def withBasics(self):                                                        # Callbacks
        """Adds a bunch of very basic Callbacks that's needed for everything. Also
includes Callbacks that are not necessary, but don't slow things down"""         # Callbacks
        self.add(Cbs.CoreNormal()).add(Cbs.Profiler()).add(Cbs.Recorder())       # Callbacks
        self.add(Cbs.ProgressBar()).add(Cbs.Loss()).add(Cbs.Accuracy()).add(Cbs.DontTrainValid()) # Callbacks
        return self.add(Cbs.CancelOnExplosion()).add(Cbs.ParamFinder())          # Callbacks 
[docs]    def withQOL(self):                                                           # Callbacks
        """Adds quality of life Callbacks."""                                    # Callbacks
        return self                                                              # Callbacks 
[docs]    def withAdvanced(self):                                                      # Callbacks
        """Adds advanced Callbacks that do fancy stuff, but may slow things
down if not configured specifically."""                                          # Callbacks
        return self.add(Cbs.HookModule().withMeanRecorder().withStdRecorder()).add(Cbs.HookParam()) # Callbacks  
@k1lib.patch(Callbacks)                                                          # Callbacks
def _resolveDependencies(self):                                                  # _resolveDependencies
    for cb in self.cbs:                                                          # _resolveDependencies
        cb._dependents:Set[Callback] = set()                                     # _resolveDependencies
        cb.dependsOn = set(cb.dependsOn)                                         # _resolveDependencies
    for cb in self.cbs:                                                          # _resolveDependencies
        for cb2 in self.cbs:                                                     # _resolveDependencies
            if cb2.__class__.__name__ in cb.dependsOn:                           # _resolveDependencies
                cb2._dependents.add(cb)                                          # _resolveDependencies
class SuspendContext:                                                            # SuspendContext
    def __init__(self, cbs:Callbacks, cbsNames:List[str], cbsClasses:List[str]): # SuspendContext
        self.cbs = cbs; self.cbsNames = cbsNames; self.cbsClasses = cbsClasses   # SuspendContext
        self.cbs.suspendStack = getattr(self.cbs, "suspendStack", [])            # SuspendContext
    def __enter__(self):                                                         # SuspendContext
        cbsClasses = set(self.cbsClasses); cbsNames = set(self.cbsNames)         # SuspendContext
        self._resolveDependencies()                                              # SuspendContext
        def explore(cb:Callback):                                                # SuspendContext
            for dept in cb._dependents:                                          # SuspendContext
                cbsClasses.add(dept.__class__.__name__); explore(dept)           # SuspendContext
        [explore(cb) for cb in self.cbs if cb.__class__.__name__ in cbsClasses or cb.name in cbsNames] # SuspendContext
        stackFrame = {cb:cb.suspended for cb in self.cbs if cb.__class__.__name__ in cbsClasses or cb.name in cbsNames} # SuspendContext
        for cb in stackFrame: cb.suspend(); cb.suspended = True                  # SuspendContext
        self.suspendStack.append(stackFrame)                                     # SuspendContext
    def __exit__(self, *ignored):                                                # SuspendContext
        for cb, oldValue in self.suspendStack.pop().items():                     # SuspendContext
            cb.suspended = oldValue; cb.restore()                                # SuspendContext
    def __getattr__(self, attr): return getattr(self.cbs, attr)                  # SuspendContext
@k1lib.patch(Callbacks)                                                          # SuspendContext
def suspend(self, *cbNames:List[str]) -> ContextManager:                         # suspend
    """Creates suspension context for specified Callbacks. Matches callbacks with
their name. Works like this::
    cbs = k1lib.Callbacks().add(CbA()).add(CbB()).add(CbC())
    with cbs.suspend("CbA", "CbC"):
        pass # inside here, only CbB will be active, and its checkpoints executed
    # CbA, CbB and CbC are all active
.. seealso:: :meth:`suspendClasses`"""                                           # suspend
    return SuspendContext(self, cbNames, [])                                     # suspend
@k1lib.patch(Callbacks)                                                          # suspend
def suspendClasses(self, *classNames:List[str]) -> ContextManager:               # suspendClasses
    """Like :meth:`suspend`, but matches callbacks' class names to the given list,
instead of matching names. Meaning::
    cbs.k1lib.Callbacks().add(Cbs.Loss()).add(Cbs.Loss())
    # cbs now has 2 callbacks "Loss" and "Loss0"
    with cbs.suspendClasses("Loss"):
        pass # now both of them are suspended"""                                 # suspendClasses
    return SuspendContext(self, [], classNames)                                  # suspendClasses
@k1lib.patch(Callbacks)                                                          # suspendClasses
def suspendEval(self, more:List[str]=[], less:List[str]=[]) -> ContextManager:   # suspendEval
    """Same as :meth:`suspendClasses`, but suspend some default classes typical
used for evaluation callbacks. Just convenience method really. Currently includes:
- HookModule, HookParam, ProgressBar
- ParamScheduler, Loss, Accuracy, Autosave
- ConfusionMatrix
:param more: include more classes to be suspended
:param less: exclude classes supposed to be suspended by default"""              # suspendEval
    classes = {"HookModule", "HookParam", "ProgressBar", "ParamScheduler", "Loss", "Accuracy", "Autosave", "ConfusionMatrix"} # suspendEval
    classes.update(more); classes -= set(less)                                   # suspendEval
    return self.suspendClasses(*classes)                                         # suspendEval
class AppendContext:                                                             # AppendContext
    def __init__(self, cbs:Callbacks, initCbs:List[Callback]=[]):                # AppendContext
        self.cbs = cbs; self.initCbs = initCbs                                   # AppendContext
    def __enter__(self):                                                         # AppendContext
        self.cbs.contexts.append([])                                             # AppendContext
        for cb in self.initCbs: self.cbs.add(cb)                                 # AppendContext
        return self.cbs                                                          # AppendContext
    def __exit__(self, *ignored):                                                # AppendContext
        [cb.detach() for cb in self.cbs.contexts.pop()]                          # AppendContext
@k1lib.patch(Callbacks)                                                          # AppendContext
def _appendContext_append(self, cb):                                             # _appendContext_append
    self.contexts[-1].append(cb)                                                 # _appendContext_append
@k1lib.patch(Callbacks)                                                          # _appendContext_append
def context(self, *initCbs:List[Callback]) -> ContextManager:                    # context
    """Add context.
Works like this::
    cbs = k1lib.Callbacks().add(CbA())
    # CbA is available
    with cbs.context(CbE(), CbF()):
        cbs.add(CbB())
        # CbA, CbB, CbE and CbF available
        cbs.add(CbC())
        # all 5 are available
    # only CbA is available
For maximum shortness, you can do this::
    with k1lib.Callbacks().context(CbA()) as cbs:
        # Cba is available
"""                                                                              # context
    return AppendContext(self, initCbs)                                          # context
@k1lib.patch(Callbacks)                                                          # context
def _checkpointGraph_call(self, checkpoints:List[str]):                          # _checkpointGraph_call
    if not hasattr(self, "_checkpointGraphDict"):                                # _checkpointGraph_call
        self._checkpointGraphDict = k1lib.Object().withAutoDeclare(lambda: k1lib.Object().withAutoDeclare(lambda: 0)) # _checkpointGraph_call
        self._lastCheckpoint = "<root>"                                          # _checkpointGraph_call
    for cp in checkpoints:                                                       # _checkpointGraph_call
        self._checkpointGraphDict[self._lastCheckpoint][cp] += 1                 # _checkpointGraph_call
        self._lastCheckpoint = cp                                                # _checkpointGraph_call
@k1lib.patch(Callbacks)                                                          # _checkpointGraph_call
def checkpointGraph(self, highlightCb:Union[str, Callback]=None):                # checkpointGraph
    """Graphs what checkpoints follows what checkpoints. Has to run at least once
first. Requires graphviz package though. Example::
    cbs = Callbacks()
    cbs("a", "b", "c", "d", "b")
    cbs.checkpointGraph() # returns graph object. Will display image if using notebooks
.. image:: ../images/checkpointGraph.png
:param highlightCb: if available, will highlight the checkpoints the callback
    uses. Can be name/class-name/class/self of callback."""                      # checkpointGraph
    g = k1lib.digraph(); s = set()                                               # checkpointGraph
    for cp1, cp1o in self._checkpointGraphDict.state.items():                    # checkpointGraph
        for cp2, v in cp1o.state.items():                                        # checkpointGraph
            g.edge(cp1, cp2, label=f"  {v}  "); s.add(cp2)                       # checkpointGraph
    if highlightCb != None:                                                      # checkpointGraph
        _cb = None                                                               # checkpointGraph
        if isinstance(highlightCb, Callback): _cb = highlightCb                  # checkpointGraph
        elif isinstance(highlightCb, type) and issubclass(highlightCb, Callback): # find cb that has the same class # checkpointGraph
            for cbo in self.cbs:                                                 # checkpointGraph
                if isinstance(cbo, highlightCb): _cb = cbo; break                # checkpointGraph
            if _cb is None: raise AttributeError(f"Can't find any Callback inside this Callbacks which is of type `{cb.__name__}`") # checkpointGraph
        elif isinstance(highlightCb, str):                                       # checkpointGraph
            for cbName, cbo in self.cbsDict.items():                             # checkpointGraph
                if cbName == highlightCb: _cb = cbo; break                       # checkpointGraph
                if type(cbo).name == highlightCb: _cb = cbo; break               # checkpointGraph
            if _cb is None: raise AttributeError(f"Can't find any Callback inside this Callbacks with name or class `{cb}`") # checkpointGraph
        else: raise AttributeError(f"Don't understand {cb}")                     # checkpointGraph
        print(f"Highlighting callback `{_cb.name}`, of type `{type(_cb)}`")      # checkpointGraph
        for cp in s:                                                             # checkpointGraph
            if hasattr(_cb, cp): g.node(cp, color="red")                         # checkpointGraph
    return g                                                                     # checkpointGraph