Edit model card

Usage:

import torch, transformers, pyreft
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args()


def generate_response():
    '''
    simple test for the model
    '''
    model = transformers.AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.bfloat16, device_map=device)
    tokenizer = transformers.AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", use_fast=False)
    streamer = transformers.TextStreamer(tokenizer,skip_prompt=True)
    reft_model = pyreft.ReftModel.load("benchang1110/Tinyllama-1.1B-Chat-REFT-v1.0", model)
    reft_model.set_device(device)
    
    while(1):
        prompt = input('USER:')
        if prompt == "exit":
            break
        print("Assistant: ")
        messages = [
            {'content': prompt, 'role': 'user'},
        ]
        prompt = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)
        print(prompt)
        prompt = tokenizer(prompt, return_tensors="pt").to(device)  # move prompt to the same device as the model
        
        # have to set the following hyperparameters to make the model work (so stupid.....)
        base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position    
        first_n = 8 # (number of first_n)
        last_n = 8 # (number of last_n)
        LAYER = [i for i in range(model.config.num_hidden_layers)]
        
        base_unit_locations = [[[i for i in range(first_n)] + [base_unit_location-i for i in range(last_n)]]]*len(LAYER)
        _, reft_response = reft_model.generate(
                prompt, unit_locations={"sources->base": (None, base_unit_locations)},
                intervene_on_prompt=True, max_new_tokens=256, do_sample=True, temperature=0.3,repetition_penalty=1.1,streamer=streamer
        )
        
if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    generate_response()
Downloads last month
11
Safetensors
Model size
1.1B params
Tensor type
BF16
·
Inference API
Unable to determine this model’s pipeline type. Check the docs .