Some questions about Paligemma’s segmentation capabilities

#1
by cslys1999 - opened

I downloaded them from the two links :
https://huggingface.co/collections/google/paligemma-release-6643a9ffbf57de2ae0448dda
https://huggingface.co/collections/google/paligemma-ft-models-6643b03efb769dad650d2dda.
And I got
google/paligemma-3b-ft-refcoco-seg-224,
google/paligemma-3b-mix-224,
google/paligemma-3b-ft-refcoco-seg-224
three checkpoints.
I want to use Paligemma to perform my instance segmentation task. However, I found that the ability of these three checkpoints on instance segmentation is not as good as the model in huggingface demo. Why is this?

My inference code looks like this:


from PIL import Image
import requests
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
import torch
from loguru import logger
"""
ref: https://colab.research.google.com/drive/1gOhRCFyt9yIoasJkd4VoaHcIqJPdJnlg?usp=sharing#scrollTo=u9kau5IOjNt9
"""

model = PaliGemmaForConditionalGeneration.from_pretrained("./paligemma_refcoco_ft/refcoco_ft", torch_dtype=torch.float32)
processor = AutoProcessor.from_pretrained("./paligemma_refcoco_ft/refcoco_ft")


model = model.to("cuda:0")

# url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
# image = Image.open(requests.get(url, stream=True).raw)
while True:
    try:
        image_path = input("the image file path: ")
        prompt = input("please enter the prompt:")
        # from IPython import embed;embed()
        image = Image.open(image_path).convert("RGB")
        inputs = processor(text=prompt, images=image, return_tensors="pt", padding=True)
        inputs = inputs.to(dtype=model.dtype)
        inputs = {key: value.to("cuda:0") for key, value in inputs.items()}
        # Generate
        from IPython import embed;embed()
        with torch.inference_mode():
            
            generate_ids = model.generate(**inputs, max_new_tokens=100)

        result = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        print(result)
    except Exception as e:
        logger.error(e)

Looking forward to getting your answers very much!

Sign up or log in to comment