davidberenstein1957 HF staff commited on
Commit
119215e
1 Parent(s): f20ab91

Add overview of magpie battle

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. README.md +2 -2
  3. app.py +31 -43
  4. requirements.txt +1 -1
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Data Viber
3
- emoji: 📉
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
 
1
  ---
2
+ title: Dataset Viber - chat preference magpie battle
3
+ emoji: ⚔️
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
app.py CHANGED
@@ -1,54 +1,42 @@
1
  import os
2
- import io
3
  import random
4
 
5
- import requests
6
- from PIL import Image
7
  from dataset_viber import AnnotatorInterFace
8
-
9
- HF_TOKEN = os.environ["HF_TOKEN"]
10
- HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
11
- DATASET_SERVER_URL = "https://datasets-server.huggingface.co"
12
- DATASET_NAME = "poloclub%2Fdiffusiondb&config=2m_random_1k&split=train"
13
- MODEL_URL = (
14
- "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
15
- )
16
-
17
-
18
- def retrieve_sample(idx):
19
- api_url = f"{DATASET_SERVER_URL}/rows?dataset={DATASET_NAME}&offset={idx}&length=1"
20
- response = requests.get(api_url, headers=HEADERS)
21
- data = response.json()
22
- img_url = data["rows"][0]["row"]["image"]["src"]
23
- prompt = data["rows"][0]["row"]["prompt"]
24
- return img_url, prompt
25
-
26
-
27
- def get_rows():
28
- api_url = f"{DATASET_SERVER_URL}/size?dataset={DATASET_NAME}"
29
- response = requests.get(api_url, headers=HEADERS)
30
- num_rows = response.json()["size"]["config"]["num_rows"]
31
- return num_rows
32
-
33
-
34
- def generate_response(prompt):
35
- payload = {
36
- "inputs": prompt,
37
- }
38
- response = requests.post(MODEL_URL, headers=HEADERS, json=payload)
39
- image = Image.open(io.BytesIO(response.content))
40
- return image
41
-
42
 
43
  def next_input(_prompt, _completion_a, _completion_b):
44
- random_idx = random.randint(0, get_rows()) - 1
45
- img_url, prompt = retrieve_sample(random_idx)
46
- generated_image = generate_response(prompt)
47
- return (prompt, img_url, generated_image)
 
 
 
 
 
48
 
49
  if __name__ == "__main__":
50
- interface = AnnotatorInterFace.for_image_generation_preference(
51
- fn=next_input,
52
  dataset_name=None,
53
  )
54
  interface.launch()
 
1
  import os
 
2
  import random
3
 
 
 
4
  from dataset_viber import AnnotatorInterFace
5
+ from datasets import load_dataset
6
+ from huggingface_hub import InferenceClient
7
+
8
+ MODEL_IDS = [
9
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
10
+ "microsoft/Phi-3-mini-4k-instruct",
11
+ "mistralai/Mistral-7B-Instruct-v0.2"
12
+ ]
13
+ CLIENTS = [InferenceClient(model_id, token=os.environ["HF_AUTH_TOKEN_PERSONAL"]) for model_id in MODEL_IDS]
14
+
15
+ dataset = load_dataset("argilla/magpie-ultra-v0.1", split="train")
16
+
17
+ def _get_response(messages):
18
+ client = random.choice(CLIENTS)
19
+ message = client.chat_completion(
20
+ messages=messages,
21
+ stream=False,
22
+ max_tokens=2000
23
+ )
24
+ return message.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def next_input(_prompt, _completion_a, _completion_b):
27
+ new_dataset = dataset.shuffle()
28
+ row = new_dataset[0]
29
+ messages = row["messages"][:-1]
30
+ completions = [row["response"]]
31
+ completions.append(_get_response(messages))
32
+ completions.append(_get_response(messages))
33
+ random.shuffle(completions)
34
+ return messages, completions.pop(), completions.pop()
35
+
36
 
37
  if __name__ == "__main__":
38
+ interface = AnnotatorInterFace.for_chat_generation_preference(
39
+ fn_next_input=next_input,
40
  dataset_name=None,
41
  )
42
  interface.launch()
requirements.txt CHANGED
@@ -1 +1 @@
1
- git+https://github.com/burtenshaw/data-viber.git@example/image-generation-preference#egg=dataset_viber
 
1
+ dataset-viber==0.2.1