PEFT
Safetensors
English
jinjieyuan's picture
Remove base model
3861d7d
|
raw
history blame
4.44 kB
---
language: en
license: apache-2.0
---
# Shears Model Card: shears-mpt-7b-50-gsm8k-super
The super-adapter fine-tuned on sparsified MPT-7B with GSM8K datasets using Shears.
The release of the super-network is to facilitate users to apply their own search algorithms and evaluation indicators to extract subnetworks suitable for their specific needs.
## Model Details
### Information
- **Model name:** shears-mpt-7b-50-gsm8k-super
- **Base model:** [IntelLabs/MPT-7B-sparsity50](https://huggingface.co/IntelLabs/MPT-7B-sparsity50)
- **Sparsity:** 50%
- **Subnetwork version:** Super
- **NNCF Configuration:** [nncf_shears_mpt.json](https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/tree/main/Shears/nncf_config/gsm8k/nncf_shears_mpt.json)
### Adapter Configuration
- **LoRA rank:** 32
- **LoRA alpha:** 64
- **LoRA target modules:** q_proj, k_proj, v_proj, out_proj, up_proj, down_proj
- **LoRA rank search space:** [32, 24, 16] (for each LoRA module)
### Training Hyperparameters
- **Batch size:** 16
- **Learning rate:** 3e-4
- **Epoch:** 5
### Training and Evaluation
GSM8K dataset: [https://huggingface.co/datasets/gsm8k](https://huggingface.co/datasets/gsm8k)
## How to use
Refer to the illustrative example provided in [load_and_explore_supernet.ipynb](https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/tree/main/Shears/search/load_and_explore_supernet.ipynb) for a comprehensive understanding. This notebook shows the direct loading of a Shears super-network and the extraction of diverse subnetworks from it.
This feature empowers users to employ their own search algorithms and evaluation metrics for the extraction of subnetworks customized to their specific requirements.
Moreover, the super-network is essentially the maximal subnetwork, and it can also be directly loaded:
```python
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
def generate_prompt(instruction):
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:
"""
base_model = AutoModelForCausalLM.from_pretrained("IntelLabs/MPT-7B-sparsity50", trust_remote_code=True)
model = PeftModel.from_pretrained(base_model, "IntelLabs/shears-mpt-7b-50-gsm8k-super")
model.eval()
non_zero_params = sum([(param.data != 0).sum().item() for _, param in model.named_parameters()])
print(f"Number of all non-zero parameters: {non_zero_params}")
tokenizer = AutoTokenizer.from_pretrained("IntelLabs/MPT-7B-sparsity50", trust_remote_code=True)
instruction = "Edgar eats 18 pretzels a day. If his brother eats 1/2 as many, how many does his brother eat in a week?"
prompt = generate_prompt(instruction)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(model.device)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=256,
use_cache=True,
num_beams=4,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
print(output)
```
## Evaluation Results
Results of the heuristic sub-network discoverd from the super-network:
| Model | Sparsity | GSM8K Accuracy |
|-----------------------|-------------|-------|
| [**MPT-7B-Shears**](https://huggingface.co/IntelLabs/shears-mpt-7b-50-gsm8k-heuristic) | **50%** | 33.4 |
## Model Sources
- **Repository:** [https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/tree/main/Shears](https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/tree/main/Shears)
- **Paper:** [Shears: Unstructured Sparsity with Neural Low-rank Adapter Search](https://arxiv.org/abs/2404.10934)
## Citation
```bash
@article{munoz2024shears,
title = {Shears: Unstructured Sparsity with Neural Low-rank Adapter Search},
author={J. Pablo Munoz and Jinjie Yuan and Nilesh Jain},
journal={The 2024 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL-2024)},
year={2024}
}
```
## License
Apache-2.0