Yeonchan Ahn
first test
d23aa66
raw
history blame
2.42 kB
import datasets
import evaluate
from typing import List, Union
import torch
import torch.nn.functional as F
import torch.nn as nn
_DESCRIPTION = """
Cosine similarity between two pairs of embeddings where each embedding represents the semantics of object .
"""
_KWARGS_DESCRIPTION = """
Args:
predictions (`list` of a list of `int`): a group of embeddings
references (`list` of `int`): the other group of embeddings paired with the predictions
Returns:
cos_similarity ("float") : average cosine similarity between two pairs of embeddings
Examples:
Example 1-A simple example
>>> cos_similarity_metrics = evaluate.load("ahnyeonchan/cosine_sim_btw_embeddings_of_same_semantics")
>>> results = accuracy_metric.compute(references=[[1.0, 1.0], [0.0, 1.0]], predictions=[[1.0, 1.0], [0.0, 1.0]])
>>> print(results)
{'cos_similarity': 1.0}
"""
_CITATION = """"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class CosSim(evaluate.Metric):
def __init__(self, *args, **kwargs):
super(CosSim, self).__init__(*args, **kwargs)
self.cossim = nn.CosineSimilarity()
def _info(self):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("float32")),
"references": datasets.Sequence(datasets.Value("float32")),
}
),
reference_urls=[],
)
def _compute(self, predictions: List[List], references: List[List]):
if isinstance(predictions, torch.Tensor):
predictions = torch.Tensor(predictions)
elif isinstance(predictions, list):
predictions = torch.Tensor(predictions)
else:
raise NotImplementedError()
if isinstance(references, torch.Tensor):
references = torch.Tensor(references)
elif isinstance(references, list):
references = torch.Tensor(references)
else:
raise NotImplementedError()
cosims = self.cossim(predictions, references)
val = torch.mean(cossim).item()
return {
"cos_similarity": float(val)
}