Spaces:
Runtime error
Runtime error
import transformers | |
from transformers import AutoTokenizer | |
from transformers import pipeline, set_seed, LogitsProcessor | |
from transformers.generation.logits_process import TopPLogitsWarper, TopKLogitsWarper | |
import torch | |
from scipy.special import gamma, gammainc, gammaincc, betainc | |
from scipy.optimize import fminbound | |
import numpy as np | |
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') | |
def hash_tokens(input_ids: torch.LongTensor, key: int): | |
seed = key | |
salt = 35317 | |
for i in input_ids: | |
seed = (seed * salt + i.item()) % (2 ** 64 - 1) | |
return seed | |
class WatermarkingLogitsProcessor(LogitsProcessor): | |
def __init__(self, n, key, messages, window_size, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.batch_size = len(messages) | |
self.generators = [ torch.Generator(device=device) for _ in range(self.batch_size) ] | |
self.n = n | |
self.key = key | |
self.window_size = window_size | |
if not self.window_size: | |
for b in range(self.batch_size): | |
self.generators[b].manual_seed(self.key) | |
self.messages = messages | |
class WatermarkingAaronsonLogitsProcessor( WatermarkingLogitsProcessor): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
# get random uniform variables | |
B, V = scores.shape | |
r = torch.zeros_like(scores) | |
for b in range(B): | |
if self.window_size: | |
window = input_ids[b, -self.window_size:] | |
seed = hash_tokens(window, self.key) | |
self.generators[b].manual_seed(seed) | |
r[b] = torch.rand(self.n, generator=self.generators[b], device=self.generators[b].device).log().roll(-self.messages[b]) | |
# generate n but keep only V, as we want to keep the pseudo-random sequences in sync with the decoder | |
r = r[:,:V] | |
# modify law as r^(1/p) | |
# Since we want to return logits (logits processor takes and outputs logits), | |
# we return log(q), hence torch.log(r) * torch.log(torch.exp(1/p)) = torch.log(r) / p | |
return r / scores.exp() | |
class WatermarkingKirchenbauerLogitsProcessor(WatermarkingLogitsProcessor): | |
def __init__(self, *args, | |
gamma = 0.25, | |
delta = 15.0, | |
**kwargs): | |
super().__init__(*args, **kwargs) | |
self.gamma = gamma | |
self.delta = delta | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
B, V = scores.shape | |
for b in range(B): | |
if self.window_size: | |
window = input_ids[b, -self.window_size:] | |
seed = hash_tokens(window, self.key) | |
self.generators[b].manual_seed(seed) | |
vocab_permutation = torch.randperm(self.n, generator=self.generators[b], device=self.generators[b].device) | |
greenlist = vocab_permutation[:int(self.gamma * self.n)] # gamma * n | |
bias = torch.zeros(self.n).to(scores.device) | |
bias[greenlist] = self.delta | |
bias = bias.roll(-self.messages[b])[:V] | |
scores[b] += bias # add bias to greenlist words | |
return scores | |
class Watermarker(object): | |
def __init__(self, tokenizer=None, model=None, window_size = 0, payload_bits = 0, logits_processor = None, *args, **kwargs): | |
self.tokenizer = tokenizer | |
self.model = model | |
self.model.eval() | |
self.window_size = window_size | |
# preprocessing wrappers | |
self.logits_processor = logits_processor or [] | |
self.payload_bits = payload_bits | |
self.V = max(2**payload_bits, self.model.config.vocab_size) | |
self.generator = torch.Generator(device=device) | |
def embed(self, key=42, messages=[1234], prompt="", max_length=30, method='aaronson'): | |
B = len(messages) # batch size | |
length = max_length | |
# compute capacity | |
if self.payload_bits: | |
assert min([message >= 0 and message < 2**self.payload_bits for message in messages]) | |
# tokenize prompt | |
inputs = self.tokenizer([ prompt ] * B, return_tensors="pt") | |
if method == 'aaronson': | |
# generate with greedy search | |
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False, | |
logits_processor = self.logits_processor + [ | |
WatermarkingAaronsonLogitsProcessor(n=self.V, | |
key=key, | |
messages=messages, | |
window_size = self.window_size)]) | |
elif method == 'kirchenbauer': | |
# use sampling | |
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True, | |
logits_processor = self.logits_processor + [ | |
WatermarkingKirchenbauerLogitsProcessor(n=self.V, | |
key=key, | |
messages=messages, | |
window_size = self.window_size)]) | |
elif method == 'greedy': | |
# generate with greedy search | |
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False, | |
logits_processor = self.logits_processor) | |
elif method == 'sampling': | |
# generate with greedy search | |
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True, | |
logits_processor = self.logits_processor) | |
else: | |
raise Exception('Unknown method %s' % method) | |
decoded_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
return decoded_texts | |
def detect(self, attacked_texts, key=42, method='aaronson', gamma=0.5, prompts=None): | |
if(prompts==None): | |
prompts = [""] * len(attacked_texts) | |
generator = self.generator | |
print("attacked_texts = ", attacked_texts) | |
print("prompts = ", prompts) | |
cdfs = [] | |
ms = [] | |
MAX = 2**self.payload_bits | |
# tokenize input | |
inputs = self.tokenizer(attacked_texts, return_tensors="pt", padding=True, return_attention_mask=True) | |
input_ids = inputs["input_ids"].to(self.model.device) | |
attention_masks = inputs["attention_mask"].to(self.model.device) | |
B,T = input_ids.shape | |
if method == 'aaronson_neyman_pearson': | |
# compute logits | |
outputs = self.model.forward(input_ids, return_dict=True) | |
logits = outputs['logits'] | |
# TODO | |
# reapply logits processors to get same distribution | |
#for i in range(T): | |
# for processor in self.logits_processor: | |
# logits[:,i] = processor(input_ids[:, :i], logits[:, i]) | |
probs = logits.softmax(dim=-1) | |
ps = torch.gather(probs, 2, input_ids[:,1:,None]).squeeze_(-1) | |
seq_len = input_ids.shape[1] | |
length = seq_len | |
V = self.V | |
Z = torch.zeros(size=(B, V), dtype=torch.float32, device=device) | |
# keep a history of contexts we have already seen, | |
# to exclude them from score aggregation and allow | |
# correct p-value computation under H0 | |
history = [set() for _ in range(B)] | |
attention_masks_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True, return_attention_mask=True)["attention_mask"] | |
prompts_length = torch.sum(attention_masks_prompts, dim=1) | |
for b in range(B): | |
attention_masks[b, :prompts_length[b]] = 0 | |
if not self.window_size: | |
generator.manual_seed(key) | |
# We can go from seq_len - prompt_len, need to change +1 to + prompt_len | |
for i in range(seq_len-1): | |
if self.window_size: | |
window = input_ids[b, max(0, i-self.window_size+1):i+1] | |
#print("window = ", window) | |
seed = hash_tokens(window, key) | |
if seed not in history[b]: | |
generator.manual_seed(seed) | |
history[b].add(seed) | |
else: | |
# ignore the token | |
attention_masks[b, i+1] = 0 | |
if not attention_masks[b,i+1]: | |
continue | |
token = int(input_ids[b,i+1]) | |
if method in {'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson'}: | |
R = torch.rand(V, generator = generator, device = generator.device) | |
if method == 'aaronson': | |
r = -(1-R).log() | |
elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}: | |
r = -R.log() | |
elif method == 'kirchenbauer': | |
r = torch.zeros(V, device=device) | |
vocab_permutation = torch.randperm(V, generator = generator, device=generator.device) | |
greenlist = vocab_permutation[:int(gamma * V)] | |
r[greenlist] = 1 | |
else: | |
raise Exception('Unknown method %s' % method) | |
if method in {'aaronson', 'aaronson_simplified', 'kirchenbauer'}: | |
# independent of probs | |
Z[b] += r.roll(-token) | |
elif method == 'aaronson_neyman_pearson': | |
# Neyman-Pearson | |
Z[b] += r.roll(-token) * (1/ps[b,i] - 1) | |
for b in range(B): | |
if method in {'aaronson', 'kirchenbauer'}: | |
m = torch.argmax(Z[b,:MAX]) | |
elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}: | |
m = torch.argmin(Z[b,:MAX]) | |
i = int(m) | |
S = Z[b, i].item() | |
m = i | |
# actual sequence length | |
k = torch.sum(attention_masks[b]).item() - 1 | |
if method == 'aaronson': | |
cdf = gammaincc(k, S) | |
elif method == 'aaronson_simplified': | |
cdf = gammainc(k, S) | |
elif method == 'aaronson_neyman_pearson': | |
# Chernoff bound | |
ratio = ps[b,:k] / (1 - ps[b,:k]) | |
E = (1/ratio).sum() | |
if S > E: | |
cdf = 1.0 | |
else: | |
# to compute p-value we must solve for c*: | |
# (1/(c* + ps/(1-ps))).sum() = S | |
func = lambda c : (((1 / (c + ratio)).sum() - S)**2).item() | |
c1 = (k / S - torch.min(ratio)).item() | |
print("max = ", c1) | |
c = fminbound(func, 0, c1) | |
print("solved c = ", c) | |
print("solved s = ", ((1/(c + ratio)).sum()).item()) | |
# upper bound | |
cdf = torch.exp(torch.sum(-torch.log(1 + c / ratio)) + c * S) | |
elif method == 'kirchenbauer': | |
cdf = betainc(S, k - S + 1, gamma) | |
if cdf > min(1 / MAX, 1e-5): | |
cdf = 1 - (1 - cdf)**MAX # true value | |
else: | |
cdf = cdf * MAX # numerically stable upper bound | |
cdfs.append(float(cdf)) | |
ms.append(m) | |
return cdfs, ms | |