k1lib¶

PyTorch is awesome, and it provides a very effective way to execute ML code fast. What it lacks is surrounding infrastructure to make general debugging and discovery process better. Other more official wrapper frameworks sort of don't make sense to me, so this is an attempt at recreating a robust suite of tools that makes sense.

Table of contents:

  • Overview
    • ParamFinder
    • Loss
    • LossLandscape
    • HookParam
    • HookModule
  • CSS module selector
  • Data loader
  • Callbacks

Let's see an example:

Overview¶

In [1]:
from k1lib.imports import *

k1lib.imports is just a file that imports lots of common utilities, so that importing stuff is easier and quicker.

In [2]:
class SkipBlock(nn.Module):
    def __init__(self, hiddenDim=10):
        super().__init__()
        def gen(): return nn.Linear(hiddenDim, hiddenDim), nn.LeakyReLU()
        self.seq = nn.Sequential(*gen(), *gen(), *gen())
    def forward(self, x):
        return self.seq(x) + x
In [3]:
class Network(nn.Module):
    def __init__(self, hiddenDim=10, blocks=3, block=SkipBlock):
        super().__init__()
        layers = [nn.Linear(1, hiddenDim), nn.LeakyReLU()]
        layers += [block(hiddenDim) for _ in range(blocks)]
        layers += [nn.Linear(hiddenDim, 1)]
        self.bulk = nn.Sequential(*layers)
    def forward(self, x):
        return self.bulk(x)

Here is our network. Just a normal feed-forward network, with skip blocks in the middle.

In [4]:
def dataF(bs=32, epochs=200):
    return torch.linspace(-5, 5, 1000) | apply(op().item()) | apply(lambda x: (x, math.exp(x))) | randomize(None) | splitW() |\
    (repeatFrom() | batched(bs) | (transpose() | ((unsqueeze(1) | toTensor()) + toTensor())).all()).all() | stagger.tv(epochs) | toList()
def newL(*args, **kwargs):
    l = k1lib.Learner()
    l.model = Network(*args, **kwargs)
    l.data = dataF(64, 200)
    l.opt = optim.Adam(l.model.parameters(), lr=1e-2)
    l.lossF = lambda x, y: ((x.squeeze() - y)**2).mean()

    l.cbs.add(Cbs.ModifyBatch(lambda x, y: (x[:, None], y)))
    l.cbs.add(Cbs.DType(torch.float32))
    l.cbs.add(Cbs.CancelOnLowLoss(1, epochMode=True))
    l.css = """SkipBlock #0: HookParam
SkipBlock: HookModule"""

    def evaluate(self):
        xbs, ybs, ys = self.Recorder.record(1, 3)
        xbs = torch.vstack(xbs).squeeze()
        ybs = torch.vstack([yb[:, None] for yb in ybs]).squeeze()
        ys = torch.vstack(ys).squeeze()
        plt.plot(xbs, ys.detach(), ".")
    l.evaluate = partial(evaluate, l)
    return l
l = newL()
l.run(10);
Progress:  30%, epoch:  3/10, batch:   0/200, elapsed:   1.49s, loss: 0.018542366102337837             Run cancelled: Low loss 1 ([10.633015524595976, 2.2217107348144056, 0.0817870583734475] actual) achieved!.

Here is where things get a little more interesting. k1lib.Learner is the main wrapper where training will take place. It has 4 basic parameters that must be set before training: model, data loader, optimizer, and loss function.

Tip: docs are tailored for each object so you can do print(obj) or just obj in a code cell

In [5]:
l.cbs
Out[5]:
Callbacks:
- CoreNormal
- Profiler
- ProgressBar
- DontTrainValid
- HookModule
- HookParam
- LossF
- DType
- ModifyBatch
- Recorder
- Loss
- Accuracy
- ParamFinder
- CancelOnExplosion
- CancelOnLowLoss

Use...
- cbs.add(cb[, name]): to add a callback with a name
- cbs("startRun"): to trigger a specific checkpoint, this case "startRun"
- cbs.Loss: to get a specific callback by name, this case "Loss"
- cbs[i]: to get specific callback by index
- cbs.timings: to get callback execution times
- cbs.checkpointGraph(): to graph checkpoint calling orders
- cbs.context(): context manager that will detach all Callbacks attached inside the context
- cbs.suspend("Loss", "Cuda"): context manager to temporarily prevent triggering checkpoints

