# AUTOGENERATED FILE! PLEASE DON'T EDIT HERE. EDIT THE SOURCE NOTEBOOKS INSTEAD
from .callbacks import Callback, Callbacks, Cbs
import k1lib, numpy as np, time
plt = k1lib.dep("matplotlib.pyplot")
from typing import Callable
import k1lib.cli as cli
try: import torch; hasTorch = True
except: hasTorch = False
__all__ = ["Landscape"]
spacing = 0.35 # orders of magnitude
offset = -2 # orders of magnitude shift
res = 20 # resolution
scales = 10**(np.array(range(8))*spacing + offset)
scales = [round(scale, 3) for scale in scales]
scales
F = Callable[["k1lib.Learner"], float]
[docs]@k1lib.patch(Cbs)
class Landscape(Callback):                                                       # Landscape
    " "                                                                          # Landscape
[docs]    def __init__(self, propertyF:F, name:str=None):                              # Landscape
        """Plots the landscape of the network.
:param propertyF: a function that takes in :class:`k1lib.Learner` and outputs the
    desired float property
.. warning::
    Remember to detach anything you get from :class:`k1lib.Learner` in your
    function, or else you're gonna cause a huge memory leak.
"""                                                                              # Landscape
        super().__init__(); self.propertyF = propertyF; self.suspended = True    # Landscape
        self.name = name or self.name; self.order = 23; self.parent:Callback = None # Landscape 
[docs]    def startRun(self): self.originalParams = self.l.model.exportParams()        # Landscape 
[docs]    def endRun(self): self.l.model.importParams(self.originalParams)             # Landscape 
[docs]    def startPass(self):                                                         # Landscape
        next(self.iter)                                                          # Landscape
        for param, og, v1, v2 in zip(self.l.model.parameters(), self.originalParams, *self.vs): # Landscape
            param.data = og + self.x * v1 + self.y * v2                          # Landscape 
[docs]    def endLoss(self):                                                           # Landscape
        prop = self.propertyF(self.l)                                            # Landscape
        self.zs[self.ix, self.iy] = prop if prop == prop else 0 # check for nan  # Landscape
        if self.l.batch % 10: print(f"\rProgress: {round(100*(self.ix+self.iy/res)/res)}%, {round(time.time()-self.beginTime)}s      ", end="") # Landscape 
[docs]    def startBackward(self): return True                                         # Landscape 
[docs]    def startStep(self): return True                                             # Landscape 
[docs]    def startZeroGrad(self): return True                                         # Landscape 
    def __iter__(self):                                                          # Landscape
        """This one is the "core running loop", if you'd like to say so. Because
this needs to be sort of event-triggered (by checkpoint "startPass"), so kinda have
to put this into an iterator so that it's not the driving thread."""             # Landscape
        self.zss = [] # debug data                                               # Landscape
        for i, (scale, ax) in enumerate(zip(scales, self.axes)):                 # Landscape
            a = torch.linspace(-scale, scale, res)                               # Landscape
            xs, ys = np.meshgrid(a, a); self.zs = np.empty((res, res))           # Landscape
            xs = torch.tensor(xs); ys = torch.tensor(ys)                         # Landscape
            for ix in range(res):                                                # Landscape
                for iy in range(res):                                            # Landscape
                    self.x = xs[ix, iy]; self.y = ys[ix, iy]                     # Landscape
                    self.ix, self.iy = ix, iy; yield True                        # Landscape
            self.zs[self.zs == float("inf")] = 0                                 # Landscape
            ax.plot_surface(xs, ys, self.zs, cmap=plt.cm.coolwarm)               # Landscape
            self.zss.append(self.zs)                                             # Landscape
            print(f"     {i+1}/8 Finished [{-scale}, {scale}] range              ", end="") # Landscape
        raise k1lib.CancelRunException("Landscape finished")                     # Landscape
[docs]    def plot(self):                                                              # Landscape
        """Creates the landscapes and show plots"""                              # Landscape
        self.suspended = False; self.iter = iter(self); self.beginTime = time.time() # Landscape
        def inner():                                                             # Landscape
            self.vs = [self.l.model.getParamsVector(), self.l.model.getParamsVector()] # Landscape
            fig, axes = plt.subplots(2, 4, subplot_kw={"projection": "3d"}, figsize=(16, 8), dpi=120) # Landscape
            self.axes = axes.flatten(); self.l.run(1000000)                      # Landscape
        try:                                                                     # Landscape
            with self.cbs.suspendEval(), torch.no_grad(): inner()                # Landscape
        except: pass                                                             # Landscape
        self.suspended = True; self.iter = None                                  # Landscape 
    def __repr__(self):                                                          # Landscape
        return f"""{super()._reprHead}, use...
- l.plot(): to plot everything
{super()._reprCan}"""                                                            # Landscape