File size: 3,145 Bytes
ee519f0
 
35e6b48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497bfcd
 
1c72597
f017c86
 
497bfcd
 
 
 
35e6b48
497bfcd
 
35e6b48
497bfcd
35e6b48
497bfcd
 
 
 
1dcc637
497bfcd
 
 
 
 
7cf1a8c
497bfcd
7cf1a8c
 
 
497bfcd
7cf1a8c
 
 
497bfcd
 
 
1c72597
e883152
 
 
1c72597
 
 
 
 
 
 
 
 
7cf1a8c
 
 
 
1c72597
7cf1a8c
 
 
 
 
 
1c72597
c914fb5
497bfcd
 
7cf1a8c
1dcc637
7cf1a8c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os

import octoai
octoai_client = octoai.client.Client(token=os.getenv('OCTOML_KEY'))

from pinecone import Pinecone, ServerlessSpec
pc = Pinecone(api_key=os.getenv('PINECONE_API_KEY'))


from llama_index.vector_stores.pinecone import PineconeVectorStore
from llama_index.core import VectorStoreIndex
from llama_index.core.response.pprint_utils import pprint_source_node
from llama_index.llms.octoai import OctoAI

octoai = OctoAI(
    token=os.getenv('OCTOML_KEY'),
    model="meta-llama-3-70b-instruct",
    max_tokens=512,
    temperature=0.1,
)


from llama_index.core.memory import ChatMemoryBuffer

import gradio as gr
from io import StringIO

import util as cu

def get_credit_dist(history):
    atoms_l = cu.sentence_splitter.split_text(history[-1][1])
    atoms_l = list(filter(lambda x: len(x) > 50, atoms_l))
    atom_topkmatches_l = cu.get_atom_topk_matches_l_concurrent(atoms_l, max_workers=8)

    atomidx_w_single_url_aggmatch_l = cu.aggregate_atom_topkmatches_l(atom_topkmatches_l)
    atom_support_l = cu.get_atmom_support_l_from_atomidx_w_single_url_aggmatch_l_concurrent(atoms_l, atomidx_w_single_url_aggmatch_l, max_workers=8)

    credit_dist = cu.credit_atom_support_list(atom_support_l)

    _out = StringIO()
    print(f"Credit distribution to sources:\n", file=_out)
    cu.print_credit_dist(credit_dist, prefix='    ', url_to_id=None, file=_out)
    print(file=_out)

    print(f"Per claim support:\n", file=_out)
    for j, atom_support in enumerate(atom_support_l):
        print(f"    Claim {j+1}: \"{atoms_l[j]}\"\n", file=_out)
        cu.print_atom_support(atom_support, prefix='            ', file=_out)
        print(file=_out)

    return _out.getvalue()


with gr.Blocks() as demo:
    chatbot = gr.Chatbot(height=800)
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    credit_box = gr.Textbox(label="Credit distribution", lines=20, autoscroll=False)
    credit_btn = gr.Button("Credit response")

    def get_chat_engine():
        vector_store = PineconeVectorStore(pinecone_index=pc.Index('prorata-postman-ds-256'))
        vindex = VectorStoreIndex.from_vector_store(vector_store)
        
        memory = ChatMemoryBuffer.from_defaults(token_limit=5000)
        return vindex.as_chat_engine(
            chat_mode="context",
            llm=octoai,
            memory=memory,
            system_prompt="You are a chatbot, able to have normal interactions, as well as talk about news events.",
        )

    chat_engine_var = gr.State(get_chat_engine)

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history, chat_engine):
        response = chat_engine.stream_chat(history[-1][0])
        history[-1][1] = ""
        for token in response.response_gen:
            history[-1][1] += token
            yield history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, [chatbot, chat_engine_var], chatbot)
    clear.click(lambda x: x.reset(), chat_engine_var, chatbot, queue=False)

    credit_btn.click(get_credit_dist, chatbot, credit_box)
    
if __name__ == "__main__":
    demo.queue()
    demo.launch()