harry85 commited on
Commit
44a3a0e
1 Parent(s): 6323488

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # Install the necessary packages
2
- # pip install accelerate transformers fastapi pydantic torch
3
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import torch
@@ -9,14 +9,15 @@ from fastapi import FastAPI
9
  # Initialize the FastAPI app
10
  app = FastAPI(docs_url="/")
11
 
12
- # Load the model and tokenizer once at startup
13
- device = "cuda" # the device to load the model onto
14
 
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
  "Qwen/Qwen1.5-0.5B-Chat",
17
  torch_dtype="auto",
18
  device_map="auto"
19
- )
20
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat")
21
 
22
  # Define the request model
 
1
  # Install the necessary packages
2
+ # pip install accelerate transformers fastapi pydantic torch jinja2
3
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import torch
 
9
  # Initialize the FastAPI app
10
  app = FastAPI(docs_url="/")
11
 
12
+ # Determine the device to use
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
+ # Load the model and tokenizer once at startup
16
  model = AutoModelForCausalLM.from_pretrained(
17
  "Qwen/Qwen1.5-0.5B-Chat",
18
  torch_dtype="auto",
19
  device_map="auto"
20
+ ).to(device)
21
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat")
22
 
23
  # Define the request model