There're lots of Callbacks. What they are will be discussed later, but here's a tour of a few of them:

ParamFinder¶

In [6]:
l = newL(); l.ParamFinder.plot(samples=1000)[:0.99]
Progress:   0%, epoch:    1/1000, batch:  40/200, elapsed:   0.53s, loss: 2318.115478515625             Run cancelled: Loss increases significantly.
Suggested param: 4.100944749601106e-05
No description has been provided for this image
Out[6]:
Sliceable plot. Can...
- p[a:b]: to focus on a specific range of the plot
- p.yscale("log"): to perform operation as if you're using plt

Reminder: slice range here is actually [0, 1], because it's kinda hard to slice the normal way

As advertised, this callback searches for a perfect parameter for the network.

Loss¶

In [7]:
l = newL(); l.run(10); l.Loss
Progress:  20%, epoch:  2/10, batch:   0/200, elapsed:   1.05s, loss: 0.33111751079559326             Run cancelled: Low loss 1 ([9.34933760613203, 0.24827303597703576] actual) achieved!.
Out[7]:
Callback `Loss`, use...
- cb.train: for all training losses over all epochs and batches (#epochs * #batches)
- cb.valid: for all validation losses over all epochs and batches (#epochs * #batches)
- cb.plot(): to plot the 2 above
- cb.epoch: for average losses of each epochs
- cb.Landscape: for loss-landscape-plotting Callback
- cb.something: to get specific attribute "something" from learner if not available
- cb.withCheckpoint(checkpoint, f): to quickly insert an event handler
- cb.detach(): to remove itself from its parent Callbacks
In [8]:
l.Loss.plot()
No description has been provided for this image
Out[8]:
Sliceable plot. Can...
- p[a:b]: to focus on a specific range of the plot
- p.yscale("log"): to perform operation as if you're using plt

Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame

Data type returned is k1lib.viz.SliceablePlot, so you can zoom the plot in a specific range, like this:

In [9]:
l.Loss.plot()[120:]
No description has been provided for this image
Out[9]:
Sliceable plot. Can...
- p[a:b]: to focus on a specific range of the plot
- p.yscale("log"): to perform operation as if you're using plt

Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame

Notice how original train range is [0, 250], and valid range is [0, 60]. When sliced with [120:], train's range sliced as planned from the middle to end, and valid's range adapting and also sliced from middle to end ([30:]).

LossLandscape¶

In [10]:
l.Loss.Landscape.plot()
Progress: 100%, 4s           8/8 Finished [-2.818, 2.818] range              Run cancelled: Landscape finished.
No description has been provided for this image
In [11]:
l.Loss.Landscape.plot()
Progress: 100%, 4s           8/8 Finished [-2.818, 2.818] range              Run cancelled: Landscape finished.
No description has been provided for this image

Oh and yeah, this callback can give you a quick view into how the landscape is. The center point (0, 0) is always the lowest portion of the landscape, so that tells us the network has learned stuff.

HookParam¶

In [12]:
l.HookParam
Out[12]:
Callback `HookParam`: 6 params, 134 means and stds each:
  0. bulk.2.seq.0.weight
  1. bulk.2.seq.0.bias
  2. bulk.3.seq.0.weight
  3. bulk.3.seq.0.bias
  4. bulk.4.seq.0.weight
  5. bulk.4.seq.0.bias

Use...
- p.plot(): to quickly look at everything
- p[i]: to view a single param
- p[a:b]: to get a new HookParam with selected params
- p.css("..."): to select a specific subset of modules only
- cb.something: to get specific attribute "something" from learner if not available
- cb.withCheckpoint(checkpoint, f): to quickly insert an event handler
- cb.detach(): to remove itself from its parent Callbacks
In [13]:
l.HookParam.plot()
No description has been provided for this image
Out[13]:
Sliceable plot. Can...
- p[a:b]: to focus on a specific range of the plot
- p.yscale("log"): to perform operation as if you're using plt

This tracks parameters' means, stds, mins and maxs while training. You can also display only certain number of parameters:

In [14]:
l.HookParam[::2].plot()[50:]
No description has been provided for this image
Out[14]:
Sliceable plot. Can...
- p[a:b]: to focus on a specific range of the plot
- p.yscale("log"): to perform operation as if you're using plt

HookModule¶

In [15]:
l.HookModule.plot()
No description has been provided for this image
Out[15]:
Sliceable plot. Can...
- p[a:b]: to focus on a specific range of the plot
- p.yscale("log"): to perform operation as if you're using plt

Pretty much same thing as before. This callback hooks into selected modules, and captures the forward and backward passes. Both HookParam and HookModule will only hook into selected modules (by default all is selected):

In [16]:
l.selector
Out[16]:
ModuleSelector:
root: Network                       
    bulk: Sequential                
        0: Linear                       
        1: LeakyReLU                    
        2: SkipBlock                HookModule
            seq: Sequential         
                0: Linear           HookParam    
                1: LeakyReLU            
                2: Linear               
                3: LeakyReLU            
                4: Linear               
                5: LeakyReLU            
        3: SkipBlock                HookModule
            seq: Sequential         
                0: Linear           HookParam    
                1: LeakyReLU            
                2: Linear               
                3: LeakyReLU            
                4: Linear               
                5: LeakyReLU            
        4: SkipBlock                HookModule
            seq: Sequential         
                0: Linear           HookParam    
                1: LeakyReLU            
                2: Linear               
                3: LeakyReLU            
                4: Linear               
                5: LeakyReLU            
        5: Linear                       

Can...
- mS.deepestDepth: get deepest depth possible
- mS.nn: get the underlying nn.Module object
- mS.apply(f): apply to self and all descendants
- "HookModule" in mS: whether this module has a specified prop
- mS.highlight(prop): highlights all modules with specified prop
- mS.parse([..., ...]): parses extra css
- mS.directParams: get Dict[str, nn.Parameter] that are directly under this module

CSS module selector¶

You can select specific modules by setting l.css = ..., kinda like this:

In [17]:
l = newL()
l.css = """
#bulk > Linear: a
#bulk > #1: b
SkipBlock Sequential: c
SkipBlock LeakyReLU
"""
l.selector
Out[17]:
ModuleSelector:
root: Network                       
    bulk: Sequential                
        0: Linear                   a    
        1: LeakyReLU                b    
        2: SkipBlock                
            seq: Sequential         c
                0: Linear               
                1: LeakyReLU        *    
                2: Linear               
                3: LeakyReLU        *    
                4: Linear               
                5: LeakyReLU        *    
        3: SkipBlock                
            seq: Sequential         c
                0: Linear               
                1: LeakyReLU        *    
                2: Linear               
                3: LeakyReLU        *    
                4: Linear               
                5: LeakyReLU        *    
        4: SkipBlock                
            seq: Sequential         c
                0: Linear               
                1: LeakyReLU        *    
                2: Linear               
                3: LeakyReLU        *    
                4: Linear               
                5: LeakyReLU        *    
        5: Linear                   a    

Can...
- mS.deepestDepth: get deepest depth possible
- mS.nn: get the underlying nn.Module object
- mS.apply(f): apply to self and all descendants
- "HookModule" in mS: whether this module has a specified prop
- mS.highlight(prop): highlights all modules with specified prop
- mS.parse([..., ...]): parses extra css
- mS.directParams: get Dict[str, nn.Parameter] that are directly under this module

Essentially, you can:

  • a: to select modules with name "a"¶

  • b: to select modules with class name "b"
  • a #b: to select modules with name "b" under modules with class "a"
  • a > #b: to select modules with name "b" directly under modules with class "a"
  • "#a: infinity war": to assign selected module with properties "infinity" and "war"

Different callbacks will recognize certain props. HookModule will hook all modules with props "all" or "HookModule". Likewise, HookParam will hook all parameters with props "all" or "HookParam".

Data loader¶

In [18]:
l.data
Out[18]:
[<k1lib.cli.modifier.StaggeredStream at 0x7f8ab54c6ca0>,
 <k1lib.cli.modifier.StaggeredStream at 0x7f8ab54fbb50>]
In [19]:
for xb, yb in l.data[0]:
    print(xb.shape, yb.shape)
    break
torch.Size([32]) torch.Size([32])

It's simple, really! l.data contains a train and valid data loader, and each "dispenses" a batch as usual.

Callbacks¶

Let's look at l again:

In [20]:
l
Out[20]:
l.model:
    Network(
      (bulk): Sequential(
        (0): Linear(in_features=1, out_features=10, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
        (2): SkipBlock(
          (seq): Sequential(
            (0): Linear(in_features=10, out_features=10, bias=True)
            (1): LeakyReLU(negative_slope=0.01)
            (2): Linear(in_features=10, out_features=10, bias=True)
            (3): LeakyReLU(negative_slope=0.01)
    .....
l.opt:
    Adam (
    Parameter Group 0
        amsgrad: False
        betas: (0.9, 0.999)
        eps: 1e-08
        lr: 0.01
        weight_decay: 0
    )
l.cbs:
    Callbacks:
    - CoreNormal
    - Profiler
    - ProgressBar
    - DontTrainValid
    - HookModule
    - HookParam
    - LossF
    - DType
    - ModifyBatch
    .....
Use...
- l.model = ...: to specify a nn.Module object
- l.data = ...: to specify data object
- l.opt = ...: to specify an optimizer
- l.lossF = ...: to specify a loss function
- l.css = ...: to select modules using CSS. "#root" for root model
- l.cbs = ...: to use a custom `Callbacks` object
- l.selector: to get the modules selected by `l.css`
- l.run(epochs): to run the network
- l.Loss: to get a specific callback, this case "Loss"

l.model and l.opt is simple enough. It's just PyTorch's primitives. The part where most of the magic lies is in l.cbs, an object of type k1lib.Callbacks, a container object of k1lib.Callback. Notice the final "s" in the name.

A callback is pretty simple. While training, you may want to sort of insert functionality here and there. Let's say you want the program to print out a progress bar after each epoch. You can edit the learning loop directly, with some internal variables to keep track of the current epoch and batch, like this:

startTime = time.time()
for epoch in epochs:
    for batch in batches:
        # do training
        data = getData()
        train(data)
        
        # calculate progress
        elapsedTime = time.time() - startTime
        progress = round((batch / batches + epoch) / epochs * 100)
        print(f"\rProgress: {progress}%, elapsed: {round(elapsedTime, 2)}s         ", end="")

But this means when you don't want that functionality anymore, you have to know what internal variable belongs to the progress bar, and you have to delete it. With callbacks, things work a little bit differently:

class ProgressBar(k1lib.Callback):
    def startRun(self):
        pass
    def startBatch(self):
        self.progress = round((self.batch / self.batches + self.epoch) / self.epochs * 100)
        a = f"Progress: {self.progress}%"
        b = f"epoch: {self.epoch}/{self.epochs}"
        c = f"batch: {self.batch}/{self.batches}"
        print(f"{a}, {b}, {c}")

class Learner:
    def run(self):
        self.epochs = 1; self.batches = 10

        self.cbs = k1lib.Callbacks()
        self.cbs.append(ProgressBar())

        self.cbs("startRun")
        for self.epoch in self.epochs:
            self.cbs("startEpoch")
            for self.batch in self.batches:
                self.xb, self.yb = getData()
                self.cbs("startBatch")

                # do training
                self.y = self.model(data); self.cbs("endPass")
                self.loss = self.lossF(self.y); self.cbs("endLoss")
                if self.cbs("startBackward"): self.loss.backward()

                self.cbs("endBatch")
            self.cbs("endEpoch")
        self.cbs("endRun")

This is a stripped down version of k1lib.Learner, to get the idea across. Point is, whenever you do self.cbs("startRun"), it will run through all k1lib.Callback that it has (ProgressBar in this example), check if it implements startRun, and if yes, executes it.

Inside ProgressBar's startBatch, you can access learner's current epoch by doing self.learner.epoch. But you can also do self.epoch alone. If the attribute is not defined, then it will automatically be searched inside self.learner.

As you can see, if you want to get rid of the progress bar without using k1lib.Callbacks, you have to delete the startTime line and the actual calculate progress lines. This requires you to remember which lines belongs to which functionality. If you use the k1lib.Callbacks mechanism instead, then you can just uncomment self.cbs.append(ProgressBar()), and that's it. This makes swapping out components extremely easy, repeatable, and beautiful.

Other use cases include intercepting at startBatch, and push all the training data to the GPU. You can also reshape the data however you want. You can insert different loss mechanisms (endLoss) in addition to lossF, or quickly inspect the model output. You can also change learning rates while training (startEpoch) according to some schedules. The possibility are literally endless.

In [ ]: