flexthink commited on
Commit
0bd83c7
1 Parent(s): 4101f55

Fix device issues

Browse files
Files changed (1) hide show
  1. custom_interface.py +2 -1
custom_interface.py CHANGED
@@ -80,7 +80,8 @@ class Discrete_EmbeddingLayer(torch.nn.Module):
80
  self.init = init
81
  self.layers = layers
82
  self.available_layers = available_layers
83
- self.offsets = self.build_offsets()
 
84
  self.chunk_size = chunk_size
85
 
86
  def init_embedding(self, weights):
 
80
  self.init = init
81
  self.layers = layers
82
  self.available_layers = available_layers
83
+ self.register_buffer("offsets", self.build_offsets())
84
+ self.register_buffer("layer_embs", self.compute_layer_embs())
85
  self.chunk_size = chunk_size
86
 
87
  def init_embedding(self, weights):