File size: 3,053 Bytes
ad02fa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import { HF_TOKEN } from '$env/static/private';
import { PUBLIC_MODEL_ENDPOINT } from '$env/static/public';
import { buildPrompt } from '$lib/buildPrompt.js';
import { collections } from '$lib/server/database.js';
import type { Message } from '$lib/types/Message.js';
import { streamToAsyncIterable } from '$lib/utils/streamToAsyncIterable';
import { sum } from '$lib/utils/sum';
import { error } from '@sveltejs/kit';
import { ObjectId } from 'mongodb';

export async function POST({ request, fetch, locals, params }) {
	// todo: add validation on params.id
	const convId = new ObjectId(params.id);

	const conv = await collections.conversations.findOne({
		_id: convId,
		sessionId: locals.sessionId
	});

	if (!conv) {
		throw error(404, 'Conversation not found');
	}

	// Todo: validate prompt with zod? or aktype
	const json = await request.json();

	const messages = [...conv.messages, { from: 'user', content: json.inputs }] satisfies Message[];

	json.inputs = buildPrompt(messages);

	const resp = await fetch(PUBLIC_MODEL_ENDPOINT, {
		headers: {
			'Content-Type': request.headers.get('Content-Type') ?? 'application/json',
			Authorization: `Basic ${HF_TOKEN}`
		},
		method: 'POST',
		body: JSON.stringify(json)
	});

	const [stream1, stream2] = resp.body!.tee();

	async function saveMessage() {
		const generated_text = await parseGeneratedText(stream2);

		messages.push({ from: 'assistant', content: generated_text });

		console.log('updating conversation', convId, messages);

		await collections.conversations.updateOne(
			{
				_id: convId
			},
			{
				$set: {
					messages,
					updatedAt: new Date()
				}
			}
		);
	}

	saveMessage().catch(console.error);

	// Todo: maybe we should wait for the message to be saved before ending the response - in case of errors
	return new Response(stream1, {
		headers: Object.fromEntries(resp.headers.entries()),
		status: resp.status,
		statusText: resp.statusText
	});
}

async function parseGeneratedText(stream: ReadableStream): Promise<string> {
	const inputs: Uint8Array[] = [];
	for await (const input of streamToAsyncIterable(stream)) {
		inputs.push(input);
	}

	// Merge inputs into a single Uint8Array
	const completeInput = new Uint8Array(sum(inputs.map((input) => input.length)));
	let offset = 0;
	for (const input of inputs) {
		completeInput.set(input, offset);
		offset += input.length;
	}

	// Get last line starting with "data:" and parse it as JSON to get the generated text
	const message = new TextDecoder().decode(completeInput);

	let lastIndex = message.lastIndexOf('\ndata:');
	if (lastIndex === -1) {
		lastIndex = message.indexOf('data');
	}

	if (lastIndex === -1) {
		console.error('Could not parse in last message');
	}

	let lastMessage = message.slice(lastIndex).trim().slice('data:'.length);
	if (lastMessage.includes('\n')) {
		lastMessage = lastMessage.slice(0, lastMessage.indexOf('\n'));
	}

	const res = JSON.parse(lastMessage).generated_text;

	if (typeof res !== 'string') {
		throw new Error('Could not parse generated text');
	}

	return res;
}