forked from phoenix/litellm-mirror
add test for test_get_valid_models
This commit is contained in:
parent
57f15b379d
commit
a48cb49820
1 changed files with 23 additions and 6 deletions
|
@ -10,7 +10,7 @@ sys.path.insert(
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import trim_messages, get_token_count
|
from litellm.utils import trim_messages, get_token_count, get_valid_models
|
||||||
|
|
||||||
# Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils'
|
# Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils'
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ def test_basic_trimming():
|
||||||
print(trimmed_messages)
|
print(trimmed_messages)
|
||||||
# print(get_token_count(messages=trimmed_messages, model="claude-2"))
|
# print(get_token_count(messages=trimmed_messages, model="claude-2"))
|
||||||
assert (get_token_count(messages=trimmed_messages, model="claude-2")) <= 8
|
assert (get_token_count(messages=trimmed_messages, model="claude-2")) <= 8
|
||||||
test_basic_trimming()
|
# test_basic_trimming()
|
||||||
|
|
||||||
def test_basic_trimming_no_max_tokens_specified():
|
def test_basic_trimming_no_max_tokens_specified():
|
||||||
messages = [{"role": "user", "content": "This is a long message that is definitely under the token limit."}]
|
messages = [{"role": "user", "content": "This is a long message that is definitely under the token limit."}]
|
||||||
|
@ -31,7 +31,7 @@ def test_basic_trimming_no_max_tokens_specified():
|
||||||
print(trimmed_messages)
|
print(trimmed_messages)
|
||||||
# print(get_token_count(messages=trimmed_messages, model="claude-2"))
|
# print(get_token_count(messages=trimmed_messages, model="claude-2"))
|
||||||
assert (get_token_count(messages=trimmed_messages, model="gpt-4")) <= litellm.model_cost['gpt-4']['max_tokens']
|
assert (get_token_count(messages=trimmed_messages, model="gpt-4")) <= litellm.model_cost['gpt-4']['max_tokens']
|
||||||
test_basic_trimming_no_max_tokens_specified()
|
# test_basic_trimming_no_max_tokens_specified()
|
||||||
|
|
||||||
def test_multiple_messages_trimming():
|
def test_multiple_messages_trimming():
|
||||||
messages = [
|
messages = [
|
||||||
|
@ -43,7 +43,7 @@ def test_multiple_messages_trimming():
|
||||||
print(trimmed_messages)
|
print(trimmed_messages)
|
||||||
# print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo"))
|
# print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo"))
|
||||||
assert(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) <= 20
|
assert(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) <= 20
|
||||||
test_multiple_messages_trimming()
|
# test_multiple_messages_trimming()
|
||||||
|
|
||||||
def test_multiple_messages_no_trimming():
|
def test_multiple_messages_no_trimming():
|
||||||
messages = [
|
messages = [
|
||||||
|
@ -55,7 +55,7 @@ def test_multiple_messages_no_trimming():
|
||||||
print(trimmed_messages)
|
print(trimmed_messages)
|
||||||
assert(messages==trimmed_messages)
|
assert(messages==trimmed_messages)
|
||||||
|
|
||||||
test_multiple_messages_no_trimming()
|
# test_multiple_messages_no_trimming()
|
||||||
|
|
||||||
|
|
||||||
def test_large_trimming():
|
def test_large_trimming():
|
||||||
|
@ -64,4 +64,21 @@ def test_large_trimming():
|
||||||
print("trimmed messages")
|
print("trimmed messages")
|
||||||
print(trimmed_messages)
|
print(trimmed_messages)
|
||||||
assert(get_token_count(messages=trimmed_messages, model="random")) <= 20
|
assert(get_token_count(messages=trimmed_messages, model="random")) <= 20
|
||||||
test_large_trimming()
|
# test_large_trimming()
|
||||||
|
|
||||||
|
def test_get_valid_models():
|
||||||
|
old_environ = os.environ
|
||||||
|
os.environ = {'OPENAI_API_KEY': 'temp'} # mock set only openai key in environ
|
||||||
|
|
||||||
|
valid_models = get_valid_models()
|
||||||
|
print(valid_models)
|
||||||
|
|
||||||
|
# list of openai supported llms on litellm
|
||||||
|
expected_models = litellm.open_ai_chat_completion_models + litellm.open_ai_text_completion_models
|
||||||
|
|
||||||
|
assert(valid_models == expected_models)
|
||||||
|
|
||||||
|
# reset replicate env key
|
||||||
|
os.environ = old_environ
|
||||||
|
|
||||||
|
# test_get_valid_models()
|
Loading…
Add table
Add a link
Reference in a new issue