Monkey patched classes¶
-
class
torch.nn.modules.
Module
¶ -
Module.
importParams
(params: List[torch.nn.parameter.Parameter])¶ Given a list of
torch.nn.parameter.Parameter
/torch.Tensor
, update the currenttorch.nn.Module
’s parameters with it’
-
Module.
exportParams
() → List[torch.Tensor]¶ Gets the list of
torch.nn.parameter.Parameter
data
-
Module.
getParamsVector
() → List[torch.Tensor]¶ For each parameter, returns a normal distributed random tensor with the same standard deviation as the original parameter
-
Module.
preserveDevice
() → AbstractContextManager¶ Preserves the device of whatever operation is inside this. Example:
import torch.nn as nn m = nn.Linear(3, 4) with m.preserveDevice(): m.cuda() # moves whole model to cuda # automatically moves model to cpu
This will work even if the model has many tensors that live on 10 different devices.
-