Upload folder using huggingface_hub
Browse files- README.md +54 -0
- app.py +259 -0
- requirements.txt +8 -0
- wpkklhc6/config.yaml +32 -0
- wpkklhc6/image_adapter.pt +3 -0
README.md
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Here's a concise, structured, and aesthetically formatted markdown description for your GitHub repo README:
|
2 |
+
|
3 |
+
# Image Captioning App
|
4 |
+
|
5 |
+
## Overview
|
6 |
+
|
7 |
+
This application generates descriptive captions for images using advanced ML models. It processes single images or entire directories, leveraging CLIP and LLM models for accurate and contextual captions. It has NSFW captioning support with natural language.
|
8 |
+
|
9 |
+
## Features
|
10 |
+
|
11 |
+
- Single image and batch processing
|
12 |
+
- Multiple directory support
|
13 |
+
- Custom output directory
|
14 |
+
- Adjustable batch size
|
15 |
+
- Progress tracking
|
16 |
+
|
17 |
+
## Usage
|
18 |
+
|
19 |
+
| Command | Description |
|
20 |
+
|---------|-------------|
|
21 |
+
| `python app.py image.jpg` | Process a single image |
|
22 |
+
| `python app.py /path/to/directory` | Process all images in a directory |
|
23 |
+
| `python app.py /path/to/dir1 /path/to/dir2` | Process multiple directories |
|
24 |
+
| `python app.py /path/to/dir --output /path/to/output` | Specify output directory |
|
25 |
+
| `python app.py /path/to/dir --bs 8` | Set batch size (default: 4) |
|
26 |
+
|
27 |
+
## Technical Details
|
28 |
+
|
29 |
+
- **Models**: CLIP (vision), LLM (language), custom ImageAdapter
|
30 |
+
- **Optimization**: CUDA-enabled GPU support
|
31 |
+
- **Error Handling**: Skips problematic images in batch processing
|
32 |
+
|
33 |
+
## Requirements
|
34 |
+
|
35 |
+
- Python 3.x
|
36 |
+
- PyTorch
|
37 |
+
- Transformers library
|
38 |
+
- CUDA-capable GPU (recommended)
|
39 |
+
|
40 |
+
## Installation
|
41 |
+
|
42 |
+
```bash
|
43 |
+
git clone https://huggingface.co/Wi-zz/joy-caption-pre-alpha
|
44 |
+
cd joy-caption-pre-alpha
|
45 |
+
pip install -r requirements.txt
|
46 |
+
```
|
47 |
+
|
48 |
+
## Contributing
|
49 |
+
|
50 |
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
51 |
+
|
52 |
+
## License
|
53 |
+
|
54 |
+
This project is licensed under the [MIT License](LICENSE).
|
app.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# # For a single image
|
2 |
+
# python app.py image.jpg
|
3 |
+
|
4 |
+
# # For a single directory
|
5 |
+
# python app.py /path/to/directory
|
6 |
+
|
7 |
+
# # For multiple directories
|
8 |
+
# python app.py /path/to/directory1 /path/to/directory2 /path/to/directory3
|
9 |
+
|
10 |
+
# # With output directory specified
|
11 |
+
# python app.py /path/to/directory1 /path/to/directory2 --output /path/to/output
|
12 |
+
|
13 |
+
# # With batch size specified
|
14 |
+
# python app.py /path/to/directory1 /path/to/directory2 --bs 8
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.amp.autocast_mode
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
+
import logging
|
21 |
+
import warnings
|
22 |
+
import argparse
|
23 |
+
from PIL import Image
|
24 |
+
from pathlib import Path
|
25 |
+
from tqdm import tqdm
|
26 |
+
from torch import nn
|
27 |
+
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
|
28 |
+
from typing import List
|
29 |
+
|
30 |
+
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
31 |
+
VLM_PROMPT = "A descriptive caption for this image:\n"
|
32 |
+
MODEL_PATH = "unsloth/Meta-Llama-3.1-8B-bnb-4bit"
|
33 |
+
CHECKPOINT_PATH = Path("wpkklhc6")
|
34 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
35 |
+
|
36 |
+
class ImageAdapter(nn.Module):
|
37 |
+
def __init__(self, input_features: int, output_features: int):
|
38 |
+
super().__init__()
|
39 |
+
self.linear1 = nn.Linear(input_features, output_features)
|
40 |
+
self.activation = nn.GELU()
|
41 |
+
self.linear2 = nn.Linear(output_features, output_features)
|
42 |
+
|
43 |
+
def forward(self, vision_outputs: torch.Tensor):
|
44 |
+
x = self.linear1(vision_outputs)
|
45 |
+
x = self.activation(x)
|
46 |
+
x = self.linear2(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
# Load CLIP
|
50 |
+
print("Loading CLIP ๐")
|
51 |
+
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
|
52 |
+
clip_model = AutoModel.from_pretrained(CLIP_PATH)
|
53 |
+
clip_model = clip_model.vision_model
|
54 |
+
clip_model.eval()
|
55 |
+
clip_model.requires_grad_(False)
|
56 |
+
clip_model.to("cuda")
|
57 |
+
|
58 |
+
# Tokenizer
|
59 |
+
print("Loading tokenizer ๐ช")
|
60 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
|
61 |
+
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
|
62 |
+
|
63 |
+
# LLM
|
64 |
+
print("Loading LLM ๐ค")
|
65 |
+
logging.getLogger("transformers").setLevel(logging.ERROR)
|
66 |
+
text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
|
67 |
+
text_model.eval()
|
68 |
+
|
69 |
+
# Image Adapter
|
70 |
+
print("Loading image adapter ๐ผ๏ธ")
|
71 |
+
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size)
|
72 |
+
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu", weights_only=True))
|
73 |
+
image_adapter.eval()
|
74 |
+
image_adapter.to("cuda")
|
75 |
+
|
76 |
+
@torch.no_grad()
|
77 |
+
def stream_chat(input_images: List[Image.Image], batch_size=4, pbar=None):
|
78 |
+
torch.cuda.empty_cache()
|
79 |
+
all_captions = []
|
80 |
+
|
81 |
+
if not isinstance(input_images, list):
|
82 |
+
input_images = [input_images]
|
83 |
+
|
84 |
+
for i in range(0, len(input_images), batch_size):
|
85 |
+
batch = input_images[i:i+batch_size]
|
86 |
+
|
87 |
+
# Preprocess image batch
|
88 |
+
try:
|
89 |
+
images = clip_processor(images=batch, return_tensors='pt', padding=True).pixel_values
|
90 |
+
except ValueError as e:
|
91 |
+
print(f"Error processing image batch: {e}")
|
92 |
+
print("Skipping this batch and continuing...")
|
93 |
+
continue
|
94 |
+
|
95 |
+
images = images.to('cuda')
|
96 |
+
|
97 |
+
# Embed image batch
|
98 |
+
with torch.amp.autocast_mode.autocast('cuda', enabled=True):
|
99 |
+
vision_outputs = clip_model(pixel_values=images, output_hidden_states=True)
|
100 |
+
image_features = vision_outputs.hidden_states[-2]
|
101 |
+
embedded_images = image_adapter(image_features)
|
102 |
+
embedded_images = embedded_images.to(dtype=torch.bfloat16)
|
103 |
+
|
104 |
+
# Embed prompt
|
105 |
+
prompt = tokenizer.encode(VLM_PROMPT, return_tensors='pt')
|
106 |
+
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda')).to(dtype=torch.bfloat16)
|
107 |
+
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64)).to(dtype=torch.bfloat16)
|
108 |
+
|
109 |
+
# Construct prompts
|
110 |
+
inputs_embeds = torch.cat([
|
111 |
+
embedded_bos.expand(embedded_images.shape[0], -1, -1),
|
112 |
+
embedded_images,
|
113 |
+
prompt_embeds.expand(embedded_images.shape[0], -1, -1),
|
114 |
+
], dim=1).to(dtype=torch.bfloat16)
|
115 |
+
|
116 |
+
input_ids = torch.cat([
|
117 |
+
torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).expand(embedded_images.shape[0], -1),
|
118 |
+
torch.zeros((embedded_images.shape[0], embedded_images.shape[1]), dtype=torch.long),
|
119 |
+
prompt.expand(embedded_images.shape[0], -1),
|
120 |
+
], dim=1).to('cuda')
|
121 |
+
|
122 |
+
attention_mask = torch.ones_like(input_ids)
|
123 |
+
|
124 |
+
generate_ids = text_model.generate(
|
125 |
+
input_ids=input_ids,
|
126 |
+
inputs_embeds=inputs_embeds,
|
127 |
+
attention_mask=attention_mask,
|
128 |
+
max_new_tokens=300,
|
129 |
+
do_sample=True,
|
130 |
+
top_k=10,
|
131 |
+
temperature=0.5,
|
132 |
+
)
|
133 |
+
|
134 |
+
if pbar:
|
135 |
+
pbar.update(len(batch))
|
136 |
+
|
137 |
+
# Trim off the prompt
|
138 |
+
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
139 |
+
|
140 |
+
for ids in generate_ids:
|
141 |
+
if ids[-1] == tokenizer.eos_token_id:
|
142 |
+
ids = ids[:-1]
|
143 |
+
caption = tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
144 |
+
# Remove any remaining special tokens
|
145 |
+
caption = caption.replace('<|end_of_text|>', '').replace('<|finetune_right_pad_id|>', '').strip()
|
146 |
+
all_captions.append(caption)
|
147 |
+
|
148 |
+
return all_captions
|
149 |
+
|
150 |
+
def preprocess_image(img):
|
151 |
+
return img.convert('RGBA')
|
152 |
+
|
153 |
+
def process_image(image_path, output_path, pbar=None):
|
154 |
+
try:
|
155 |
+
with Image.open(image_path) as img:
|
156 |
+
# Convert image to RGB
|
157 |
+
img = img.convert('RGB')
|
158 |
+
caption = stream_chat([img], pbar=pbar)[0]
|
159 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
160 |
+
f.write(caption)
|
161 |
+
except Exception as e:
|
162 |
+
print(f"Error processing {image_path}: {e}")
|
163 |
+
if pbar:
|
164 |
+
pbar.update(1)
|
165 |
+
return
|
166 |
+
|
167 |
+
with Image.open(image_path) as img:
|
168 |
+
# Pass the image as a list to stream_chat
|
169 |
+
caption = stream_chat([img], pbar=pbar)[0] # Get the first (and only) caption
|
170 |
+
|
171 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
172 |
+
f.write(caption)
|
173 |
+
|
174 |
+
def process_directory(input_dir, output_dir, batch_size):
|
175 |
+
input_path = Path(input_dir)
|
176 |
+
output_path = Path(output_dir)
|
177 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
178 |
+
|
179 |
+
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
|
180 |
+
image_files = [f for f in input_path.iterdir() if f.suffix.lower() in image_extensions]
|
181 |
+
|
182 |
+
# Create a list to store images that need processing
|
183 |
+
images_to_process = []
|
184 |
+
|
185 |
+
# Check which images need processing
|
186 |
+
for file in image_files:
|
187 |
+
output_file = output_path / (file.stem + '.txt')
|
188 |
+
if not output_file.exists():
|
189 |
+
images_to_process.append(file)
|
190 |
+
else:
|
191 |
+
print(f"Skipping {file.name} - Caption already exists")
|
192 |
+
|
193 |
+
# Process images in batches
|
194 |
+
with tqdm(total=len(images_to_process), desc="Processing images", unit="image") as pbar:
|
195 |
+
for i in range(0, len(images_to_process), batch_size):
|
196 |
+
batch_files = images_to_process[i:i+batch_size]
|
197 |
+
batch_images = []
|
198 |
+
for f in batch_files:
|
199 |
+
try:
|
200 |
+
img = Image.open(f).convert('RGB')
|
201 |
+
batch_images.append(img)
|
202 |
+
except Exception as e:
|
203 |
+
print(f"Error opening {f}: {e}")
|
204 |
+
continue
|
205 |
+
|
206 |
+
if batch_images:
|
207 |
+
captions = stream_chat(batch_images, batch_size, pbar)
|
208 |
+
for file, caption in zip(batch_files, captions):
|
209 |
+
output_file = output_path / (file.stem + '.txt')
|
210 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
211 |
+
f.write(caption)
|
212 |
+
|
213 |
+
# Close the image files
|
214 |
+
for img in batch_images:
|
215 |
+
img.close()
|
216 |
+
|
217 |
+
def parse_arguments():
|
218 |
+
parser = argparse.ArgumentParser(description="Process images and generate captions.")
|
219 |
+
parser.add_argument("input", nargs='+', help="Input image file or directory (or multiple directories)")
|
220 |
+
parser.add_argument("--output", help="Output directory (optional)")
|
221 |
+
parser.add_argument("--bs", type=int, default=4, help="Batch size (default: 4)")
|
222 |
+
return parser.parse_args()
|
223 |
+
|
224 |
+
def is_image_file(file_path):
|
225 |
+
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
|
226 |
+
return Path(file_path).suffix.lower() in image_extensions
|
227 |
+
|
228 |
+
# Main execution
|
229 |
+
if __name__ == "__main__":
|
230 |
+
args = parse_arguments()
|
231 |
+
input_paths = [Path(input_path) for input_path in args.input]
|
232 |
+
batch_size = args.bs
|
233 |
+
|
234 |
+
for input_path in input_paths:
|
235 |
+
if input_path.is_file() and is_image_file(input_path):
|
236 |
+
# Single file processing
|
237 |
+
output_path = input_path.with_suffix('.txt')
|
238 |
+
print(f"Processing single image ๐๏ธ: {input_path.name}")
|
239 |
+
with tqdm(total=1, desc="Processing image", unit="image") as pbar:
|
240 |
+
process_image(input_path, output_path, pbar)
|
241 |
+
print(f"Output saved to {output_path}")
|
242 |
+
elif input_path.is_dir():
|
243 |
+
# Directory processing
|
244 |
+
output_path = Path(args.output) if args.output else input_path
|
245 |
+
print(f"Processing directory ๐: {input_path}")
|
246 |
+
print(f"Output directory ๐ฆ: {output_path}")
|
247 |
+
print(f"Batch size ๐๏ธ: {batch_size}")
|
248 |
+
process_directory(input_path, output_path, batch_size)
|
249 |
+
else:
|
250 |
+
print(f"Invalid input: {input_path}")
|
251 |
+
print("Skipping...")
|
252 |
+
|
253 |
+
if not input_paths:
|
254 |
+
print("Usage:")
|
255 |
+
print("For single image: python app.py [image_file] [--bs batch_size]")
|
256 |
+
print("For directory (same input/output): python app.py [directory] [--bs batch_size]")
|
257 |
+
print("For directory (separate input/output): python app.py [directory] --output [output_directory] [--bs batch_size]")
|
258 |
+
print("For multiple directories: python app.py [directory1] [directory2] ... [--output output_directory] [--bs batch_size]")
|
259 |
+
sys.exit(1)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub==0.24.3
|
2 |
+
accelerate
|
3 |
+
torch
|
4 |
+
transformers==4.43.3
|
5 |
+
sentencepiece
|
6 |
+
bitsandbytes
|
7 |
+
Pillow
|
8 |
+
protobuf
|
wpkklhc6/config.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wandb_project: joy-caption-1
|
2 |
+
device_batch_size: 2
|
3 |
+
batch_size: 256
|
4 |
+
learning_rate: 0.001
|
5 |
+
warmup_samples: 18000
|
6 |
+
max_samples: 600000
|
7 |
+
save_every: 50000
|
8 |
+
test_every: 50000
|
9 |
+
use_amp: true
|
10 |
+
grad_scaler: true
|
11 |
+
lr_scheduler_type: cosine
|
12 |
+
min_lr_ratio: 0.0
|
13 |
+
allow_tf32: true
|
14 |
+
seed: 42
|
15 |
+
num_workers: 8
|
16 |
+
optimizer_type: adamw
|
17 |
+
adam_beta1: 0.9
|
18 |
+
adam_beta2: 0.999
|
19 |
+
adam_eps: 1.0e-08
|
20 |
+
adam_weight_decay: 0.0
|
21 |
+
clip_grad_norm: 1.0
|
22 |
+
dataset: fancyfeast/joy-captioning-20240729a
|
23 |
+
clip_model: google/siglip-so400m-patch14-384
|
24 |
+
text_model: meta-llama/Meta-Llama-3.1-8B
|
25 |
+
resume: null
|
26 |
+
gradient_checkpointing: false
|
27 |
+
test_size: 2048
|
28 |
+
grad_scaler_init: 65536.0
|
29 |
+
max_caption_length: 257
|
30 |
+
num_image_tokens: 32
|
31 |
+
adapter_type: mlp
|
32 |
+
text_model_dtype: float16
|
wpkklhc6/image_adapter.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2ebb1d1437bbb3264a6f25a896b25a7c7dd06c570c5de909dc2f19d3a5c5c110
|
3 |
+
size 86018240
|