Llama2_watermarking / watermark.py
Antoine Chaffin
Back to llama
013a7fe
import transformers
from transformers import AutoTokenizer
from transformers import pipeline, set_seed, LogitsProcessor
from transformers.generation.logits_process import TopPLogitsWarper, TopKLogitsWarper
import torch
from scipy.special import gamma, gammainc, gammaincc, betainc
from scipy.optimize import fminbound
import numpy as np
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
def hash_tokens(input_ids: torch.LongTensor, key: int):
seed = key
salt = 35317
for i in input_ids:
seed = (seed * salt + i.item()) % (2 ** 64 - 1)
return seed
class WatermarkingLogitsProcessor(LogitsProcessor):
def __init__(self, n, key, messages, window_size, *args, **kwargs):
super().__init__(*args, **kwargs)
self.batch_size = len(messages)
self.generators = [ torch.Generator(device=device) for _ in range(self.batch_size) ]
self.n = n
self.key = key
self.window_size = window_size
if not self.window_size:
for b in range(self.batch_size):
self.generators[b].manual_seed(self.key)
self.messages = messages
class WatermarkingAaronsonLogitsProcessor( WatermarkingLogitsProcessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# get random uniform variables
B, V = scores.shape
r = torch.zeros_like(scores)
for b in range(B):
if self.window_size:
window = input_ids[b, -self.window_size:]
seed = hash_tokens(window, self.key)
self.generators[b].manual_seed(seed)
r[b] = torch.rand(self.n, generator=self.generators[b], device=self.generators[b].device).log().roll(-self.messages[b])
# generate n but keep only V, as we want to keep the pseudo-random sequences in sync with the decoder
r = r[:,:V]
# modify law as r^(1/p)
# Since we want to return logits (logits processor takes and outputs logits),
# we return log(q), hence torch.log(r) * torch.log(torch.exp(1/p)) = torch.log(r) / p
return r / scores.exp()
class WatermarkingKirchenbauerLogitsProcessor(WatermarkingLogitsProcessor):
def __init__(self, *args,
gamma = 0.25,
delta = 15.0,
**kwargs):
super().__init__(*args, **kwargs)
self.gamma = gamma
self.delta = delta
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
B, V = scores.shape
for b in range(B):
if self.window_size:
window = input_ids[b, -self.window_size:]
seed = hash_tokens(window, self.key)
self.generators[b].manual_seed(seed)
vocab_permutation = torch.randperm(self.n, generator=self.generators[b], device=self.generators[b].device)
greenlist = vocab_permutation[:int(self.gamma * self.n)] # gamma * n
bias = torch.zeros(self.n).to(scores.device)
bias[greenlist] = self.delta
bias = bias.roll(-self.messages[b])[:V]
scores[b] += bias # add bias to greenlist words
return scores
class Watermarker(object):
def __init__(self, tokenizer=None, model=None, window_size = 0, payload_bits = 0, logits_processor = None, *args, **kwargs):
self.tokenizer = tokenizer
self.model = model
self.model.eval()
self.window_size = window_size
# preprocessing wrappers
self.logits_processor = logits_processor or []
self.payload_bits = payload_bits
self.V = max(2**payload_bits, self.model.config.vocab_size)
self.generator = torch.Generator(device=device)
def embed(self, key=42, messages=[1234], prompt="", max_length=30, method='aaronson'):
B = len(messages) # batch size
length = max_length
# compute capacity
if self.payload_bits:
assert min([message >= 0 and message < 2**self.payload_bits for message in messages])
# tokenize prompt
inputs = self.tokenizer([ prompt ] * B, return_tensors="pt")
if method == 'aaronson':
# generate with greedy search
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
logits_processor = self.logits_processor + [
WatermarkingAaronsonLogitsProcessor(n=self.V,
key=key,
messages=messages,
window_size = self.window_size)])
elif method == 'kirchenbauer':
# use sampling
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
logits_processor = self.logits_processor + [
WatermarkingKirchenbauerLogitsProcessor(n=self.V,
key=key,
messages=messages,
window_size = self.window_size)])
elif method == 'greedy':
# generate with greedy search
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
logits_processor = self.logits_processor)
elif method == 'sampling':
# generate with greedy search
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
logits_processor = self.logits_processor)
else:
raise Exception('Unknown method %s' % method)
decoded_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return decoded_texts
def detect(self, attacked_texts, key=42, method='aaronson', gamma=0.5, prompts=None):
if(prompts==None):
prompts = [""] * len(attacked_texts)
generator = self.generator
print("attacked_texts = ", attacked_texts)
print("prompts = ", prompts)
cdfs = []
ms = []
MAX = 2**self.payload_bits
# tokenize input
inputs = self.tokenizer(attacked_texts, return_tensors="pt", padding=True, return_attention_mask=True)
input_ids = inputs["input_ids"].to(self.model.device)
attention_masks = inputs["attention_mask"].to(self.model.device)
B,T = input_ids.shape
if method == 'aaronson_neyman_pearson':
# compute logits
outputs = self.model.forward(input_ids, return_dict=True)
logits = outputs['logits']
# TODO
# reapply logits processors to get same distribution
#for i in range(T):
# for processor in self.logits_processor:
# logits[:,i] = processor(input_ids[:, :i], logits[:, i])
probs = logits.softmax(dim=-1)
ps = torch.gather(probs, 2, input_ids[:,1:,None]).squeeze_(-1)
seq_len = input_ids.shape[1]
length = seq_len
V = self.V
Z = torch.zeros(size=(B, V), dtype=torch.float32, device=device)
# keep a history of contexts we have already seen,
# to exclude them from score aggregation and allow
# correct p-value computation under H0
history = [set() for _ in range(B)]
attention_masks_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True, return_attention_mask=True)["attention_mask"]
prompts_length = torch.sum(attention_masks_prompts, dim=1)
for b in range(B):
attention_masks[b, :prompts_length[b]] = 0
if not self.window_size:
generator.manual_seed(key)
# We can go from seq_len - prompt_len, need to change +1 to + prompt_len
for i in range(seq_len-1):
if self.window_size:
window = input_ids[b, max(0, i-self.window_size+1):i+1]
#print("window = ", window)
seed = hash_tokens(window, key)
if seed not in history[b]:
generator.manual_seed(seed)
history[b].add(seed)
else:
# ignore the token
attention_masks[b, i+1] = 0
if not attention_masks[b,i+1]:
continue
token = int(input_ids[b,i+1])
if method in {'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson'}:
R = torch.rand(V, generator = generator, device = generator.device)
if method == 'aaronson':
r = -(1-R).log()
elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
r = -R.log()
elif method == 'kirchenbauer':
r = torch.zeros(V, device=device)
vocab_permutation = torch.randperm(V, generator = generator, device=generator.device)
greenlist = vocab_permutation[:int(gamma * V)]
r[greenlist] = 1
else:
raise Exception('Unknown method %s' % method)
if method in {'aaronson', 'aaronson_simplified', 'kirchenbauer'}:
# independent of probs
Z[b] += r.roll(-token)
elif method == 'aaronson_neyman_pearson':
# Neyman-Pearson
Z[b] += r.roll(-token) * (1/ps[b,i] - 1)
for b in range(B):
if method in {'aaronson', 'kirchenbauer'}:
m = torch.argmax(Z[b,:MAX])
elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
m = torch.argmin(Z[b,:MAX])
i = int(m)
S = Z[b, i].item()
m = i
# actual sequence length
k = torch.sum(attention_masks[b]).item() - 1
if method == 'aaronson':
cdf = gammaincc(k, S)
elif method == 'aaronson_simplified':
cdf = gammainc(k, S)
elif method == 'aaronson_neyman_pearson':
# Chernoff bound
ratio = ps[b,:k] / (1 - ps[b,:k])
E = (1/ratio).sum()
if S > E:
cdf = 1.0
else:
# to compute p-value we must solve for c*:
# (1/(c* + ps/(1-ps))).sum() = S
func = lambda c : (((1 / (c + ratio)).sum() - S)**2).item()
c1 = (k / S - torch.min(ratio)).item()
print("max = ", c1)
c = fminbound(func, 0, c1)
print("solved c = ", c)
print("solved s = ", ((1/(c + ratio)).sum()).item())
# upper bound
cdf = torch.exp(torch.sum(-torch.log(1 + c / ratio)) + c * S)
elif method == 'kirchenbauer':
cdf = betainc(S, k - S + 1, gamma)
if cdf > min(1 / MAX, 1e-5):
cdf = 1 - (1 - cdf)**MAX # true value
else:
cdf = cdf * MAX # numerically stable upper bound
cdfs.append(float(cdf))
ms.append(m)
return cdfs, ms