File size: 6,138 Bytes
dd8990d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import inspect
import logging
from typing import Awaitable, Callable, get_type_hints

from open_webui.apps.webui.models.tools import Tools
from open_webui.apps.webui.models.users import UserModel
from open_webui.apps.webui.utils import load_toolkit_module_by_id
from open_webui.utils.schemas import json_schema_to_model

log = logging.getLogger(__name__)


def apply_extra_params_to_tool_function(
    function: Callable, extra_params: dict
) -> Callable[..., Awaitable]:
    sig = inspect.signature(function)
    extra_params = {
        key: value for key, value in extra_params.items() if key in sig.parameters
    }
    is_coroutine = inspect.iscoroutinefunction(function)

    async def new_function(**kwargs):
        extra_kwargs = kwargs | extra_params
        if is_coroutine:
            return await function(**extra_kwargs)
        return function(**extra_kwargs)

    return new_function


# Mutation on extra_params
def get_tools(
    webui_app, tool_ids: list[str], user: UserModel, extra_params: dict
) -> dict[str, dict]:
    tools = {}
    for tool_id in tool_ids:
        toolkit = Tools.get_tool_by_id(tool_id)
        if toolkit is None:
            continue

        module = webui_app.state.TOOLS.get(tool_id, None)
        if module is None:
            module, _ = load_toolkit_module_by_id(tool_id)
            webui_app.state.TOOLS[tool_id] = module

        extra_params["__id__"] = tool_id
        if hasattr(module, "valves") and hasattr(module, "Valves"):
            valves = Tools.get_tool_valves_by_id(tool_id) or {}
            module.valves = module.Valves(**valves)

        if hasattr(module, "UserValves"):
            extra_params["__user__"]["valves"] = module.UserValves(  # type: ignore
                **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
            )

        for spec in toolkit.specs:
            # TODO: Fix hack for OpenAI API
            for val in spec.get("parameters", {}).get("properties", {}).values():
                if val["type"] == "str":
                    val["type"] = "string"
            function_name = spec["name"]

            # convert to function that takes only model params and inserts custom params
            original_func = getattr(module, function_name)
            callable = apply_extra_params_to_tool_function(original_func, extra_params)
            if hasattr(original_func, "__doc__"):
                callable.__doc__ = original_func.__doc__

            # TODO: This needs to be a pydantic model
            tool_dict = {
                "toolkit_id": tool_id,
                "callable": callable,
                "spec": spec,
                "pydantic_model": json_schema_to_model(spec),
                "file_handler": hasattr(module, "file_handler") and module.file_handler,
                "citation": hasattr(module, "citation") and module.citation,
            }

            # TODO: if collision, prepend toolkit name
            if function_name in tools:
                log.warning(f"Tool {function_name} already exists in another toolkit!")
                log.warning(f"Collision between {toolkit} and {tool_id}.")
                log.warning(f"Discarding {toolkit}.{function_name}")
            else:
                tools[function_name] = tool_dict
    return tools


def doc_to_dict(docstring):
    lines = docstring.split("\n")
    description = lines[1].strip()
    param_dict = {}

    for line in lines:
        if ":param" in line:
            line = line.replace(":param", "").strip()
            param, desc = line.split(":", 1)
            param_dict[param.strip()] = desc.strip()
    ret_dict = {"description": description, "params": param_dict}
    return ret_dict


def get_tools_specs(tools) -> list[dict]:
    function_list = [
        {"name": func, "function": getattr(tools, func)}
        for func in dir(tools)
        if callable(getattr(tools, func))
        and not func.startswith("__")
        and not inspect.isclass(getattr(tools, func))
    ]

    specs = []
    for function_item in function_list:
        function_name = function_item["name"]
        function = function_item["function"]

        function_doc = doc_to_dict(function.__doc__ or function_name)
        specs.append(
            {
                "name": function_name,
                # TODO: multi-line desc?
                "description": function_doc.get("description", function_name),
                "parameters": {
                    "type": "object",
                    "properties": {
                        param_name: {
                            "type": param_annotation.__name__.lower(),
                            **(
                                {
                                    "enum": (
                                        str(param_annotation.__args__)
                                        if hasattr(param_annotation, "__args__")
                                        else None
                                    )
                                }
                                if hasattr(param_annotation, "__args__")
                                else {}
                            ),
                            "description": function_doc.get("params", {}).get(
                                param_name, param_name
                            ),
                        }
                        for param_name, param_annotation in get_type_hints(
                            function
                        ).items()
                        if param_name != "return"
                        and not (
                            param_name.startswith("__") and param_name.endswith("__")
                        )
                    },
                    "required": [
                        name
                        for name, param in inspect.signature(
                            function
                        ).parameters.items()
                        if param.default is param.empty
                        and not (name.startswith("__") and name.endswith("__"))
                    ],
                },
            }
        )

    return specs