Joshua Lochner commited on
Commit
5dd37ab
1 Parent(s): bb74d9f

Ensure `input_ids` are on the correct device when predicting

Browse files
Files changed (1) hide show
  1. src/predict.py +1 -1
src/predict.py CHANGED
@@ -171,7 +171,7 @@ DEFAULT_TOKEN_PREFIX = 'summarize: '
171
  def predict_sponsor_text(text, model, tokenizer):
172
  """Given a body of text, predict the words which are part of the sponsor"""
173
  input_ids = tokenizer(
174
- f'{DEFAULT_TOKEN_PREFIX}{text}', return_tensors='pt', truncation=True).input_ids
175
 
176
  # Can't be longer than input length + SAFETY_TOKENS or model input dim
177
  max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim)
 
171
  def predict_sponsor_text(text, model, tokenizer):
172
  """Given a body of text, predict the words which are part of the sponsor"""
173
  input_ids = tokenizer(
174
+ f'{DEFAULT_TOKEN_PREFIX}{text}', return_tensors='pt', truncation=True).input_ids.to(device())
175
 
176
  # Can't be longer than input length + SAFETY_TOKENS or model input dim
177
  max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim)