nsarrazin HF staff commited on
Commit
cb000d3
1 Parent(s): a01ed5a

Implement Cloudflare Workers AI endpoint (#907) (#972)

Browse files

* Implement Cloudflare Workers AI endpoint (#907)

* Renamed to Cloudflare Workers AI in docs

* Add note about sampling parameters

* clean up env example

.env CHANGED
@@ -8,8 +8,11 @@ MONGODB_DIRECT_CONNECTION=false
8
  COOKIE_NAME=hf-chat
9
  HF_TOKEN=#hf_<token> from https://huggingface.co/settings/token
10
  HF_API_ROOT=https://api-inference.huggingface.co/models
 
11
  OPENAI_API_KEY=#your openai api key here
12
  ANTHROPIC_API_KEY=#your anthropic api key here
 
 
13
 
14
  HF_ACCESS_TOKEN=#LEGACY! Use HF_TOKEN instead
15
 
 
8
  COOKIE_NAME=hf-chat
9
  HF_TOKEN=#hf_<token> from https://huggingface.co/settings/token
10
  HF_API_ROOT=https://api-inference.huggingface.co/models
11
+
12
  OPENAI_API_KEY=#your openai api key here
13
  ANTHROPIC_API_KEY=#your anthropic api key here
14
+ CLOUDFLARE_ACCOUNT_ID=#your cloudflare account id here
15
+ CLOUDFLARE_API_TOKEN=#your cloudflare api token here
16
 
17
  HF_ACCESS_TOKEN=#LEGACY! Use HF_TOKEN instead
18
 
README.md CHANGED
@@ -528,6 +528,38 @@ You can also set `"service" : "lambda"` to use a lambda instance.
528
 
529
  You can get the `accessKey` and `secretKey` from your AWS user, under programmatic access.
530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
  ##### Google Vertex models
532
 
533
  Chat UI can connect to the google Vertex API endpoints ([List of supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models)).
 
528
 
529
  You can get the `accessKey` and `secretKey` from your AWS user, under programmatic access.
530
 
531
+ #### Cloudflare Workers AI
532
+
533
+ You can also use Cloudflare Workers AI to run your own models with serverless inference.
534
+
535
+ You will need to have a Cloudflare account, then get your [account ID](https://developers.cloudflare.com/fundamentals/setup/find-account-and-zone-ids/) as well as your [API token](https://developers.cloudflare.com/workers-ai/get-started/rest-api/#1-get-an-api-token) for Workers AI.
536
+
537
+ You can either specify them directly in your `.env.local` using the `CLOUDFLARE_ACCOUNT_ID` and `CLOUDFLARE_API_TOKEN` variables, or you can set them directly in the endpoint config.
538
+
539
+ You can find the list of models available on Cloudflare [here](https://developers.cloudflare.com/workers-ai/models/#text-generation).
540
+
541
+ ```env
542
+ {
543
+ "name" : "nousresearch/hermes-2-pro-mistral-7b",
544
+ "tokenizer": "nousresearch/hermes-2-pro-mistral-7b",
545
+ "parameters": {
546
+ "stop": ["<|im_end|>"]
547
+ },
548
+ "endpoints" : [
549
+ {
550
+ "type" : "cloudflare"
551
+ <!-- optionally specify these
552
+ "accountId": "your-account-id",
553
+ "authToken": "your-api-token"
554
+ -->
555
+ }
556
+ ]
557
+ }
558
+ ```
559
+
560
+ > [!NOTE]
561
+ > Cloudlare Workers AI currently do not support custom sampling parameters like temperature, top_p, etc.
562
+
563
  ##### Google Vertex models
564
 
565
  Chat UI can connect to the google Vertex API endpoints ([List of supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models)).
src/lib/server/endpoints/cloudflare/endpointCloudflare.ts ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { z } from "zod";
2
+ import type { Endpoint } from "../endpoints";
3
+ import type { TextGenerationStreamOutput } from "@huggingface/inference";
4
+ import { CLOUDFLARE_ACCOUNT_ID, CLOUDFLARE_API_TOKEN } from "$env/static/private";
5
+
6
+ export const endpointCloudflareParametersSchema = z.object({
7
+ weight: z.number().int().positive().default(1),
8
+ model: z.any(),
9
+ type: z.literal("cloudflare"),
10
+ accountId: z.string().default(CLOUDFLARE_ACCOUNT_ID),
11
+ apiToken: z.string().default(CLOUDFLARE_API_TOKEN),
12
+ });
13
+
14
+ export async function endpointCloudflare(
15
+ input: z.input<typeof endpointCloudflareParametersSchema>
16
+ ): Promise<Endpoint> {
17
+ const { accountId, apiToken, model } = endpointCloudflareParametersSchema.parse(input);
18
+ const apiURL = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/@hf/${model.id}`;
19
+
20
+ return async ({ messages, preprompt }) => {
21
+ let messagesFormatted = messages.map((message) => ({
22
+ role: message.from,
23
+ content: message.content,
24
+ }));
25
+
26
+ if (messagesFormatted?.[0]?.role !== "system") {
27
+ messagesFormatted = [{ role: "system", content: preprompt ?? "" }, ...messagesFormatted];
28
+ }
29
+
30
+ const payload = JSON.stringify({
31
+ messages: messagesFormatted,
32
+ stream: true,
33
+ });
34
+
35
+ const res = await fetch(apiURL, {
36
+ method: "POST",
37
+ headers: {
38
+ Authorization: `Bearer ${apiToken}`,
39
+ "Content-Type": "application/json",
40
+ },
41
+ body: payload,
42
+ });
43
+
44
+ if (!res.ok) {
45
+ throw new Error(`Failed to generate text: ${await res.text()}`);
46
+ }
47
+
48
+ const encoder = new TextDecoderStream();
49
+ const reader = res.body?.pipeThrough(encoder).getReader();
50
+
51
+ return (async function* () {
52
+ let stop = false;
53
+ let generatedText = "";
54
+ let tokenId = 0;
55
+ let accumulatedData = ""; // Buffer to accumulate data chunks
56
+
57
+ while (!stop) {
58
+ const out = await reader?.read();
59
+
60
+ // If it's done, we cancel
61
+ if (out?.done) {
62
+ reader?.cancel();
63
+ return;
64
+ }
65
+
66
+ if (!out?.value) {
67
+ return;
68
+ }
69
+
70
+ // Accumulate the data chunk
71
+ accumulatedData += out.value;
72
+
73
+ // Process each complete JSON object in the accumulated data
74
+ while (accumulatedData.includes("\n")) {
75
+ // Assuming each JSON object ends with a newline
76
+ const endIndex = accumulatedData.indexOf("\n");
77
+ let jsonString = accumulatedData.substring(0, endIndex).trim();
78
+
79
+ // Remove the processed part from the buffer
80
+ accumulatedData = accumulatedData.substring(endIndex + 1);
81
+
82
+ if (jsonString.startsWith("data: ")) {
83
+ jsonString = jsonString.slice(6);
84
+ let data = null;
85
+
86
+ if (jsonString === "[DONE]") {
87
+ stop = true;
88
+
89
+ yield {
90
+ token: {
91
+ id: tokenId++,
92
+ text: "",
93
+ logprob: 0,
94
+ special: true,
95
+ },
96
+ generated_text: generatedText,
97
+ details: null,
98
+ } satisfies TextGenerationStreamOutput;
99
+ reader?.cancel();
100
+
101
+ continue;
102
+ }
103
+
104
+ try {
105
+ data = JSON.parse(jsonString);
106
+ } catch (e) {
107
+ console.error("Failed to parse JSON", e);
108
+ console.error("Problematic JSON string:", jsonString);
109
+ continue; // Skip this iteration and try the next chunk
110
+ }
111
+
112
+ // Handle the parsed data
113
+ if (data.response) {
114
+ generatedText += data.response ?? "";
115
+ const output: TextGenerationStreamOutput = {
116
+ token: {
117
+ id: tokenId++,
118
+ text: data.response ?? "",
119
+ logprob: 0,
120
+ special: false,
121
+ },
122
+ generated_text: null,
123
+ details: null,
124
+ };
125
+ yield output;
126
+ }
127
+ }
128
+ }
129
+ }
130
+ })();
131
+ };
132
+ }
133
+
134
+ export default endpointCloudflare;
src/lib/server/endpoints/endpoints.ts CHANGED
@@ -13,6 +13,9 @@ import {
13
  endpointAnthropicParametersSchema,
14
  } from "./anthropic/endpointAnthropic";
15
  import type { Model } from "$lib/types/Model";
 
 
 
16
 
17
  // parameters passed when generating text
18
  export interface EndpointParameters {
@@ -42,6 +45,7 @@ export const endpoints = {
42
  llamacpp: endpointLlamacpp,
43
  ollama: endpointOllama,
44
  vertex: endpointVertex,
 
45
  };
46
 
47
  export const endpointSchema = z.discriminatedUnion("type", [
@@ -52,5 +56,6 @@ export const endpointSchema = z.discriminatedUnion("type", [
52
  endpointLlamacppParametersSchema,
53
  endpointOllamaParametersSchema,
54
  endpointVertexParametersSchema,
 
55
  ]);
56
  export default endpoints;
 
13
  endpointAnthropicParametersSchema,
14
  } from "./anthropic/endpointAnthropic";
15
  import type { Model } from "$lib/types/Model";
16
+ import endpointCloudflare, {
17
+ endpointCloudflareParametersSchema,
18
+ } from "./cloudflare/endpointCloudflare";
19
 
20
  // parameters passed when generating text
21
  export interface EndpointParameters {
 
45
  llamacpp: endpointLlamacpp,
46
  ollama: endpointOllama,
47
  vertex: endpointVertex,
48
+ cloudflare: endpointCloudflare,
49
  };
50
 
51
  export const endpointSchema = z.discriminatedUnion("type", [
 
56
  endpointLlamacppParametersSchema,
57
  endpointOllamaParametersSchema,
58
  endpointVertexParametersSchema,
59
+ endpointCloudflareParametersSchema,
60
  ]);
61
  export default endpoints;
src/lib/server/models.ts CHANGED
@@ -130,6 +130,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
130
  return endpoints.ollama(args);
131
  case "vertex":
132
  return await endpoints.vertex(args);
 
 
133
  default:
134
  // for legacy reason
135
  return endpoints.tgi(args);
 
130
  return endpoints.ollama(args);
131
  case "vertex":
132
  return await endpoints.vertex(args);
133
+ case "cloudflare":
134
+ return await endpoints.cloudflare(args);
135
  default:
136
  // for legacy reason
137
  return endpoints.tgi(args);