# AUTOGENERATED FILE! PLEASE DON'T EDIT
from k1lib.callbacks import Callback, Callbacks
import k1lib, torch, numpy as np, time
import matplotlib.pyplot as plt
from typing import Callable
__all__ = ["Landscape"]
spacing = 0.35 # orders of magnitude
offset = -2 # orders of magnitude shift
res = 30 # resolution
scales = 10**(np.array(range(8))*spacing + offset)
scales = [round(scale, 3) for scale in scales]
scales
F = Callable[["k1lib.Learner"], float]
[docs]@k1lib.patch(Callback.cls)
class Landscape(Callback):
" "
[docs] def __init__(self, propertyF:F, name:str=None):
"""Plots the landscape of the network.
:param propertyF: a function that takes in :class:`k1lib.Learner` and outputs the
desired float property"""
super().__init__(); self.propertyF = propertyF; self.suspended = True
self.name = name or self.name; self.order = 23; self.parent:Callback = None
def startRun(self): self.originalParams = self.l.model.exportParams()
def endRun(self): self.l.model.importParams(self.originalParams)
def startPass(self):
next(self.iter)
for param, og, v1, v2 in zip(self.l.model.parameters(), self.originalParams, *self.vs):
param.data = og + self.x * v1 + self.y * v2
def endLoss(self):
prop = self.propertyF(self.l)
self.zs[self.ix, self.iy] = prop if prop == prop else 0 # check for nan
if self.l.batch % 10: print(f"\rProgress: {round(100*(self.ix+self.iy/res)/res)}%, {round(time.time()-self.beginTime)}s ", end="")
def startBackward(self): return True
def startStep(self): return True
def startZeroGrad(self): return True
def __iter__(self):
"""This one is the "core running loop", if you'd like to say so. Because
this needs to be sort of event-triggered (by checkpoint "startPass"), so kinda have
to put this into an iterator so that it's not the driving thread."""
for i, (scale, ax) in enumerate(zip(scales, self.axes)):
a = torch.linspace(-scale, scale, res)
xs, ys = torch.meshgrid(a, a); self.zs = np.empty((res, res))
for ix in range(res):
for iy in range(res):
self.x = xs[ix, iy]; self.y = ys[ix, iy]
self.ix, self.iy = ix, iy; yield True
ax.plot_surface(xs, ys, self.zs, cmap=plt.cm.coolwarm)
print(f" {i+1}/8 Finished [{-scale}, {scale}] range ", end="")
raise k1lib.CancelRunException("Landscape finished")
[docs] def plot(self):
"""Creates the landscapes and show plots"""
self.suspended = False; self.iter = iter(self); self.beginTime = time.time()
def inner():
self.vs = [self.l.model.getParamsVector(), self.l.model.getParamsVector()]
fig, axes = plt.subplots(2, 4, subplot_kw={"projection": "3d"}, figsize=(16, 8), dpi=120)
self.axes = axes.flatten(); self.l.run(1000000)
if isinstance(self.parent, k1lib.callbacks.Accuracy):
with self.cbs.suspendEval(less=["Accuracy"]), torch.no_grad(), self.parent.pause():
inner()
else:
with self.cbs.suspendEval(), torch.no_grad(): inner()
self.suspended = True; self.iter = None
def __repr__(self):
return f"""{super()._reprHead}, use...
- l.plot(): to plot everything
{super()._reprCan}"""
@k1lib.patch(Callbacks, docs=Landscape.__init__)
def withLandscape(self, propertyF:F, name:str=None):
return self.append(Landscape(propertyF, name))