import { z } from "zod"; import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream"; import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream"; import { buildPrompt } from "$lib/buildPrompt"; import { OPENAI_API_KEY } from "$env/static/private"; import type { Endpoint } from "../endpoints"; import { format } from "date-fns"; export const endpointOAIParametersSchema = z.object({ weight: z.number().int().positive().default(1), model: z.any(), type: z.literal("openai"), baseURL: z.string().url().default("https://api.openai.com/v1"), apiKey: z.string().default(OPENAI_API_KEY ?? "sk-"), completion: z .union([z.literal("completions"), z.literal("chat_completions")]) .default("chat_completions"), }); export async function endpointOai( input: z.input ): Promise { const { baseURL, apiKey, completion, model } = endpointOAIParametersSchema.parse(input); let OpenAI; try { OpenAI = (await import("openai")).OpenAI; } catch (e) { throw new Error("Failed to import OpenAI", { cause: e }); } const openai = new OpenAI({ apiKey: apiKey ?? "sk-", baseURL: baseURL, }); if (completion === "completions") { return async ({ conversation }) => { return openAICompletionToTextGenerationStream( await openai.completions.create({ model: model.id ?? model.name, prompt: await buildPrompt({ messages: conversation.messages, webSearch: conversation.messages[conversation.messages.length - 1].webSearch, preprompt: conversation.preprompt, model, }), stream: true, max_tokens: model.parameters?.max_new_tokens, stop: model.parameters?.stop, temperature: model.parameters?.temperature, top_p: model.parameters?.top_p, frequency_penalty: model.parameters?.repetition_penalty, }) ); }; } else if (completion === "chat_completions") { return async ({ conversation }) => { let messages = conversation.messages; const webSearch = conversation.messages[conversation.messages.length - 1].webSearch; if (webSearch && webSearch.context) { const lastMsg = messages.slice(-1)[0]; const messagesWithoutLastUsrMsg = messages.slice(0, -1); const previousUserMessages = messages.filter((el) => el.from === "user").slice(0, -1); const previousQuestions = previousUserMessages.length > 0 ? `Previous questions: \n${previousUserMessages .map(({ content }) => `- ${content}`) .join("\n")}` : ""; const currentDate = format(new Date(), "MMMM d, yyyy"); messages = [ ...messagesWithoutLastUsrMsg, { from: "user", content: `I searched the web using the query: ${webSearch.searchQuery}. Today is ${currentDate} and here are the results: ===================== ${webSearch.context} ===================== ${previousQuestions} Answer the question: ${lastMsg.content} `, }, ]; } const messagesOpenAI = messages.map((message) => ({ role: message.from, content: message.content, })); return openAIChatToTextGenerationStream( await openai.chat.completions.create({ model: model.id ?? model.name, messages: conversation.preprompt ? [{ role: "system", content: conversation.preprompt }, ...messagesOpenAI] : messagesOpenAI, stream: true, max_tokens: model.parameters?.max_new_tokens, stop: model.parameters?.stop, temperature: model.parameters?.temperature, top_p: model.parameters?.top_p, frequency_penalty: model.parameters?.repetition_penalty, }) ); }; } else { throw new Error("Invalid completion type"); } }