Source code for k1lib.callbacks.landscape

# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib, numpy as np, time
import matplotlib.pyplot as plt
from typing import Callable
import k1lib.cli as cli
try: import torch; hasTorch = True
except: hasTorch = False
__all__ = ["Landscape"]
spacing = 0.35 # orders of magnitude
offset = -2 # orders of magnitude shift
res = 20 # resolution
scales = 10**(np.array(range(8))*spacing + offset)
scales = [round(scale, 3) for scale in scales]
scales
F = Callable[["k1lib.Learner"], float]
[docs]@k1lib.patch(Cbs) class Landscape(Callback): # Landscape " " # Landscape
[docs] def __init__(self, propertyF:F, name:str=None): # Landscape """Plots the landscape of the network. :param propertyF: a function that takes in :class:`k1lib.Learner` and outputs the desired float property .. warning:: Remember to detach anything you get from :class:`k1lib.Learner` in your function, or else you're gonna cause a huge memory leak. """ # Landscape super().__init__(); self.propertyF = propertyF; self.suspended = True # Landscape self.name = name or self.name; self.order = 23; self.parent:Callback = None # Landscape
[docs] def startRun(self): self.originalParams = self.l.model.exportParams() # Landscape
[docs] def endRun(self): self.l.model.importParams(self.originalParams) # Landscape
[docs] def startPass(self): # Landscape next(self.iter) # Landscape for param, og, v1, v2 in zip(self.l.model.parameters(), self.originalParams, *self.vs): # Landscape param.data = og + self.x * v1 + self.y * v2 # Landscape
[docs] def endLoss(self): # Landscape prop = self.propertyF(self.l) # Landscape self.zs[self.ix, self.iy] = prop if prop == prop else 0 # check for nan # Landscape if self.l.batch % 10: print(f"\rProgress: {round(100*(self.ix+self.iy/res)/res)}%, {round(time.time()-self.beginTime)}s ", end="") # Landscape
[docs] def startBackward(self): return True # Landscape
[docs] def startStep(self): return True # Landscape
[docs] def startZeroGrad(self): return True # Landscape
def __iter__(self): # Landscape """This one is the "core running loop", if you'd like to say so. Because this needs to be sort of event-triggered (by checkpoint "startPass"), so kinda have to put this into an iterator so that it's not the driving thread.""" # Landscape self.zss = [] # debug data # Landscape for i, (scale, ax) in enumerate(zip(scales, self.axes)): # Landscape a = torch.linspace(-scale, scale, res) # Landscape xs, ys = np.meshgrid(a, a); self.zs = np.empty((res, res)) # Landscape xs = torch.tensor(xs); ys = torch.tensor(ys) # Landscape for ix in range(res): # Landscape for iy in range(res): # Landscape self.x = xs[ix, iy]; self.y = ys[ix, iy] # Landscape self.ix, self.iy = ix, iy; yield True # Landscape self.zs[self.zs == float("inf")] = 0 # Landscape ax.plot_surface(xs, ys, self.zs, cmap=plt.cm.coolwarm) # Landscape self.zss.append(self.zs) # Landscape print(f" {i+1}/8 Finished [{-scale}, {scale}] range ", end="") # Landscape raise k1lib.CancelRunException("Landscape finished") # Landscape
[docs] def plot(self): # Landscape """Creates the landscapes and show plots""" # Landscape self.suspended = False; self.iter = iter(self); self.beginTime = time.time() # Landscape def inner(): # Landscape self.vs = [self.l.model.getParamsVector(), self.l.model.getParamsVector()] # Landscape fig, axes = plt.subplots(2, 4, subplot_kw={"projection": "3d"}, figsize=(16, 8), dpi=120) # Landscape self.axes = axes.flatten(); self.l.run(1000000) # Landscape try: # Landscape with self.cbs.suspendEval(), torch.no_grad(): inner() # Landscape except: pass # Landscape self.suspended = True; self.iter = None # Landscape
def __repr__(self): # Landscape return f"""{super()._reprHead}, use... - l.plot(): to plot everything {super()._reprCan}""" # Landscape