|
import unittest |
|
|
|
import tests.context |
|
from autogpt.token_counter import count_message_tokens, count_string_tokens |
|
|
|
|
|
class TestTokenCounter(unittest.TestCase): |
|
def test_count_message_tokens(self): |
|
messages = [ |
|
{"role": "user", "content": "Hello"}, |
|
{"role": "assistant", "content": "Hi there!"}, |
|
] |
|
self.assertEqual(count_message_tokens(messages), 17) |
|
|
|
def test_count_message_tokens_with_name(self): |
|
messages = [ |
|
{"role": "user", "content": "Hello", "name": "John"}, |
|
{"role": "assistant", "content": "Hi there!"}, |
|
] |
|
self.assertEqual(count_message_tokens(messages), 17) |
|
|
|
def test_count_message_tokens_empty_input(self): |
|
self.assertEqual(count_message_tokens([]), 3) |
|
|
|
def test_count_message_tokens_invalid_model(self): |
|
messages = [ |
|
{"role": "user", "content": "Hello"}, |
|
{"role": "assistant", "content": "Hi there!"}, |
|
] |
|
with self.assertRaises(KeyError): |
|
count_message_tokens(messages, model="invalid_model") |
|
|
|
def test_count_message_tokens_gpt_4(self): |
|
messages = [ |
|
{"role": "user", "content": "Hello"}, |
|
{"role": "assistant", "content": "Hi there!"}, |
|
] |
|
self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15) |
|
|
|
def test_count_string_tokens(self): |
|
string = "Hello, world!" |
|
self.assertEqual( |
|
count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4 |
|
) |
|
|
|
def test_count_string_tokens_empty_input(self): |
|
self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0) |
|
|
|
def test_count_message_tokens_invalid_model(self): |
|
messages = [ |
|
{"role": "user", "content": "Hello"}, |
|
{"role": "assistant", "content": "Hi there!"}, |
|
] |
|
with self.assertRaises(NotImplementedError): |
|
count_message_tokens(messages, model="invalid_model") |
|
|
|
def test_count_string_tokens_gpt_4(self): |
|
string = "Hello, world!" |
|
self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|