Desm0nt commited on
Commit
bf4ab95
1 Parent(s): 557fcf1

Upload new_captioner.py

Browse files
Files changed (1) hide show
  1. new_captioner.py +109 -0
new_captioner.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import requests
3
+ import os
4
+ from openai import OpenAI
5
+ from tqdm import tqdm
6
+ import time
7
+ import sys
8
+
9
+ # Проверка наличия аргумента командной строки
10
+ if len(sys.argv) < 2:
11
+ print("Please, provide the path to image folder.")
12
+ sys.exit(1)
13
+
14
+ # Get the path to image dir from command line.
15
+ image_dir = sys.argv[1]
16
+
17
+ openai_api_key = "EMPTY"
18
+ openai_api_base = "http://localhost:8000/v1"
19
+ client = OpenAI(
20
+ api_key=openai_api_key,
21
+ base_url=openai_api_base,
22
+ )
23
+
24
+ model_type = client.models.list().data[0].id
25
+ print(f'model_type: {model_type}')
26
+
27
+ # Function to encode the image
28
+ def encode_image(image_path):
29
+ with open(image_path, "rb") as image_file:
30
+ return base64.b64encode(image_file.read()).decode('utf-8')
31
+
32
+ # Directories
33
+ #dir with tags captions from wd tagger
34
+ txt_dir = './txt/'
35
+ #dir with result captions
36
+ maintxt_dir = './maintxt/'
37
+ image_path =''
38
+
39
+ # Ensure the output directory exists
40
+ os.makedirs(maintxt_dir, exist_ok=True)
41
+
42
+ # Get list of all JPEG images in the directory
43
+ image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg'))]
44
+
45
+ total_files = len(image_files)
46
+ start_time = time.time()
47
+
48
+ progress_bar = tqdm(total=total_files, unit='file', bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]')
49
+ total_elapsed_time = 0
50
+ processed_files = 0
51
+
52
+ # Process all images in the image directory
53
+ for image_file in image_files:
54
+ image_path = os.path.join(image_dir, image_file)
55
+ txt_file = os.path.join(txt_dir, os.path.splitext(image_file)[0] + '.txt')
56
+ output_file = os.path.join(maintxt_dir, os.path.splitext(image_file)[0] + '.txt')
57
+
58
+ # Read tags from the corresponding txt file
59
+ with open(txt_file, 'r') as f:
60
+ tags = f.read().strip()
61
+
62
+ base64_image = encode_image(image_path)
63
+
64
+ step_start_time = time.time()
65
+
66
+ chat_response = client.chat.completions.create(
67
+ model="./phi3_v14_800-merged",
68
+ messages=[{
69
+ "role": "user",
70
+ "content": [
71
+ {"type": "text", "text": f"Make a caption that describe this image. Here is the tags for this image: {tags}"},
72
+ {
73
+ "type": "image_url",
74
+ "image_url": {
75
+ "url": f"data:image/jpeg;base64,{base64_image}"
76
+ },
77
+ },
78
+ ],
79
+ }],
80
+ extra_body={'repetition_penalty': 1.05, 'top_k': -1,'top_p': 1,'temperature': 0, 'use_beam_search': True, 'best_of':5},
81
+ )
82
+
83
+ step_end_time = time.time()
84
+ step_time = step_end_time - step_start_time
85
+ total_elapsed_time += step_time
86
+ remaining_time = (total_elapsed_time / (processed_files + 1)) * (total_files - processed_files - 1)
87
+
88
+ # Convert remaining time to hours, minutes and seconds
89
+ remaining_hours = int(remaining_time // 3600)
90
+ remaining_minutes = int((remaining_time % 3600) // 60)
91
+ remaining_seconds = int(remaining_time % 60)
92
+
93
+ # Extract the content from the response
94
+ content = chat_response.choices[0].message.content
95
+ content = content.lstrip()
96
+ # Write the content to the output file
97
+ with open(output_file, 'w', encoding='utf-8') as f:
98
+ f.write(content)
99
+
100
+ print(f"\n\nFile {image_file}\nProcessing time: {step_time:.2f} seconds\n{content}")
101
+ print(f"Response saved to file: {output_file}")
102
+
103
+ processed_files += 1
104
+ progress_bar.update(1)
105
+ progress_bar.set_postfix(remaining=f'{remaining_hours:02d}:{remaining_minutes:02d}:{remaining_seconds:02d}', refresh=True)
106
+
107
+ progress_bar.close()
108
+ print("All images processed.")
109
+ print(f"Total time: {time.time() - start_time:.2f} seconds")