How to get register token output values ?

#4
by rohitchaudhari25 - opened

are first 5-tokens [CLS] + 4x [REG] tokens?

PyTorch Image Models org

@rohitchaudhari25

Yes, num_prefix_tokens is 1 cls + 4 x reg for these models: https://github.com/huggingface/pytorch-image-models/blob/a6fe31b09670289dbc8e99a0cfae23de355534c9/timm/models/vision_transformer.py#L497-L498

easiest way to get them is forward_features() and take the [1:5] in the flattened output, or you can use forward_intermediates() to get the prefix tokens for all blocks

oo = mm.forward_intermediates(torch.randn(2,3,518,518), return_prefix_tokens=True)
>>>
oo[1][-1][1].shape
torch.Size([2, 5, 768])
PyTorch Image Models org

output there is a tuple of the final features and block output features, each block output is a tuple of spatial features and prefix tokens when return_prefix_tokens is set to True.

Hi! I just would like to ask if what I did below is correct (please see screenshot)
image.png

So you said that the easiest way to get the prefix token embeddings is using forward_features() and take the first 5 in the sequence. I did that (top) and compared it to using forward_intermediates()... However, their outputs are different. Is there something that I have missed? Would appreciate your help! Thank you so much :)

EDIT: I was able to show that they're the same... I forgot to add norm=True argument in forward_intermediates(). Hope this helps!

Hi! It's me again. Just one more question:

Screenshot below is taken from (https://github.com/facebookresearch/dinov2/blob/main/MODEL_CARD.md)
image.png

As I've understood there's a total of 261 tokens (1 class + 4 prefix + 256 patch tokens). Now, going back to the timm version, the output shape is (1, 1374, 768). Is the 1374 semantically equivalent to the 261 i.e., is the 1374 the sequences of tokens? How was it able to come up with this versus the 261? Thank you :-)

PyTorch Image Models org

dinov2 models I think are 518x518 by default ... so 37*37 spatial patches 1 + cls token + 4 reg tokens = 1374 ... it would be 261 if you resized and used 224x224 images

https://github.com/huggingface/pytorch-image-models/blob/a6fe31b09670289dbc8e99a0cfae23de355534c9/timm/models/vision_transformer.py#L1383-L1433

your snippets above are correct if you want both cls + reg tokens together, if you want just the regs then slice [1:5] to get the 4 reg tokens.

Sign up or log in to comment