Litellm router code coverage 3 (#6274)

* refactor(router.py): move assistants api endpoints to using 1 pass-through factory function

Reduces code, increases testing coverage

* refactor(router.py): reduce _common_check_available_deployment function size

make code more maintainable - reduce possible errors

* test(router_code_coverage.py): include batch_utils + pattern matching in enforced 100% code coverage

Improves reliability

* fix(router.py): fix model id match model dump
This commit is contained in:
Krish Dholakia 2024-10-16 21:30:25 -07:00 committed by GitHub
parent 891e9001b5
commit e22e8d24ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 407 additions and 244 deletions

View file

@ -25,3 +25,9 @@ model_list:
# guard_name: "gibberish_guard"
# mode: "post_call"
# api_base: os.environ/GUARDRAILS_AI_API_BASE
assistant_settings:
custom_llm_provider: azure
litellm_params:
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE

View file

@ -24,7 +24,18 @@ import traceback
import uuid
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Tuple,
TypedDict,
Union,
)
import httpx
import openai
@ -520,6 +531,19 @@ class Router:
if self.alerting_config is not None:
self._initialize_alerting()
self.initialize_assistants_endpoint()
def initialize_assistants_endpoint(self):
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
self.adelete_assistant = self.factory_function(litellm.adelete_assistant)
self.aget_assistants = self.factory_function(litellm.aget_assistants)
self.acreate_thread = self.factory_function(litellm.acreate_thread)
self.aget_thread = self.factory_function(litellm.aget_thread)
self.a_add_message = self.factory_function(litellm.a_add_message)
self.aget_messages = self.factory_function(litellm.aget_messages)
self.arun_thread = self.factory_function(litellm.arun_thread)
def validate_fallbacks(self, fallback_param: Optional[List]):
"""
Validate the fallbacks parameter.
@ -2167,7 +2191,6 @@ class Router:
raise e
#### FILES API ####
async def acreate_file(
self,
model: str,
@ -2504,114 +2527,29 @@ class Router:
#### ASSISTANTS API ####
async def acreate_assistants(
self,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Assistant:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.acreate_assistants(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def adelete_assistant(
self,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AssistantDeleted:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.adelete_assistant(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def aget_assistants(
self,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[Assistant]:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.aget_assistants(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def acreate_thread(
self,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.acreate_thread(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def aget_thread(
self,
thread_id: str,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.aget_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
client=client,
def factory_function(self, original_function: Callable):
async def new_function(
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional["AsyncOpenAI"] = None,
**kwargs,
)
):
return await self._pass_through_assistants_endpoint_factory(
original_function=original_function,
custom_llm_provider=custom_llm_provider,
client=client,
**kwargs,
)
async def a_add_message(
return new_function
async def _pass_through_assistants_endpoint_factory(
self,
thread_id: str,
role: Literal["user", "assistant"],
content: str,
attachments: Optional[List[Attachment]] = None,
metadata: Optional[dict] = None,
original_function: Callable,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> OpenAIMessage:
):
"""Internal helper function to pass through the assistants endpoint"""
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
@ -2620,76 +2558,8 @@ class Router:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.a_add_message(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
role=role,
content=content,
attachments=attachments,
metadata=metadata,
client=client,
**kwargs,
)
async def aget_messages(
self,
thread_id: str,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[OpenAIMessage]:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.aget_messages(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
client=client,
**kwargs,
)
async def arun_thread(
self,
thread_id: str,
assistant_id: str,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
additional_instructions: Optional[str] = None,
instructions: Optional[str] = None,
metadata: Optional[dict] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
tools: Optional[Iterable[AssistantToolParam]] = None,
client: Optional[Any] = None,
**kwargs,
) -> Run:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.arun_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
stream=stream,
tools=tools,
client=client,
**kwargs,
return await original_function( # type: ignore
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
#### [END] ASSISTANTS API ####
@ -4609,7 +4479,7 @@ class Router:
"""
returned_models: List[DeploymentTypedDict] = []
for model in self.model_list:
if model["model_name"] == model_name:
if model_name is not None and model["model_name"] == model_name:
if model_alias is not None:
alias_model = copy.deepcopy(model)
alias_model["model_name"] = model_alias
@ -5007,82 +4877,73 @@ class Router:
return _returned_deployments
def _get_model_from_alias(self, model: str) -> Optional[str]:
"""
Get the model from the alias.
Returns:
- str, the litellm model name
- None, if model is not in model group alias
"""
if model not in self.model_group_alias:
return None
_item = self.model_group_alias[model]
if isinstance(_item, str):
model = _item
else:
model = _item["model"]
return model
def _get_deployment_by_litellm_model(self, model: str) -> List:
"""
Get the deployment by litellm model.
"""
return [m for m in self.model_list if m["litellm_params"]["model"] == model]
def _common_checks_available_deployment(
self,
model: str,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
specific_deployment: Optional[bool] = False,
) -> Tuple[str, Union[list, dict]]:
) -> Tuple[str, Union[List, Dict]]:
"""
Common checks for 'get_available_deployment' across sync + async call.
If 'healthy_deployments' returned is None, this means the user chose a specific deployment
Returns
- Dict, if specific model chosen
- str, the litellm model name
- List, if multiple models chosen
- Dict, if specific model chosen
"""
# check if aliases set on litellm model alias map
if specific_deployment is True:
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
for deployment in self.model_list:
deployment_model = deployment.get("litellm_params").get("model")
if deployment_model == model:
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
# return the first deployment where the `model` matches the specificed deployment name
return deployment_model, deployment
raise ValueError(
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.get_model_names()}"
)
return model, self._get_deployment_by_litellm_model(model=model)
elif model in self.get_model_ids():
deployment = self.get_model_info(id=model)
deployment = self.get_deployment(model_id=model)
if deployment is not None:
deployment_model = deployment.get("litellm_params", {}).get("model")
return deployment_model, deployment
deployment_model = deployment.litellm_params.model
return deployment_model, deployment.model_dump(exclude_none=True)
raise ValueError(
f"LiteLLM Router: Trying to call specific deployment, but Model ID :{model} does not exist in \
Model ID List: {self.get_model_ids}"
)
if model in self.model_group_alias:
_item = self.model_group_alias[model]
if isinstance(_item, str):
model = _item
else:
model = _item["model"]
_model_from_alias = self._get_model_from_alias(model=model)
if _model_from_alias is not None:
model = _model_from_alias
if model not in self.model_names:
# check if provider/ specific wildcard routing use pattern matching
custom_llm_provider: Optional[str] = None
try:
(
_,
custom_llm_provider,
_,
_,
) = litellm.get_llm_provider(model=model)
except Exception:
# get_llm_provider raises exception when provider is unknown
pass
"""
self.pattern_router.route(model):
does exact pattern matching. Example openai/gpt-3.5-turbo gets routed to pattern openai/*
self.pattern_router.route(f"{custom_llm_provider}/{model}"):
does pattern matching using litellm.get_llm_provider(), example claude-3-5-sonnet-20240620 gets routed to anthropic/* since 'claude-3-5-sonnet-20240620' is an Anthropic Model
"""
_pattern_router_response = self.pattern_router.route(
model
) or self.pattern_router.route(f"{custom_llm_provider}/{model}")
if _pattern_router_response is not None:
provider_deployments = []
for deployment in _pattern_router_response:
dep = copy.deepcopy(deployment)
dep["litellm_params"]["model"] = model
provider_deployments.append(dep)
return model, provider_deployments
pattern_deployments = self.pattern_router.get_deployments_by_pattern(
model=model,
)
if pattern_deployments:
return model, pattern_deployments
# check if default deployment is set
if self.default_deployment is not None:
@ -5094,12 +4955,11 @@ class Router:
## get healthy deployments
### get all deployments
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
healthy_deployments = self._get_all_deployments(model_name=model)
if len(healthy_deployments) == 0:
# check if the user sent in a deployment name instead
healthy_deployments = [
m for m in self.model_list if m["litellm_params"]["model"] == model
]
healthy_deployments = self._get_deployment_by_litellm_model(model=model)
verbose_router_logger.debug(
f"initial list of deployments: {healthy_deployments}"
@ -5151,7 +5011,6 @@ class Router:
input=input,
specific_deployment=specific_deployment,
) # type: ignore
if isinstance(healthy_deployments, dict):
return healthy_deployments

