basic_demo / trans_stress_test.py
Jar2023's picture
Upload folder using huggingface_hub
25be583 verified
import argparse
import time
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import torch
from threading import Thread
MODEL_PATH = 'THUDM/glm-4-9b-chat'
def stress_test(token_len, n, num_gpu):
device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
padding_side="left"
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to(device).eval()
# Use INT4 weight infer
# model = AutoModelForCausalLM.from_pretrained(
# MODEL_PATH,
# trust_remote_code=True,
# quantization_config=BitsAndBytesConfig(load_in_4bit=True),
# low_cpu_mem_usage=True,
# ).eval()
times = []
decode_times = []
print("Warming up...")
vocab_size = tokenizer.vocab_size
warmup_token_len = 20
random_token_ids = torch.randint(3, vocab_size - 200, (warmup_token_len - 5,), dtype=torch.long)
start_tokens = [151331, 151333, 151336, 198]
end_tokens = [151337]
input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(0).to(
device)
attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
warmup_inputs = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids
}
with torch.no_grad():
_ = model.generate(
input_ids=warmup_inputs['input_ids'],
attention_mask=warmup_inputs['attention_mask'],
max_new_tokens=2048,
do_sample=False,
repetition_penalty=1.0,
eos_token_id=[151329, 151336, 151338]
)
print("Warming up complete. Starting stress test...")
for i in range(n):
random_token_ids = torch.randint(3, vocab_size - 200, (token_len - 5,), dtype=torch.long)
input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(
0).to(device)
attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
test_inputs = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids
}
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
timeout=36000,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = {
"input_ids": test_inputs['input_ids'],
"attention_mask": test_inputs['attention_mask'],
"max_new_tokens": 512,
"do_sample": False,
"repetition_penalty": 1.0,
"eos_token_id": [151329, 151336, 151338],
"streamer": streamer
}
start_time = time.time()
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
first_token_time = None
all_token_times = []
for token in streamer:
current_time = time.time()
if first_token_time is None:
first_token_time = current_time
times.append(first_token_time - start_time)
all_token_times.append(current_time)
t.join()
end_time = time.time()
avg_decode_time_per_token = len(all_token_times) / (end_time - first_token_time) if all_token_times else 0
decode_times.append(avg_decode_time_per_token)
print(
f"Iteration {i + 1}/{n} - Prefilling Time: {times[-1]:.4f} seconds - Average Decode Time: {avg_decode_time_per_token:.4f} tokens/second")
torch.cuda.empty_cache()
avg_first_token_time = sum(times) / n
avg_decode_time = sum(decode_times) / n
print(f"\nAverage First Token Time over {n} iterations: {avg_first_token_time:.4f} seconds")
print(f"Average Decode Time per Token over {n} iterations: {avg_decode_time:.4f} tokens/second")
return times, avg_first_token_time, decode_times, avg_decode_time
def main():
parser = argparse.ArgumentParser(description="Stress test for model inference")
parser.add_argument('--token_len', type=int, default=1000, help='Number of tokens for each test')
parser.add_argument('--n', type=int, default=3, help='Number of iterations for the stress test')
parser.add_argument('--num_gpu', type=int, default=1, help='Number of GPUs to use for inference')
args = parser.parse_args()
token_len = args.token_len
n = args.n
num_gpu = args.num_gpu
stress_test(token_len, n, num_gpu)
if __name__ == "__main__":
main()