Style-Instruct Mistral 7B
Mistral 7B instruct fine-tuned on the neuralwork/fashion-style-instruct dataset with LoRA and 4bit quantization. See the blog post and Github repository for training details. This model is trained with body type / personal style descriptions as input, target events (e.g. casual date, business meeting) as context and outfit combination suggestions as output. For a full list of event types, refer to the gradio demo file in the Github repository.
Usage
This repo contains the LoRA parameters of the fine-tuned Mistral 7B model. To perform inference, load and use the model as follows:
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
def format_instruction(input, event):
return f"""You are a personal stylist recommending fashion advice and clothing combinations. Use the self body and style description below, combined with the event described in the context to generate 5 self-contained and complete outfit combinations.
### Input:
{input}
### Context:
I'm going to a {event}.
### Response:
"""
# input is a self description of your body type and personal style
prompt = "I'm an athletic and 171cm tall woman in my mid twenties, I have a rectangle shaped body with slightly broad shoulders and have a sleek, casual style. I usually prefer darker colors."
event = "business meeting"
prompt = format_instruction(prompt, event)
# load base LLM model, LoRA params and tokenizer
model = AutoPeftModelForCausalLM.from_pretrained(
"neuralwork/mistral-7b-style-instruct",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
)
tokenizer = AutoTokenizer.from_pretrained("neuralwork/mistral-7b-style-instruct")
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
# inference
with torch.inference_mode():
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=800,
do_sample=True,
top_p=0.9,
temperature=0.9
)
# decode output tokens and strip response
outputs = outputs.detach().cpu().numpy()
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
output = outputs[0][len(prompt):]
- Downloads last month
- 189