Spaces:
Runtime error
Runtime error
import datasets | |
import evaluate | |
from typing import List, Union | |
import torch | |
import torch.nn.functional as F | |
import torch.nn as nn | |
_DESCRIPTION = """ | |
Cosine similarity between two pairs of embeddings where each embedding represents the semantics of object . | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Args: | |
predictions (`list` of a list of `int`): a group of embeddings | |
references (`list` of `int`): the other group of embeddings paired with the predictions | |
Returns: | |
cos_similarity ("float") : average cosine similarity between two pairs of embeddings | |
Examples: | |
Example 1-A simple example | |
>>> cos_similarity_metrics = evaluate.load("ahnyeonchan/cosine_sim_btw_embeddings_of_same_semantics") | |
>>> results = accuracy_metric.compute(references=[[1.0, 1.0], [0.0, 1.0]], predictions=[[1.0, 1.0], [0.0, 1.0]]) | |
>>> print(results) | |
{'cos_similarity': 1.0} | |
""" | |
_CITATION = """""" | |
class CosSim(evaluate.Metric): | |
def __init__(self, *args, **kwargs): | |
super(CosSim, self).__init__(*args, **kwargs) | |
self.cossim = nn.CosineSimilarity() | |
def _info(self): | |
return evaluate.MetricInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"predictions": datasets.Sequence(datasets.Value("float32")), | |
"references": datasets.Sequence(datasets.Value("float32")), | |
} | |
), | |
reference_urls=[], | |
) | |
def _compute(self, predictions: List[List], references: List[List]): | |
if isinstance(predictions, torch.Tensor): | |
predictions = torch.Tensor(predictions) | |
elif isinstance(predictions, list): | |
predictions = torch.Tensor(predictions) | |
else: | |
raise NotImplementedError() | |
if isinstance(references, torch.Tensor): | |
references = torch.Tensor(references) | |
elif isinstance(references, list): | |
references = torch.Tensor(references) | |
else: | |
raise NotImplementedError() | |
cosims = self.cossim(predictions, references) | |
val = torch.mean(cossim).item() | |
return { | |
"cos_similarity": float(val) | |
} | |