nsarrazin HF staff commited on
Commit
922b1b2
1 Parent(s): b46dc11

very basic rate limiter (#320)

Browse files
.env CHANGED
@@ -70,6 +70,7 @@ PARQUET_EXPORT_DATASET=
70
  PARQUET_EXPORT_HF_TOKEN=
71
  PARQUET_EXPORT_SECRET=
72
 
 
73
 
74
  PUBLIC_APP_NAME=ChatUI # name used as title throughout the app
75
  PUBLIC_APP_ASSETS=chatui # used to find logos & favicons in static/$PUBLIC_APP_ASSETS
 
70
  PARQUET_EXPORT_HF_TOKEN=
71
  PARQUET_EXPORT_SECRET=
72
 
73
+ RATE_LIMIT= # requests per minute
74
 
75
  PUBLIC_APP_NAME=ChatUI # name used as title throughout the app
76
  PUBLIC_APP_ASSETS=chatui # used to find logos & favicons in static/$PUBLIC_APP_ASSETS
src/lib/server/database.ts CHANGED
@@ -6,6 +6,7 @@ import type { WebSearch } from "$lib/types/WebSearch";
6
  import type { AbortedGeneration } from "$lib/types/AbortedGeneration";
7
  import type { Settings } from "$lib/types/Settings";
8
  import type { User } from "$lib/types/User";
 
9
 
10
  if (!MONGODB_URL) {
11
  throw new Error(
@@ -27,6 +28,7 @@ const abortedGenerations = db.collection<AbortedGeneration>("abortedGenerations"
27
  const settings = db.collection<Settings>("settings");
28
  const users = db.collection<User>("users");
29
  const webSearches = db.collection<WebSearch>("webSearches");
 
30
 
31
  export { client, db };
32
  export const collections = {
@@ -36,6 +38,7 @@ export const collections = {
36
  settings,
37
  users,
38
  webSearches,
 
39
  };
40
 
41
  client.on("open", () => {
@@ -59,4 +62,5 @@ client.on("open", () => {
59
  settings.createIndex({ userId: 1 }, { unique: true, sparse: true }).catch(console.error);
60
  users.createIndex({ hfUserId: 1 }, { unique: true }).catch(console.error);
61
  users.createIndex({ sessionId: 1 }, { unique: true, sparse: true }).catch(console.error);
 
62
  });
 
6
  import type { AbortedGeneration } from "$lib/types/AbortedGeneration";
7
  import type { Settings } from "$lib/types/Settings";
8
  import type { User } from "$lib/types/User";
9
+ import type { MessageEvent } from "$lib/types/MessageEvent";
10
 
11
  if (!MONGODB_URL) {
12
  throw new Error(
 
28
  const settings = db.collection<Settings>("settings");
29
  const users = db.collection<User>("users");
30
  const webSearches = db.collection<WebSearch>("webSearches");
31
+ const messageEvents = db.collection<MessageEvent>("messageEvents");
32
 
33
  export { client, db };
34
  export const collections = {
 
38
  settings,
39
  users,
40
  webSearches,
41
+ messageEvents,
42
  };
43
 
44
  client.on("open", () => {
 
62
  settings.createIndex({ userId: 1 }, { unique: true, sparse: true }).catch(console.error);
63
  users.createIndex({ hfUserId: 1 }, { unique: true }).catch(console.error);
64
  users.createIndex({ sessionId: 1 }, { unique: true, sparse: true }).catch(console.error);
65
+ messageEvents.createIndex({ createdAt: 1 }, { expireAfterSeconds: 60 }).catch(console.error);
66
  });
src/lib/stores/errors.ts CHANGED
@@ -3,6 +3,7 @@ import { writable } from "svelte/store";
3
  export const ERROR_MESSAGES = {
4
  default: "Oops, something went wrong.",
5
  authOnly: "You have to be logged in.",
 
6
  };
7
 
8
  export const error = writable<string | null>(null);
 
3
  export const ERROR_MESSAGES = {
4
  default: "Oops, something went wrong.",
5
  authOnly: "You have to be logged in.",
6
+ rateLimited: "You are sending too many messages. Try again later.",
7
  };
8
 
9
  export const error = writable<string | null>(null);
src/lib/types/MessageEvent.ts ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import type { Timestamps } from "./Timestamps";
2
+ import type { User } from "./User";
3
+
4
+ export interface MessageEvent extends Pick<Timestamps, "createdAt"> {
5
+ userId: User["_id"] | User["sessionId"];
6
+ }
src/routes/conversation/[id]/+page.svelte CHANGED
@@ -207,6 +207,8 @@
207
  } catch (err) {
208
  if (err instanceof Error && err.message.includes("overloaded")) {
209
  $error = "Too much traffic, please try again.";
 
 
210
  } else if (err instanceof Error) {
211
  $error = err.message;
212
  } else {
 
207
  } catch (err) {
208
  if (err instanceof Error && err.message.includes("overloaded")) {
209
  $error = "Too much traffic, please try again.";
210
+ } else if (err instanceof Error && err.message.includes("429")) {
211
+ $error = ERROR_MESSAGES.rateLimited;
212
  } else if (err instanceof Error) {
213
  $error = err.message;
214
  } else {
src/routes/conversation/[id]/+server.ts CHANGED
@@ -1,3 +1,4 @@
 
1
  import { buildPrompt } from "$lib/buildPrompt";
2
  import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
3
  import { abortedGenerations } from "$lib/server/abortedGenerations";
@@ -5,6 +6,7 @@ import { authCondition } from "$lib/server/auth";
5
  import { collections } from "$lib/server/database";
6
  import { modelEndpoint } from "$lib/server/modelEndpoint";
7
  import { models } from "$lib/server/models";
 
8
  import type { Message } from "$lib/types/Message";
9
  import { concatUint8Arrays } from "$lib/utils/concatUint8Arrays";
10
  import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
@@ -20,6 +22,12 @@ export async function POST({ request, fetch, locals, params }) {
20
  const convId = new ObjectId(id);
21
  const date = new Date();
22
 
 
 
 
 
 
 
23
  const conv = await collections.conversations.findOne({
24
  _id: convId,
25
  ...authCondition(locals),
@@ -29,6 +37,12 @@ export async function POST({ request, fetch, locals, params }) {
29
  throw error(404, "Conversation not found");
30
  }
31
 
 
 
 
 
 
 
32
  const model = models.find((m) => m.id === conv.model);
33
 
34
  if (!model) {
@@ -118,6 +132,11 @@ export async function POST({ request, fetch, locals, params }) {
118
  id: (responseId as Message["id"]) || crypto.randomUUID(),
119
  });
120
 
 
 
 
 
 
121
  await collections.conversations.updateOne(
122
  {
123
  _id: convId,
 
1
+ import { RATE_LIMIT } from "$env/static/private";
2
  import { buildPrompt } from "$lib/buildPrompt";
3
  import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
4
  import { abortedGenerations } from "$lib/server/abortedGenerations";
 
6
  import { collections } from "$lib/server/database";
7
  import { modelEndpoint } from "$lib/server/modelEndpoint";
8
  import { models } from "$lib/server/models";
9
+ import { ERROR_MESSAGES } from "$lib/stores/errors.js";
10
  import type { Message } from "$lib/types/Message";
11
  import { concatUint8Arrays } from "$lib/utils/concatUint8Arrays";
12
  import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
 
22
  const convId = new ObjectId(id);
23
  const date = new Date();
24
 
25
+ const userId = locals.user?._id ?? locals.sessionId;
26
+
27
+ if (!userId) {
28
+ throw error(401, "Unauthorized");
29
+ }
30
+
31
  const conv = await collections.conversations.findOne({
32
  _id: convId,
33
  ...authCondition(locals),
 
37
  throw error(404, "Conversation not found");
38
  }
39
 
40
+ const nEvents = await collections.messageEvents.countDocuments({ userId });
41
+
42
+ if (RATE_LIMIT != "" && nEvents > parseInt(RATE_LIMIT)) {
43
+ throw error(429, ERROR_MESSAGES.rateLimited);
44
+ }
45
+
46
  const model = models.find((m) => m.id === conv.model);
47
 
48
  if (!model) {
 
132
  id: (responseId as Message["id"]) || crypto.randomUUID(),
133
  });
134
 
135
+ await collections.messageEvents.insertOne({
136
+ userId: userId,
137
+ createdAt: new Date(),
138
+ });
139
+
140
  await collections.conversations.updateOne(
141
  {
142
  _id: convId,