forked from phoenix/litellm-mirror
Litellm router code coverage 3 (#6274)
* refactor(router.py): move assistants api endpoints to using 1 pass-through factory function Reduces code, increases testing coverage * refactor(router.py): reduce _common_check_available_deployment function size make code more maintainable - reduce possible errors * test(router_code_coverage.py): include batch_utils + pattern matching in enforced 100% code coverage Improves reliability * fix(router.py): fix model id match model dump
This commit is contained in:
parent
891e9001b5
commit
e22e8d24ef
8 changed files with 407 additions and 244 deletions
|
@ -25,3 +25,9 @@ model_list:
|
|||
# guard_name: "gibberish_guard"
|
||||
# mode: "post_call"
|
||||
# api_base: os.environ/GUARDRAILS_AI_API_BASE
|
||||
|
||||
assistant_settings:
|
||||
custom_llm_provider: azure
|
||||
litellm_params:
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
|
@ -24,7 +24,18 @@ import traceback
|
|||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
|
@ -520,6 +531,19 @@ class Router:
|
|||
if self.alerting_config is not None:
|
||||
self._initialize_alerting()
|
||||
|
||||
self.initialize_assistants_endpoint()
|
||||
|
||||
def initialize_assistants_endpoint(self):
|
||||
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
|
||||
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
|
||||
self.adelete_assistant = self.factory_function(litellm.adelete_assistant)
|
||||
self.aget_assistants = self.factory_function(litellm.aget_assistants)
|
||||
self.acreate_thread = self.factory_function(litellm.acreate_thread)
|
||||
self.aget_thread = self.factory_function(litellm.aget_thread)
|
||||
self.a_add_message = self.factory_function(litellm.a_add_message)
|
||||
self.aget_messages = self.factory_function(litellm.aget_messages)
|
||||
self.arun_thread = self.factory_function(litellm.arun_thread)
|
||||
|
||||
def validate_fallbacks(self, fallback_param: Optional[List]):
|
||||
"""
|
||||
Validate the fallbacks parameter.
|
||||
|
@ -2167,7 +2191,6 @@ class Router:
|
|||
raise e
|
||||
|
||||
#### FILES API ####
|
||||
|
||||
async def acreate_file(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -2504,114 +2527,29 @@ class Router:
|
|||
|
||||
#### ASSISTANTS API ####
|
||||
|
||||
async def acreate_assistants(
|
||||
self,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> Assistant:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.acreate_assistants(
|
||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||
)
|
||||
|
||||
async def adelete_assistant(
|
||||
self,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> AssistantDeleted:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.adelete_assistant(
|
||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||
)
|
||||
|
||||
async def aget_assistants(
|
||||
self,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> AsyncCursorPage[Assistant]:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.aget_assistants(
|
||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||
)
|
||||
|
||||
async def acreate_thread(
|
||||
self,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> Thread:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
return await litellm.acreate_thread(
|
||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||
)
|
||||
|
||||
async def aget_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> Thread:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
return await litellm.aget_thread(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
client=client,
|
||||
def factory_function(self, original_function: Callable):
|
||||
async def new_function(
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional["AsyncOpenAI"] = None,
|
||||
**kwargs,
|
||||
)
|
||||
):
|
||||
return await self._pass_through_assistants_endpoint_factory(
|
||||
original_function=original_function,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
client=client,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def a_add_message(
|
||||
return new_function
|
||||
|
||||
async def _pass_through_assistants_endpoint_factory(
|
||||
self,
|
||||
thread_id: str,
|
||||
role: Literal["user", "assistant"],
|
||||
content: str,
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
original_function: Callable,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> OpenAIMessage:
|
||||
):
|
||||
"""Internal helper function to pass through the assistants endpoint"""
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
|
@ -2620,76 +2558,8 @@ class Router:
|
|||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.a_add_message(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
role=role,
|
||||
content=content,
|
||||
attachments=attachments,
|
||||
metadata=metadata,
|
||||
client=client,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def aget_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> AsyncCursorPage[OpenAIMessage]:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
return await litellm.aget_messages(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
client=client,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def arun_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
additional_instructions: Optional[str] = None,
|
||||
instructions: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
tools: Optional[Iterable[AssistantToolParam]] = None,
|
||||
client: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> Run:
|
||||
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.arun_thread(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
client=client,
|
||||
**kwargs,
|
||||
return await original_function( # type: ignore
|
||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||
)
|
||||
|
||||
#### [END] ASSISTANTS API ####
|
||||
|
@ -4609,7 +4479,7 @@ class Router:
|
|||
"""
|
||||
returned_models: List[DeploymentTypedDict] = []
|
||||
for model in self.model_list:
|
||||
if model["model_name"] == model_name:
|
||||
if model_name is not None and model["model_name"] == model_name:
|
||||
if model_alias is not None:
|
||||
alias_model = copy.deepcopy(model)
|
||||
alias_model["model_name"] = model_alias
|
||||
|
@ -5007,82 +4877,73 @@ class Router:
|
|||
|
||||
return _returned_deployments
|
||||
|
||||
def _get_model_from_alias(self, model: str) -> Optional[str]:
|
||||
"""
|
||||
Get the model from the alias.
|
||||
|
||||
Returns:
|
||||
- str, the litellm model name
|
||||
- None, if model is not in model group alias
|
||||
"""
|
||||
if model not in self.model_group_alias:
|
||||
return None
|
||||
|
||||
_item = self.model_group_alias[model]
|
||||
if isinstance(_item, str):
|
||||
model = _item
|
||||
else:
|
||||
model = _item["model"]
|
||||
|
||||
return model
|
||||
|
||||
def _get_deployment_by_litellm_model(self, model: str) -> List:
|
||||
"""
|
||||
Get the deployment by litellm model.
|
||||
"""
|
||||
return [m for m in self.model_list if m["litellm_params"]["model"] == model]
|
||||
|
||||
def _common_checks_available_deployment(
|
||||
self,
|
||||
model: str,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
) -> Tuple[str, Union[list, dict]]:
|
||||
) -> Tuple[str, Union[List, Dict]]:
|
||||
"""
|
||||
Common checks for 'get_available_deployment' across sync + async call.
|
||||
|
||||
If 'healthy_deployments' returned is None, this means the user chose a specific deployment
|
||||
|
||||
Returns
|
||||
- Dict, if specific model chosen
|
||||
- str, the litellm model name
|
||||
- List, if multiple models chosen
|
||||
- Dict, if specific model chosen
|
||||
"""
|
||||
|
||||
# check if aliases set on litellm model alias map
|
||||
if specific_deployment is True:
|
||||
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
|
||||
for deployment in self.model_list:
|
||||
deployment_model = deployment.get("litellm_params").get("model")
|
||||
if deployment_model == model:
|
||||
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
|
||||
# return the first deployment where the `model` matches the specificed deployment name
|
||||
return deployment_model, deployment
|
||||
raise ValueError(
|
||||
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.get_model_names()}"
|
||||
)
|
||||
return model, self._get_deployment_by_litellm_model(model=model)
|
||||
elif model in self.get_model_ids():
|
||||
deployment = self.get_model_info(id=model)
|
||||
deployment = self.get_deployment(model_id=model)
|
||||
if deployment is not None:
|
||||
deployment_model = deployment.get("litellm_params", {}).get("model")
|
||||
return deployment_model, deployment
|
||||
deployment_model = deployment.litellm_params.model
|
||||
return deployment_model, deployment.model_dump(exclude_none=True)
|
||||
raise ValueError(
|
||||
f"LiteLLM Router: Trying to call specific deployment, but Model ID :{model} does not exist in \
|
||||
Model ID List: {self.get_model_ids}"
|
||||
)
|
||||
|
||||
if model in self.model_group_alias:
|
||||
_item = self.model_group_alias[model]
|
||||
if isinstance(_item, str):
|
||||
model = _item
|
||||
else:
|
||||
model = _item["model"]
|
||||
_model_from_alias = self._get_model_from_alias(model=model)
|
||||
if _model_from_alias is not None:
|
||||
model = _model_from_alias
|
||||
|
||||
if model not in self.model_names:
|
||||
# check if provider/ specific wildcard routing use pattern matching
|
||||
custom_llm_provider: Optional[str] = None
|
||||
try:
|
||||
(
|
||||
_,
|
||||
custom_llm_provider,
|
||||
_,
|
||||
_,
|
||||
) = litellm.get_llm_provider(model=model)
|
||||
except Exception:
|
||||
# get_llm_provider raises exception when provider is unknown
|
||||
pass
|
||||
|
||||
"""
|
||||
self.pattern_router.route(model):
|
||||
does exact pattern matching. Example openai/gpt-3.5-turbo gets routed to pattern openai/*
|
||||
|
||||
self.pattern_router.route(f"{custom_llm_provider}/{model}"):
|
||||
does pattern matching using litellm.get_llm_provider(), example claude-3-5-sonnet-20240620 gets routed to anthropic/* since 'claude-3-5-sonnet-20240620' is an Anthropic Model
|
||||
"""
|
||||
_pattern_router_response = self.pattern_router.route(
|
||||
model
|
||||
) or self.pattern_router.route(f"{custom_llm_provider}/{model}")
|
||||
if _pattern_router_response is not None:
|
||||
provider_deployments = []
|
||||
for deployment in _pattern_router_response:
|
||||
dep = copy.deepcopy(deployment)
|
||||
dep["litellm_params"]["model"] = model
|
||||
provider_deployments.append(dep)
|
||||
return model, provider_deployments
|
||||
pattern_deployments = self.pattern_router.get_deployments_by_pattern(
|
||||
model=model,
|
||||
)
|
||||
if pattern_deployments:
|
||||
return model, pattern_deployments
|
||||
|
||||
# check if default deployment is set
|
||||
if self.default_deployment is not None:
|
||||
|
@ -5094,12 +4955,11 @@ class Router:
|
|||
|
||||
## get healthy deployments
|
||||
### get all deployments
|
||||
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
||||
healthy_deployments = self._get_all_deployments(model_name=model)
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
# check if the user sent in a deployment name instead
|
||||
healthy_deployments = [
|
||||
m for m in self.model_list if m["litellm_params"]["model"] == model
|
||||
]
|
||||
healthy_deployments = self._get_deployment_by_litellm_model(model=model)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"initial list of deployments: {healthy_deployments}"
|
||||
|
@ -5151,7 +5011,6 @@ class Router:
|
|||
input=input,
|
||||
specific_deployment=specific_deployment,
|
||||
) # type: ignore
|
||||
|
||||
if isinstance(healthy_deployments, dict):
|
||||
return healthy_deployments
|
||||
|
||||
|
|
|
@ -2,9 +2,11 @@
|
|||
Class to handle llm wildcard routing and regex pattern matching
|
||||
"""
|
||||
|
||||
import copy
|
||||
import re
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from litellm import get_llm_provider
|
||||
from litellm._logging import verbose_router_logger
|
||||
|
||||
|
||||
|
@ -82,6 +84,55 @@ class PatternMatchRouter:
|
|||
|
||||
return None # No matching pattern found
|
||||
|
||||
def get_pattern(
|
||||
self, model: str, custom_llm_provider: Optional[str] = None
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Check if a pattern exists for the given model and custom llm provider
|
||||
|
||||
Args:
|
||||
model: str
|
||||
custom_llm_provider: Optional[str]
|
||||
|
||||
Returns:
|
||||
bool: True if pattern exists, False otherwise
|
||||
"""
|
||||
if custom_llm_provider is None:
|
||||
try:
|
||||
(
|
||||
_,
|
||||
custom_llm_provider,
|
||||
_,
|
||||
_,
|
||||
) = get_llm_provider(model=model)
|
||||
except Exception:
|
||||
# get_llm_provider raises exception when provider is unknown
|
||||
pass
|
||||
return self.route(model) or self.route(f"{custom_llm_provider}/{model}")
|
||||
|
||||
def get_deployments_by_pattern(
|
||||
self, model: str, custom_llm_provider: Optional[str] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Get the deployments by pattern
|
||||
|
||||
Args:
|
||||
model: str
|
||||
custom_llm_provider: Optional[str]
|
||||
|
||||
Returns:
|
||||
List[Dict]: llm deployments matching the pattern
|
||||
"""
|
||||
pattern_match = self.get_pattern(model, custom_llm_provider)
|
||||
if pattern_match:
|
||||
provider_deployments = []
|
||||
for deployment in pattern_match:
|
||||
dep = copy.deepcopy(deployment)
|
||||
dep["litellm_params"]["model"] = model
|
||||
provider_deployments.append(dep)
|
||||
return provider_deployments
|
||||
return []
|
||||
|
||||
|
||||
# Example usage:
|
||||
# router = PatternRouter()
|
||||
|
|
|
@ -75,29 +75,28 @@ def get_functions_from_router(file_path):
|
|||
|
||||
ignored_function_names = [
|
||||
"__init__",
|
||||
"_acreate_file",
|
||||
"_acreate_batch",
|
||||
"acreate_assistants",
|
||||
"adelete_assistant",
|
||||
"aget_assistants",
|
||||
"acreate_thread",
|
||||
"aget_thread",
|
||||
"a_add_message",
|
||||
"aget_messages",
|
||||
"arun_thread",
|
||||
"try_retrieve_batch",
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
router_file = "./litellm/router.py" # Update this path if it's located elsewhere
|
||||
# router_file = "../../litellm/router.py" ## LOCAL TESTING
|
||||
router_file = [
|
||||
"./litellm/router.py",
|
||||
"./litellm/router_utils/batch_utils.py",
|
||||
"./litellm/router_utils/pattern_match_deployments.py",
|
||||
]
|
||||
# router_file = [
|
||||
# "../../litellm/router.py",
|
||||
# "../../litellm/router_utils/pattern_match_deployments.py",
|
||||
# "../../litellm/router_utils/batch_utils.py",
|
||||
# ] ## LOCAL TESTING
|
||||
tests_dir = (
|
||||
"./tests/" # Update this path if your tests directory is located elsewhere
|
||||
)
|
||||
# tests_dir = "../../tests/" # LOCAL TESTING
|
||||
|
||||
router_functions = get_functions_from_router(router_file)
|
||||
router_functions = []
|
||||
for file in router_file:
|
||||
router_functions.extend(get_functions_from_router(file))
|
||||
print("router_functions: ", router_functions)
|
||||
called_functions_in_tests = get_all_functions_called_in_tests(tests_dir)
|
||||
untested_functions = [
|
||||
|
|
66
tests/code_coverage_tests/router_enforce_line_length.py
Normal file
66
tests/code_coverage_tests/router_enforce_line_length.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
import ast
|
||||
import os
|
||||
|
||||
MAX_FUNCTION_LINES = 100
|
||||
|
||||
|
||||
def get_function_line_counts(file_path):
|
||||
"""
|
||||
Extracts all function names and their line counts from a given Python file.
|
||||
"""
|
||||
with open(file_path, "r") as file:
|
||||
tree = ast.parse(file.read())
|
||||
|
||||
function_line_counts = []
|
||||
|
||||
for node in tree.body:
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
# Top-level functions
|
||||
line_count = node.end_lineno - node.lineno + 1
|
||||
function_line_counts.append((node.name, line_count))
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
# Functions inside classes
|
||||
for class_node in node.body:
|
||||
if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
line_count = class_node.end_lineno - class_node.lineno + 1
|
||||
function_line_counts.append((class_node.name, line_count))
|
||||
|
||||
return function_line_counts
|
||||
|
||||
|
||||
ignored_functions = [
|
||||
"__init__",
|
||||
]
|
||||
|
||||
|
||||
def check_function_lengths(router_file):
|
||||
"""
|
||||
Checks if any function in the specified file exceeds the maximum allowed length.
|
||||
"""
|
||||
function_line_counts = get_function_line_counts(router_file)
|
||||
long_functions = [
|
||||
(name, count)
|
||||
for name, count in function_line_counts
|
||||
if count > MAX_FUNCTION_LINES and name not in ignored_functions
|
||||
]
|
||||
|
||||
if long_functions:
|
||||
print("The following functions exceed the allowed line count:")
|
||||
for name, count in long_functions:
|
||||
print(f"- {name}: {count} lines")
|
||||
raise Exception(
|
||||
f"{len(long_functions)} functions in {router_file} exceed {MAX_FUNCTION_LINES} lines"
|
||||
)
|
||||
else:
|
||||
print("All functions in the router file are within the allowed line limit.")
|
||||
|
||||
|
||||
def main():
|
||||
# Update this path to point to the correct location of router.py
|
||||
router_file = "../../litellm/router.py" # LOCAL TESTING
|
||||
|
||||
check_function_lengths(router_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -2569,6 +2569,15 @@ async def test_router_batch_endpoints(provider):
|
|||
)
|
||||
print("Response from creating file=", file_obj)
|
||||
|
||||
## TEST 2 - test underlying create_file function
|
||||
file_obj = await router._acreate_file(
|
||||
model="my-custom-name",
|
||||
file=open(file_path, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider=provider,
|
||||
)
|
||||
print("Response from creating file=", file_obj)
|
||||
|
||||
await asyncio.sleep(10)
|
||||
batch_input_file_id = file_obj.id
|
||||
assert (
|
||||
|
@ -2583,6 +2592,15 @@ async def test_router_batch_endpoints(provider):
|
|||
custom_llm_provider=provider,
|
||||
metadata={"key1": "value1", "key2": "value2"},
|
||||
)
|
||||
## TEST 2 - test underlying create_batch function
|
||||
create_batch_response = await router._acreate_batch(
|
||||
model="my-custom-name",
|
||||
completion_window="24h",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=batch_input_file_id,
|
||||
custom_llm_provider=provider,
|
||||
metadata={"key1": "value1", "key2": "value2"},
|
||||
)
|
||||
|
||||
print("response from router.create_batch=", create_batch_response)
|
||||
|
||||
|
|
84
tests/router_unit_tests/test_router_batch_utils.py
Normal file
84
tests/router_unit_tests/test_router_batch_utils.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
import sys
|
||||
import os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Request
|
||||
from datetime import datetime
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from litellm import Router
|
||||
import pytest
|
||||
import litellm
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
import json
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
from litellm.router_utils.batch_utils import (
|
||||
replace_model_in_jsonl,
|
||||
_get_router_metadata_variable_name,
|
||||
)
|
||||
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def sample_jsonl_data() -> List[Dict]:
|
||||
"""Fixture providing sample JSONL data"""
|
||||
return [
|
||||
{
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
},
|
||||
{"body": {"model": "gpt-4", "messages": [{"role": "user", "content": "Hi"}]}},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_jsonl_bytes(sample_jsonl_data) -> bytes:
|
||||
"""Fixture providing sample JSONL as bytes"""
|
||||
jsonl_str = "\n".join(json.dumps(line) for line in sample_jsonl_data)
|
||||
return jsonl_str.encode("utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file_like(sample_jsonl_bytes):
|
||||
"""Fixture providing a file-like object"""
|
||||
return BytesIO(sample_jsonl_bytes)
|
||||
|
||||
|
||||
# Test cases
|
||||
def test_bytes_input(sample_jsonl_bytes):
|
||||
"""Test with bytes input"""
|
||||
new_model = "claude-3"
|
||||
result = replace_model_in_jsonl(sample_jsonl_bytes, new_model)
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_tuple_input(sample_jsonl_bytes):
|
||||
"""Test with tuple input"""
|
||||
new_model = "claude-3"
|
||||
test_tuple = ("test.jsonl", sample_jsonl_bytes, "application/json")
|
||||
result = replace_model_in_jsonl(test_tuple, new_model)
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_file_like_object(sample_file_like):
|
||||
"""Test with file-like object input"""
|
||||
new_model = "claude-3"
|
||||
result = replace_model_in_jsonl(sample_file_like, new_model)
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_router_metadata_variable_name():
|
||||
"""Test that the variable name is correct"""
|
||||
assert _get_router_metadata_variable_name(function_name="completion") == "metadata"
|
||||
assert (
|
||||
_get_router_metadata_variable_name(function_name="batch") == "litellm_metadata"
|
||||
)
|
|
@ -41,6 +41,20 @@ def model_list():
|
|||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "*",
|
||||
"litellm_params": {
|
||||
"model": "openai/*",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "claude-*",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/*",
|
||||
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
@ -834,3 +848,69 @@ def test_flush_cache(model_list):
|
|||
assert router.cache.get_cache("test") == "test"
|
||||
router.flush_cache()
|
||||
assert router.cache.get_cache("test") is None
|
||||
|
||||
|
||||
def test_initialize_assistants_endpoint(model_list):
|
||||
"""Test if the 'initialize_assistants_endpoint' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
router.initialize_assistants_endpoint()
|
||||
assert router.acreate_assistants is not None
|
||||
assert router.adelete_assistant is not None
|
||||
assert router.aget_assistants is not None
|
||||
assert router.acreate_thread is not None
|
||||
assert router.aget_thread is not None
|
||||
assert router.arun_thread is not None
|
||||
assert router.aget_messages is not None
|
||||
assert router.a_add_message is not None
|
||||
|
||||
|
||||
def test_pass_through_assistants_endpoint_factory(model_list):
|
||||
"""Test if the 'pass_through_assistants_endpoint_factory' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
router._pass_through_assistants_endpoint_factory(
|
||||
original_function=litellm.acreate_assistants,
|
||||
custom_llm_provider="openai",
|
||||
client=None,
|
||||
**{},
|
||||
)
|
||||
|
||||
|
||||
def test_factory_function(model_list):
|
||||
"""Test if the 'factory_function' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
router.factory_function(litellm.acreate_assistants)
|
||||
|
||||
|
||||
def test_get_model_from_alias(model_list):
|
||||
"""Test if the 'get_model_from_alias' function is working correctly"""
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
model_group_alias={"gpt-4o": "gpt-3.5-turbo"},
|
||||
)
|
||||
model = router._get_model_from_alias(model="gpt-4o")
|
||||
assert model == "gpt-3.5-turbo"
|
||||
|
||||
|
||||
def test_get_deployment_by_litellm_model(model_list):
|
||||
"""Test if the 'get_deployment_by_litellm_model' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
deployment = router._get_deployment_by_litellm_model(model="gpt-3.5-turbo")
|
||||
assert deployment is not None
|
||||
|
||||
|
||||
def test_get_pattern(model_list):
|
||||
router = Router(model_list=model_list)
|
||||
pattern = router.pattern_router.get_pattern(model="claude-3")
|
||||
assert pattern is not None
|
||||
|
||||
|
||||
def test_deployments_by_pattern(model_list):
|
||||
router = Router(model_list=model_list)
|
||||
deployments = router.pattern_router.get_deployments_by_pattern(model="claude-3")
|
||||
assert deployments is not None
|
||||
|
||||
|
||||
def test_replace_model_in_jsonl(model_list):
|
||||
router = Router(model_list=model_list)
|
||||
deployments = router.pattern_router.get_deployments_by_pattern(model="claude-3")
|
||||
assert deployments is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue