forked from phoenix/litellm-mirror
Litellm router code coverage 3 (#6274)
* refactor(router.py): move assistants api endpoints to using 1 pass-through factory function Reduces code, increases testing coverage * refactor(router.py): reduce _common_check_available_deployment function size make code more maintainable - reduce possible errors * test(router_code_coverage.py): include batch_utils + pattern matching in enforced 100% code coverage Improves reliability * fix(router.py): fix model id match model dump
This commit is contained in:
parent
891e9001b5
commit
e22e8d24ef
8 changed files with 407 additions and 244 deletions
|
@ -75,29 +75,28 @@ def get_functions_from_router(file_path):
|
|||
|
||||
ignored_function_names = [
|
||||
"__init__",
|
||||
"_acreate_file",
|
||||
"_acreate_batch",
|
||||
"acreate_assistants",
|
||||
"adelete_assistant",
|
||||
"aget_assistants",
|
||||
"acreate_thread",
|
||||
"aget_thread",
|
||||
"a_add_message",
|
||||
"aget_messages",
|
||||
"arun_thread",
|
||||
"try_retrieve_batch",
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
router_file = "./litellm/router.py" # Update this path if it's located elsewhere
|
||||
# router_file = "../../litellm/router.py" ## LOCAL TESTING
|
||||
router_file = [
|
||||
"./litellm/router.py",
|
||||
"./litellm/router_utils/batch_utils.py",
|
||||
"./litellm/router_utils/pattern_match_deployments.py",
|
||||
]
|
||||
# router_file = [
|
||||
# "../../litellm/router.py",
|
||||
# "../../litellm/router_utils/pattern_match_deployments.py",
|
||||
# "../../litellm/router_utils/batch_utils.py",
|
||||
# ] ## LOCAL TESTING
|
||||
tests_dir = (
|
||||
"./tests/" # Update this path if your tests directory is located elsewhere
|
||||
)
|
||||
# tests_dir = "../../tests/" # LOCAL TESTING
|
||||
|
||||
router_functions = get_functions_from_router(router_file)
|
||||
router_functions = []
|
||||
for file in router_file:
|
||||
router_functions.extend(get_functions_from_router(file))
|
||||
print("router_functions: ", router_functions)
|
||||
called_functions_in_tests = get_all_functions_called_in_tests(tests_dir)
|
||||
untested_functions = [
|
||||
|
|
66
tests/code_coverage_tests/router_enforce_line_length.py
Normal file
66
tests/code_coverage_tests/router_enforce_line_length.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
import ast
|
||||
import os
|
||||
|
||||
MAX_FUNCTION_LINES = 100
|
||||
|
||||
|
||||
def get_function_line_counts(file_path):
|
||||
"""
|
||||
Extracts all function names and their line counts from a given Python file.
|
||||
"""
|
||||
with open(file_path, "r") as file:
|
||||
tree = ast.parse(file.read())
|
||||
|
||||
function_line_counts = []
|
||||
|
||||
for node in tree.body:
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
# Top-level functions
|
||||
line_count = node.end_lineno - node.lineno + 1
|
||||
function_line_counts.append((node.name, line_count))
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
# Functions inside classes
|
||||
for class_node in node.body:
|
||||
if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
line_count = class_node.end_lineno - class_node.lineno + 1
|
||||
function_line_counts.append((class_node.name, line_count))
|
||||
|
||||
return function_line_counts
|
||||
|
||||
|
||||
ignored_functions = [
|
||||
"__init__",
|
||||
]
|
||||
|
||||
|
||||
def check_function_lengths(router_file):
|
||||
"""
|
||||
Checks if any function in the specified file exceeds the maximum allowed length.
|
||||
"""
|
||||
function_line_counts = get_function_line_counts(router_file)
|
||||
long_functions = [
|
||||
(name, count)
|
||||
for name, count in function_line_counts
|
||||
if count > MAX_FUNCTION_LINES and name not in ignored_functions
|
||||
]
|
||||
|
||||
if long_functions:
|
||||
print("The following functions exceed the allowed line count:")
|
||||
for name, count in long_functions:
|
||||
print(f"- {name}: {count} lines")
|
||||
raise Exception(
|
||||
f"{len(long_functions)} functions in {router_file} exceed {MAX_FUNCTION_LINES} lines"
|
||||
)
|
||||
else:
|
||||
print("All functions in the router file are within the allowed line limit.")
|
||||
|
||||
|
||||
def main():
|
||||
# Update this path to point to the correct location of router.py
|
||||
router_file = "../../litellm/router.py" # LOCAL TESTING
|
||||
|
||||
check_function_lengths(router_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -2569,6 +2569,15 @@ async def test_router_batch_endpoints(provider):
|
|||
)
|
||||
print("Response from creating file=", file_obj)
|
||||
|
||||
## TEST 2 - test underlying create_file function
|
||||
file_obj = await router._acreate_file(
|
||||
model="my-custom-name",
|
||||
file=open(file_path, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider=provider,
|
||||
)
|
||||
print("Response from creating file=", file_obj)
|
||||
|
||||
await asyncio.sleep(10)
|
||||
batch_input_file_id = file_obj.id
|
||||
assert (
|
||||
|
@ -2583,6 +2592,15 @@ async def test_router_batch_endpoints(provider):
|
|||
custom_llm_provider=provider,
|
||||
metadata={"key1": "value1", "key2": "value2"},
|
||||
)
|
||||
## TEST 2 - test underlying create_batch function
|
||||
create_batch_response = await router._acreate_batch(
|
||||
model="my-custom-name",
|
||||
completion_window="24h",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=batch_input_file_id,
|
||||
custom_llm_provider=provider,
|
||||
metadata={"key1": "value1", "key2": "value2"},
|
||||
)
|
||||
|
||||
print("response from router.create_batch=", create_batch_response)
|
||||
|
||||
|
|
84
tests/router_unit_tests/test_router_batch_utils.py
Normal file
84
tests/router_unit_tests/test_router_batch_utils.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
import sys
|
||||
import os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Request
|
||||
from datetime import datetime
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from litellm import Router
|
||||
import pytest
|
||||
import litellm
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
import json
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
from litellm.router_utils.batch_utils import (
|
||||
replace_model_in_jsonl,
|
||||
_get_router_metadata_variable_name,
|
||||
)
|
||||
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def sample_jsonl_data() -> List[Dict]:
|
||||
"""Fixture providing sample JSONL data"""
|
||||
return [
|
||||
{
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
},
|
||||
{"body": {"model": "gpt-4", "messages": [{"role": "user", "content": "Hi"}]}},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_jsonl_bytes(sample_jsonl_data) -> bytes:
|
||||
"""Fixture providing sample JSONL as bytes"""
|
||||
jsonl_str = "\n".join(json.dumps(line) for line in sample_jsonl_data)
|
||||
return jsonl_str.encode("utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file_like(sample_jsonl_bytes):
|
||||
"""Fixture providing a file-like object"""
|
||||
return BytesIO(sample_jsonl_bytes)
|
||||
|
||||
|
||||
# Test cases
|
||||
def test_bytes_input(sample_jsonl_bytes):
|
||||
"""Test with bytes input"""
|
||||
new_model = "claude-3"
|
||||
result = replace_model_in_jsonl(sample_jsonl_bytes, new_model)
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_tuple_input(sample_jsonl_bytes):
|
||||
"""Test with tuple input"""
|
||||
new_model = "claude-3"
|
||||
test_tuple = ("test.jsonl", sample_jsonl_bytes, "application/json")
|
||||
result = replace_model_in_jsonl(test_tuple, new_model)
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_file_like_object(sample_file_like):
|
||||
"""Test with file-like object input"""
|
||||
new_model = "claude-3"
|
||||
result = replace_model_in_jsonl(sample_file_like, new_model)
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_router_metadata_variable_name():
|
||||
"""Test that the variable name is correct"""
|
||||
assert _get_router_metadata_variable_name(function_name="completion") == "metadata"
|
||||
assert (
|
||||
_get_router_metadata_variable_name(function_name="batch") == "litellm_metadata"
|
||||
)
|
|
@ -41,6 +41,20 @@ def model_list():
|
|||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "*",
|
||||
"litellm_params": {
|
||||
"model": "openai/*",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "claude-*",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/*",
|
||||
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
@ -834,3 +848,69 @@ def test_flush_cache(model_list):
|
|||
assert router.cache.get_cache("test") == "test"
|
||||
router.flush_cache()
|
||||
assert router.cache.get_cache("test") is None
|
||||
|
||||
|
||||
def test_initialize_assistants_endpoint(model_list):
|
||||
"""Test if the 'initialize_assistants_endpoint' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
router.initialize_assistants_endpoint()
|
||||
assert router.acreate_assistants is not None
|
||||
assert router.adelete_assistant is not None
|
||||
assert router.aget_assistants is not None
|
||||
assert router.acreate_thread is not None
|
||||
assert router.aget_thread is not None
|
||||
assert router.arun_thread is not None
|
||||
assert router.aget_messages is not None
|
||||
assert router.a_add_message is not None
|
||||
|
||||
|
||||
def test_pass_through_assistants_endpoint_factory(model_list):
|
||||
"""Test if the 'pass_through_assistants_endpoint_factory' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
router._pass_through_assistants_endpoint_factory(
|
||||
original_function=litellm.acreate_assistants,
|
||||
custom_llm_provider="openai",
|
||||
client=None,
|
||||
**{},
|
||||
)
|
||||
|
||||
|
||||
def test_factory_function(model_list):
|
||||
"""Test if the 'factory_function' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
router.factory_function(litellm.acreate_assistants)
|
||||
|
||||
|
||||
def test_get_model_from_alias(model_list):
|
||||
"""Test if the 'get_model_from_alias' function is working correctly"""
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
model_group_alias={"gpt-4o": "gpt-3.5-turbo"},
|
||||
)
|
||||
model = router._get_model_from_alias(model="gpt-4o")
|
||||
assert model == "gpt-3.5-turbo"
|
||||
|
||||
|
||||
def test_get_deployment_by_litellm_model(model_list):
|
||||
"""Test if the 'get_deployment_by_litellm_model' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
deployment = router._get_deployment_by_litellm_model(model="gpt-3.5-turbo")
|
||||
assert deployment is not None
|
||||
|
||||
|
||||
def test_get_pattern(model_list):
|
||||
router = Router(model_list=model_list)
|
||||
pattern = router.pattern_router.get_pattern(model="claude-3")
|
||||
assert pattern is not None
|
||||
|
||||
|
||||
def test_deployments_by_pattern(model_list):
|
||||
router = Router(model_list=model_list)
|
||||
deployments = router.pattern_router.get_deployments_by_pattern(model="claude-3")
|
||||
assert deployments is not None
|
||||
|
||||
|
||||
def test_replace_model_in_jsonl(model_list):
|
||||
router = Router(model_list=model_list)
|
||||
deployments = router.pattern_router.get_deployments_by_pattern(model="claude-3")
|
||||
assert deployments is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue