tbboukhari commited on
Commit
9eec875
1 Parent(s): 9398202

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -2,6 +2,8 @@ import streamlit as st
2
  #from streamlit_chat import message as st_message
3
  from streamlit_chat import message as st_message
4
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
 
5
 
6
  st.title("Chatbot Produit")
7
 
@@ -9,12 +11,16 @@ if "history" not in st.session_state:
9
  st.session_state.history = []
10
 
11
  def get_models():
 
 
 
12
 
13
- model_name = "tbboukhari/chatbot-produit-fr"
 
14
 
15
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")
16
- tokenizer = AutoTokenizer.from_pretrained(model_name)
17
-
18
  return tokenizer, model
19
 
20
  def generate_answer():
 
2
  #from streamlit_chat import message as st_message
3
  from streamlit_chat import message as st_message
4
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
+ import torch
6
+ from peft import PeftModel, PeftConfig
7
 
8
  st.title("Chatbot Produit")
9
 
 
11
  st.session_state.history = []
12
 
13
  def get_models():
14
+
15
+ peft_model_id = "tbboukhari/chatbot-produit-fr"
16
+ config = PeftConfig.from_pretrained(peft_model_id)
17
 
18
+ model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, torch_dtype="auto", device_map="auto")
19
+ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
20
 
21
+ # Load the Lora model
22
+ model = PeftModel.from_pretrained(model, peft_model_id)
23
+
24
  return tokenizer, model
25
 
26
  def generate_answer():