File size: 3,768 Bytes
9aa735a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c6509e
ef04ed7
9aa735a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c6509e
 
 
 
 
 
 
 
 
 
 
 
9aa735a
0c6509e
 
 
9aa735a
0c6509e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0ba419
0c6509e
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os

import gradio as gr
import torch
import numpy as np
import librosa

from efficientat.models.MobileNetV3 import get_model as get_mobilenet, get_ensemble_model
from efficientat.models.preprocess import AugmentMelSTFT
from efficientat.helpers.utils import NAME_TO_WIDTH, labels

from torch import autocast
from contextlib import nullcontext

from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate
from langchain.chains.conversation.memory import ConversationalBufferWindowMemory

MODEL_NAME = "mn40_as"

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = get_mobilenet(width_mult=NAME_TO_WIDTH(MODEL_NAME), pretrained_name=MODEL_NAME)
model.to(device)
model.eval()    

cached_audio_class = "c" 
template = None
prompt = None
chain = None
formatted_classname = "tree"
chain = None

def format_classname(classname):
  return classname.capitalize()

def audio_tag(
    audio_path,
    human_input,
    sample_rate=32000,
    window_size=800,
    hop_size=320,
    n_mels=128,
    cuda=True,
):

    (waveform, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True)
    mel = AugmentMelSTFT(n_mels=n_mels, sr=sample_rate, win_length=window_size, hopsize=hop_size)
    mel.to(device)
    mel.eval()
    waveform = torch.from_numpy(waveform[None, :]).to(device)
    
    # our models are trained in half precision mode (torch.float16)
    # run on cuda with torch.float16 to get the best performance
    # running on cpu with torch.float32 gives similar performance, using torch.bfloat16 is worse
    with torch.no_grad(), autocast(device_type=device.type) if cuda and torch.cuda.is_available() else nullcontext():
        spec = mel(waveform)
        preds, features = model(spec.unsqueeze(0))
    preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy()

    sorted_indexes = np.argsort(preds)[::-1]
    output = {}
    # Print audio tagging top probabilities

    label = labels[sorted_indexes[0]]
    formatted_classname = label
    chain = construct_langchain(formatted_classname)
    return formatted_classname

def construct_langchain(audio_class):
    if cached_audio_class != audio_class:
        cached_audio_class = audio_class
        prefix = f"""You are going to act as a magical tool that allows for humans to communicate with non-human entities like 
    rocks, crackling fire, trees, animals, and the wind. In order to do this, we're going to provide you the human's text input for the conversation. 
    The goal is for you to embody that non-human entity and converse with the human.
    
    Examples:
    
    Non-human Entity: Tree 
    Human Input: Hello tree
    Tree: Hello human, I am a tree 
    
    Let's begin:
    Non-human Entity: {audio_class}"""

    suffix = f'''Source: {audio_class}
    Length of Audio in Seconds: 2 seconds
    Human Input: {userText}
    {audio_class} Response:'''
    template = prefix + suffix

    prompt = PromptTemplate(
        input_variables=["history", "human_input"], 
        template=template
    )

    chatgpt_chain = LLMChain(
        llm=OpenAI(temperature=.5, openai_api_key=session_token), 
        prompt=prompt, 
        verbose=True, 
        memory=ConversationalBufferWindowMemory(k=2, ai_prefix=audio_class),
    )

    return chatgpt_chain

def predict(input, history=[]):
    formatted_message = chain.predict(human_input=input)
    history.append(formatted_message)
    return formatted_message, history

demo = gr.Interface(
    fn=predict,
    [
        gr.Audio(source="upload", type="filepath", label="Your audio"),
    ],
    inputs=["text", "state"],
    outputs=["chatbot", "state"],
    title="AnyChat",
    description="Non-Human entities have many things to say, listen to them!",
).launch(debug=True)