# AUTOGENERATED FILE! PLEASE DON'T EDIT
from .callbacks import Callback, Callbacks, Cbs
import k1lib, torch, numpy as np, time
import matplotlib.pyplot as plt
from typing import Callable
__all__ = ["Landscape"]
spacing = 0.35 # orders of magnitude
offset = -2 # orders of magnitude shift
res = 20 # resolution
scales = 10**(np.array(range(8))*spacing + offset)
scales = [round(scale, 3) for scale in scales]
scales
F = Callable[["k1lib.Learner"], float]
[docs]@k1lib.patch(Cbs)
class Landscape(Callback):
    " "
[docs]    def __init__(self, propertyF:F, name:str=None):
        """Plots the landscape of the network.
:param propertyF: a function that takes in :class:`k1lib.Learner` and outputs the
    desired float property
.. warning::
    Remember to detach anything you get from :class:`k1lib.Learner` in your
    function, or else you're gonna cause a huge memory leak.
"""
        super().__init__(); self.propertyF = propertyF; self.suspended = True
        self.name = name or self.name; self.order = 23; self.parent:Callback = None 
    def startRun(self): self.originalParams = self.l.model.exportParams()
    def endRun(self): self.l.model.importParams(self.originalParams)
    def startPass(self):
        next(self.iter)
        for param, og, v1, v2 in zip(self.l.model.parameters(), self.originalParams, *self.vs):
            param.data = og + self.x * v1 + self.y * v2
    def endLoss(self):
        prop = self.propertyF(self.l)
        self.zs[self.ix, self.iy] = prop if prop == prop else 0 # check for nan
        if self.l.batch % 10: print(f"\rProgress: {round(100*(self.ix+self.iy/res)/res)}%, {round(time.time()-self.beginTime)}s      ", end="")
    def startBackward(self): return True
    def startStep(self): return True
    def startZeroGrad(self): return True
    def __iter__(self):
        """This one is the "core running loop", if you'd like to say so. Because
this needs to be sort of event-triggered (by checkpoint "startPass"), so kinda have
to put this into an iterator so that it's not the driving thread."""
        self.zss = [] # debug data
        for i, (scale, ax) in enumerate(zip(scales, self.axes)):
            a = torch.linspace(-scale, scale, res)
            xs, ys = np.meshgrid(a, a); self.zs = np.empty((res, res))
            xs = torch.tensor(xs); ys = torch.tensor(ys)
            for ix in range(res):
                for iy in range(res):
                    self.x = xs[ix, iy]; self.y = ys[ix, iy]
                    self.ix, self.iy = ix, iy; yield True
            self.zs[self.zs == float("inf")] = 0
            ax.plot_surface(xs, ys, self.zs, cmap=plt.cm.coolwarm)
            self.zss.append(self.zs)
            print(f"     {i+1}/8 Finished [{-scale}, {scale}] range              ", end="")
        raise k1lib.CancelRunException("Landscape finished")
[docs]    def plot(self):
        """Creates the landscapes and show plots"""
        self.suspended = False; self.iter = iter(self); self.beginTime = time.time()
        def inner():
            self.vs = [self.l.model.getParamsVector(), self.l.model.getParamsVector()]
            fig, axes = plt.subplots(2, 4, subplot_kw={"projection": "3d"}, figsize=(16, 8), dpi=120)
            self.axes = axes.flatten(); self.l.run(1000000)
        try:
            with self.cbs.suspendEval(), torch.no_grad(): inner()
        except: pass
        self.suspended = True; self.iter = None 
    def __repr__(self):
        return f"""{super()._reprHead}, use...
- l.plot(): to plot everything
{super()._reprCan}""" 
@k1lib.patch(Callbacks, docs=Landscape.__init__)
def withLandscape(self, propertyF:F, name:str=None):
    return self.append(Landscape(propertyF, name))