View file

@ -2,9 +2,11 @@
Class to handle llm wildcard routing and regex pattern matching
"""
import copy
import re
from typing import Dict, List, Optional
from litellm import get_llm_provider
from litellm._logging import verbose_router_logger
@ -82,6 +84,55 @@ class PatternMatchRouter:
return None # No matching pattern found
def get_pattern(
self, model: str, custom_llm_provider: Optional[str] = None
) -> Optional[List[Dict]]:
"""
Check if a pattern exists for the given model and custom llm provider
Args:
model: str
custom_llm_provider: Optional[str]
Returns:
bool: True if pattern exists, False otherwise
"""
if custom_llm_provider is None:
try:
(
_,
custom_llm_provider,
_,
_,
) = get_llm_provider(model=model)
except Exception:
# get_llm_provider raises exception when provider is unknown
pass
return self.route(model) or self.route(f"{custom_llm_provider}/{model}")
def get_deployments_by_pattern(
self, model: str, custom_llm_provider: Optional[str] = None
) -> List[Dict]:
"""
Get the deployments by pattern
Args:
model: str
custom_llm_provider: Optional[str]
Returns:
List[Dict]: llm deployments matching the pattern
"""
pattern_match = self.get_pattern(model, custom_llm_provider)
if pattern_match:
provider_deployments = []
for deployment in pattern_match:
dep = copy.deepcopy(deployment)
dep["litellm_params"]["model"] = model
provider_deployments.append(dep)
return provider_deployments
return []
# Example usage:
# router = PatternRouter()

View file

@ -75,29 +75,28 @@ def get_functions_from_router(file_path):
ignored_function_names = [
"__init__",
"_acreate_file",
"_acreate_batch",
"acreate_assistants",
"adelete_assistant",
"aget_assistants",
"acreate_thread",
"aget_thread",
"a_add_message",
"aget_messages",
"arun_thread",
"try_retrieve_batch",
]
def main():
router_file = "./litellm/router.py" # Update this path if it's located elsewhere
# router_file = "../../litellm/router.py" ## LOCAL TESTING
router_file = [
"./litellm/router.py",
"./litellm/router_utils/batch_utils.py",
"./litellm/router_utils/pattern_match_deployments.py",
]
# router_file = [
# "../../litellm/router.py",
# "../../litellm/router_utils/pattern_match_deployments.py",
# "../../litellm/router_utils/batch_utils.py",
# ] ## LOCAL TESTING
tests_dir = (
"./tests/" # Update this path if your tests directory is located elsewhere
)
# tests_dir = "../../tests/" # LOCAL TESTING
router_functions = get_functions_from_router(router_file)
router_functions = []
for file in router_file:
router_functions.extend(get_functions_from_router(file))
print("router_functions: ", router_functions)
called_functions_in_tests = get_all_functions_called_in_tests(tests_dir)
untested_functions = [

View file

@ -0,0 +1,66 @@
import ast
import os
MAX_FUNCTION_LINES = 100
def get_function_line_counts(file_path):
"""
Extracts all function names and their line counts from a given Python file.
"""
with open(file_path, "r") as file:
tree = ast.parse(file.read())
function_line_counts = []
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
# Top-level functions
line_count = node.end_lineno - node.lineno + 1
function_line_counts.append((node.name, line_count))
elif isinstance(node, ast.ClassDef):
# Functions inside classes
for class_node in node.body:
if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
line_count = class_node.end_lineno - class_node.lineno + 1
function_line_counts.append((class_node.name, line_count))
return function_line_counts
ignored_functions = [
"__init__",
]
def check_function_lengths(router_file):
"""
Checks if any function in the specified file exceeds the maximum allowed length.
"""
function_line_counts = get_function_line_counts(router_file)
long_functions = [
(name, count)
for name, count in function_line_counts
if count > MAX_FUNCTION_LINES and name not in ignored_functions
]
if long_functions:
print("The following functions exceed the allowed line count:")
for name, count in long_functions:
print(f"- {name}: {count} lines")
raise Exception(
f"{len(long_functions)} functions in {router_file} exceed {MAX_FUNCTION_LINES} lines"
)
else:
print("All functions in the router file are within the allowed line limit.")
def main():
# Update this path to point to the correct location of router.py
router_file = "../../litellm/router.py" # LOCAL TESTING
check_function_lengths(router_file)
if __name__ == "__main__":
main()

View file

@ -2569,6 +2569,15 @@ async def test_router_batch_endpoints(provider):
)
print("Response from creating file=", file_obj)
## TEST 2 - test underlying create_file function
file_obj = await router._acreate_file(
model="my-custom-name",
file=open(file_path, "rb"),
purpose="batch",
custom_llm_provider=provider,
)
print("Response from creating file=", file_obj)
await asyncio.sleep(10)
batch_input_file_id = file_obj.id
assert (
@ -2583,6 +2592,15 @@ async def test_router_batch_endpoints(provider):
custom_llm_provider=provider,
metadata={"key1": "value1", "key2": "value2"},
)
## TEST 2 - test underlying create_batch function
create_batch_response = await router._acreate_batch(
model="my-custom-name",
completion_window="24h",
endpoint="/v1/chat/completions",
input_file_id=batch_input_file_id,
custom_llm_provider=provider,
metadata={"key1": "value1", "key2": "value2"},
)
print("response from router.create_batch=", create_batch_response)

View file

@ -0,0 +1,84 @@
import sys
import os
import traceback
from dotenv import load_dotenv
from fastapi import Request
from datetime import datetime
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm import Router
import pytest
import litellm
from unittest.mock import patch, MagicMock, AsyncMock
import json
from io import BytesIO
from typing import Dict, List
from litellm.router_utils.batch_utils import (
replace_model_in_jsonl,
_get_router_metadata_variable_name,
)
# Fixtures
@pytest.fixture
def sample_jsonl_data() -> List[Dict]:
"""Fixture providing sample JSONL data"""
return [
{
"body": {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello"}],
}
},
{"body": {"model": "gpt-4", "messages": [{"role": "user", "content": "Hi"}]}},
]
@pytest.fixture
def sample_jsonl_bytes(sample_jsonl_data) -> bytes:
"""Fixture providing sample JSONL as bytes"""
jsonl_str = "\n".join(json.dumps(line) for line in sample_jsonl_data)
return jsonl_str.encode("utf-8")
@pytest.fixture
def sample_file_like(sample_jsonl_bytes):
"""Fixture providing a file-like object"""
return BytesIO(sample_jsonl_bytes)
# Test cases
def test_bytes_input(sample_jsonl_bytes):
"""Test with bytes input"""
new_model = "claude-3"
result = replace_model_in_jsonl(sample_jsonl_bytes, new_model)
assert result is not None
def test_tuple_input(sample_jsonl_bytes):
"""Test with tuple input"""
new_model = "claude-3"
test_tuple = ("test.jsonl", sample_jsonl_bytes, "application/json")
result = replace_model_in_jsonl(test_tuple, new_model)
assert result is not None
def test_file_like_object(sample_file_like):
"""Test with file-like object input"""
new_model = "claude-3"
result = replace_model_in_jsonl(sample_file_like, new_model)
assert result is not None
def test_router_metadata_variable_name():
"""Test that the variable name is correct"""
assert _get_router_metadata_variable_name(function_name="completion") == "metadata"
assert (
_get_router_metadata_variable_name(function_name="batch") == "litellm_metadata"
)

View file

@ -41,6 +41,20 @@ def model_list():
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "*",
"litellm_params": {
"model": "openai/*",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "claude-*",
"litellm_params": {
"model": "anthropic/*",
"api_key": os.getenv("ANTHROPIC_API_KEY"),
},
},
]
@ -834,3 +848,69 @@ def test_flush_cache(model_list):
assert router.cache.get_cache("test") == "test"
router.flush_cache()
assert router.cache.get_cache("test") is None
def test_initialize_assistants_endpoint(model_list):
"""Test if the 'initialize_assistants_endpoint' function is working correctly"""
router = Router(model_list=model_list)
router.initialize_assistants_endpoint()
assert router.acreate_assistants is not None
assert router.adelete_assistant is not None
assert router.aget_assistants is not None
assert router.acreate_thread is not None
assert router.aget_thread is not None
assert router.arun_thread is not None
assert router.aget_messages is not None
assert router.a_add_message is not None
def test_pass_through_assistants_endpoint_factory(model_list):
"""Test if the 'pass_through_assistants_endpoint_factory' function is working correctly"""
router = Router(model_list=model_list)
router._pass_through_assistants_endpoint_factory(
original_function=litellm.acreate_assistants,
custom_llm_provider="openai",
client=None,
**{},
)
def test_factory_function(model_list):
"""Test if the 'factory_function' function is working correctly"""
router = Router(model_list=model_list)
router.factory_function(litellm.acreate_assistants)
def test_get_model_from_alias(model_list):
"""Test if the 'get_model_from_alias' function is working correctly"""
router = Router(
model_list=model_list,
model_group_alias={"gpt-4o": "gpt-3.5-turbo"},
)
model = router._get_model_from_alias(model="gpt-4o")
assert model == "gpt-3.5-turbo"
def test_get_deployment_by_litellm_model(model_list):
"""Test if the 'get_deployment_by_litellm_model' function is working correctly"""
router = Router(model_list=model_list)
deployment = router._get_deployment_by_litellm_model(model="gpt-3.5-turbo")
assert deployment is not None
def test_get_pattern(model_list):
router = Router(model_list=model_list)
pattern = router.pattern_router.get_pattern(model="claude-3")
assert pattern is not None
def test_deployments_by_pattern(model_list):
router = Router(model_list=model_list)
deployments = router.pattern_router.get_deployments_by_pattern(model="claude-3")
assert deployments is not None
def test_replace_model_in_jsonl(model_list):
router = Router(model_list=model_list)
deployments = router.pattern_router.get_deployments_by_pattern(model="claude-3")
assert deployments is not None