# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib, numpy as np, time
import matplotlib.pyplot as plt
from typing import Callable
import k1lib.cli as cli
try: import torch; hasTorch = True
except: hasTorch = False
__all__ = ["Landscape"]
spacing = 0.35 # orders of magnitude
offset = -2 # orders of magnitude shift
res = 20 # resolution
scales = 10**(np.array(range(8))*spacing + offset)
scales = [round(scale, 3) for scale in scales]
scales
F = Callable[["k1lib.Learner"], float]
[docs]@k1lib.patch(Cbs)
class Landscape(Callback): # Landscape
" " # Landscape
[docs] def __init__(self, propertyF:F, name:str=None): # Landscape
"""Plots the landscape of the network.
:param propertyF: a function that takes in :class:`k1lib.Learner` and outputs the
desired float property
.. warning::
Remember to detach anything you get from :class:`k1lib.Learner` in your
function, or else you're gonna cause a huge memory leak.
""" # Landscape
super().__init__(); self.propertyF = propertyF; self.suspended = True # Landscape
self.name = name or self.name; self.order = 23; self.parent:Callback = None # Landscape
[docs] def startRun(self): self.originalParams = self.l.model.exportParams() # Landscape
[docs] def endRun(self): self.l.model.importParams(self.originalParams) # Landscape
[docs] def startPass(self): # Landscape
next(self.iter) # Landscape
for param, og, v1, v2 in zip(self.l.model.parameters(), self.originalParams, *self.vs): # Landscape
param.data = og + self.x * v1 + self.y * v2 # Landscape
[docs] def endLoss(self): # Landscape
prop = self.propertyF(self.l) # Landscape
self.zs[self.ix, self.iy] = prop if prop == prop else 0 # check for nan # Landscape
if self.l.batch % 10: print(f"\rProgress: {round(100*(self.ix+self.iy/res)/res)}%, {round(time.time()-self.beginTime)}s ", end="") # Landscape
[docs] def startBackward(self): return True # Landscape
[docs] def startStep(self): return True # Landscape
[docs] def startZeroGrad(self): return True # Landscape
def __iter__(self): # Landscape
"""This one is the "core running loop", if you'd like to say so. Because
this needs to be sort of event-triggered (by checkpoint "startPass"), so kinda have
to put this into an iterator so that it's not the driving thread.""" # Landscape
self.zss = [] # debug data # Landscape
for i, (scale, ax) in enumerate(zip(scales, self.axes)): # Landscape
a = torch.linspace(-scale, scale, res) # Landscape
xs, ys = np.meshgrid(a, a); self.zs = np.empty((res, res)) # Landscape
xs = torch.tensor(xs); ys = torch.tensor(ys) # Landscape
for ix in range(res): # Landscape
for iy in range(res): # Landscape
self.x = xs[ix, iy]; self.y = ys[ix, iy] # Landscape
self.ix, self.iy = ix, iy; yield True # Landscape
self.zs[self.zs == float("inf")] = 0 # Landscape
ax.plot_surface(xs, ys, self.zs, cmap=plt.cm.coolwarm) # Landscape
self.zss.append(self.zs) # Landscape
print(f" {i+1}/8 Finished [{-scale}, {scale}] range ", end="") # Landscape
raise k1lib.CancelRunException("Landscape finished") # Landscape
[docs] def plot(self): # Landscape
"""Creates the landscapes and show plots""" # Landscape
self.suspended = False; self.iter = iter(self); self.beginTime = time.time() # Landscape
def inner(): # Landscape
self.vs = [self.l.model.getParamsVector(), self.l.model.getParamsVector()] # Landscape
fig, axes = plt.subplots(2, 4, subplot_kw={"projection": "3d"}, figsize=(16, 8), dpi=120) # Landscape
self.axes = axes.flatten(); self.l.run(1000000) # Landscape
try: # Landscape
with self.cbs.suspendEval(), torch.no_grad(): inner() # Landscape
except: pass # Landscape
self.suspended = True; self.iter = None # Landscape
def __repr__(self): # Landscape
return f"""{super()._reprHead}, use...
- l.plot(): to plot everything
{super()._reprCan}""" # Landscape