mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Litellm router code coverage 3 (#6274)
* refactor(router.py): move assistants api endpoints to using 1 pass-through factory function Reduces code, increases testing coverage * refactor(router.py): reduce _common_check_available_deployment function size make code more maintainable - reduce possible errors * test(router_code_coverage.py): include batch_utils + pattern matching in enforced 100% code coverage Improves reliability * fix(router.py): fix model id match model dump
This commit is contained in:
parent
891e9001b5
commit
e22e8d24ef
8 changed files with 407 additions and 244 deletions
|
@ -25,3 +25,9 @@ model_list:
|
||||||
# guard_name: "gibberish_guard"
|
# guard_name: "gibberish_guard"
|
||||||
# mode: "post_call"
|
# mode: "post_call"
|
||||||
# api_base: os.environ/GUARDRAILS_AI_API_BASE
|
# api_base: os.environ/GUARDRAILS_AI_API_BASE
|
||||||
|
|
||||||
|
assistant_settings:
|
||||||
|
custom_llm_provider: azure
|
||||||
|
litellm_params:
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
|
@ -24,7 +24,18 @@ import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
TypedDict,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
|
@ -520,6 +531,19 @@ class Router:
|
||||||
if self.alerting_config is not None:
|
if self.alerting_config is not None:
|
||||||
self._initialize_alerting()
|
self._initialize_alerting()
|
||||||
|
|
||||||
|
self.initialize_assistants_endpoint()
|
||||||
|
|
||||||
|
def initialize_assistants_endpoint(self):
|
||||||
|
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
|
||||||
|
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
|
||||||
|
self.adelete_assistant = self.factory_function(litellm.adelete_assistant)
|
||||||
|
self.aget_assistants = self.factory_function(litellm.aget_assistants)
|
||||||
|
self.acreate_thread = self.factory_function(litellm.acreate_thread)
|
||||||
|
self.aget_thread = self.factory_function(litellm.aget_thread)
|
||||||
|
self.a_add_message = self.factory_function(litellm.a_add_message)
|
||||||
|
self.aget_messages = self.factory_function(litellm.aget_messages)
|
||||||
|
self.arun_thread = self.factory_function(litellm.arun_thread)
|
||||||
|
|
||||||
def validate_fallbacks(self, fallback_param: Optional[List]):
|
def validate_fallbacks(self, fallback_param: Optional[List]):
|
||||||
"""
|
"""
|
||||||
Validate the fallbacks parameter.
|
Validate the fallbacks parameter.
|
||||||
|
@ -2167,7 +2191,6 @@ class Router:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
#### FILES API ####
|
#### FILES API ####
|
||||||
|
|
||||||
async def acreate_file(
|
async def acreate_file(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -2504,114 +2527,29 @@ class Router:
|
||||||
|
|
||||||
#### ASSISTANTS API ####
|
#### ASSISTANTS API ####
|
||||||
|
|
||||||
async def acreate_assistants(
|
def factory_function(self, original_function: Callable):
|
||||||
self,
|
async def new_function(
|
||||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||||
client: Optional[AsyncOpenAI] = None,
|
client: Optional["AsyncOpenAI"] = None,
|
||||||
**kwargs,
|
|
||||||
) -> Assistant:
|
|
||||||
if custom_llm_provider is None:
|
|
||||||
if self.assistants_config is not None:
|
|
||||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
|
||||||
kwargs.update(self.assistants_config["litellm_params"])
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
|
||||||
)
|
|
||||||
|
|
||||||
return await litellm.acreate_assistants(
|
|
||||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
async def adelete_assistant(
|
|
||||||
self,
|
|
||||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
|
||||||
client: Optional[AsyncOpenAI] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> AssistantDeleted:
|
|
||||||
if custom_llm_provider is None:
|
|
||||||
if self.assistants_config is not None:
|
|
||||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
|
||||||
kwargs.update(self.assistants_config["litellm_params"])
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
|
||||||
)
|
|
||||||
|
|
||||||
return await litellm.adelete_assistant(
|
|
||||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
async def aget_assistants(
|
|
||||||
self,
|
|
||||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
|
||||||
client: Optional[AsyncOpenAI] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> AsyncCursorPage[Assistant]:
|
|
||||||
if custom_llm_provider is None:
|
|
||||||
if self.assistants_config is not None:
|
|
||||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
|
||||||
kwargs.update(self.assistants_config["litellm_params"])
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
|
||||||
)
|
|
||||||
|
|
||||||
return await litellm.aget_assistants(
|
|
||||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
async def acreate_thread(
|
|
||||||
self,
|
|
||||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
|
||||||
client: Optional[AsyncOpenAI] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Thread:
|
|
||||||
if custom_llm_provider is None:
|
|
||||||
if self.assistants_config is not None:
|
|
||||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
|
||||||
kwargs.update(self.assistants_config["litellm_params"])
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
|
||||||
)
|
|
||||||
return await litellm.acreate_thread(
|
|
||||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
async def aget_thread(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
|
||||||
client: Optional[AsyncOpenAI] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Thread:
|
|
||||||
if custom_llm_provider is None:
|
|
||||||
if self.assistants_config is not None:
|
|
||||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
|
||||||
kwargs.update(self.assistants_config["litellm_params"])
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
|
||||||
)
|
|
||||||
return await litellm.aget_thread(
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
thread_id=thread_id,
|
|
||||||
client=client,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
):
|
||||||
|
return await self._pass_through_assistants_endpoint_factory(
|
||||||
|
original_function=original_function,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
client=client,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
async def a_add_message(
|
return new_function
|
||||||
|
|
||||||
|
async def _pass_through_assistants_endpoint_factory(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
original_function: Callable,
|
||||||
role: Literal["user", "assistant"],
|
|
||||||
content: str,
|
|
||||||
attachments: Optional[List[Attachment]] = None,
|
|
||||||
metadata: Optional[dict] = None,
|
|
||||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||||
client: Optional[AsyncOpenAI] = None,
|
client: Optional[AsyncOpenAI] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> OpenAIMessage:
|
):
|
||||||
|
"""Internal helper function to pass through the assistants endpoint"""
|
||||||
if custom_llm_provider is None:
|
if custom_llm_provider is None:
|
||||||
if self.assistants_config is not None:
|
if self.assistants_config is not None:
|
||||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||||
|
@ -2620,76 +2558,8 @@ class Router:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||||
)
|
)
|
||||||
|
return await original_function( # type: ignore
|
||||||
return await litellm.a_add_message(
|
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
thread_id=thread_id,
|
|
||||||
role=role,
|
|
||||||
content=content,
|
|
||||||
attachments=attachments,
|
|
||||||
metadata=metadata,
|
|
||||||
client=client,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def aget_messages(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
|
||||||
client: Optional[AsyncOpenAI] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> AsyncCursorPage[OpenAIMessage]:
|
|
||||||
if custom_llm_provider is None:
|
|
||||||
if self.assistants_config is not None:
|
|
||||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
|
||||||
kwargs.update(self.assistants_config["litellm_params"])
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
|
||||||
)
|
|
||||||
return await litellm.aget_messages(
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
thread_id=thread_id,
|
|
||||||
client=client,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def arun_thread(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
assistant_id: str,
|
|
||||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
|
||||||
additional_instructions: Optional[str] = None,
|
|
||||||
instructions: Optional[str] = None,
|
|
||||||
metadata: Optional[dict] = None,
|
|
||||||
model: Optional[str] = None,
|
|
||||||
stream: Optional[bool] = None,
|
|
||||||
tools: Optional[Iterable[AssistantToolParam]] = None,
|
|
||||||
client: Optional[Any] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Run:
|
|
||||||
|
|
||||||
if custom_llm_provider is None:
|
|
||||||
if self.assistants_config is not None:
|
|
||||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
|
||||||
kwargs.update(self.assistants_config["litellm_params"])
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
|
||||||
)
|
|
||||||
|
|
||||||
return await litellm.arun_thread(
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
thread_id=thread_id,
|
|
||||||
assistant_id=assistant_id,
|
|
||||||
additional_instructions=additional_instructions,
|
|
||||||
instructions=instructions,
|
|
||||||
metadata=metadata,
|
|
||||||
model=model,
|
|
||||||
stream=stream,
|
|
||||||
tools=tools,
|
|
||||||
client=client,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
#### [END] ASSISTANTS API ####
|
#### [END] ASSISTANTS API ####
|
||||||
|
@ -4609,7 +4479,7 @@ class Router:
|
||||||
"""
|
"""
|
||||||
returned_models: List[DeploymentTypedDict] = []
|
returned_models: List[DeploymentTypedDict] = []
|
||||||
for model in self.model_list:
|
for model in self.model_list:
|
||||||
if model["model_name"] == model_name:
|
if model_name is not None and model["model_name"] == model_name:
|
||||||
if model_alias is not None:
|
if model_alias is not None:
|
||||||
alias_model = copy.deepcopy(model)
|
alias_model = copy.deepcopy(model)
|
||||||
alias_model["model_name"] = model_alias
|
alias_model["model_name"] = model_alias
|
||||||
|
@ -5007,82 +4877,73 @@ class Router:
|
||||||
|
|
||||||
return _returned_deployments
|
return _returned_deployments
|
||||||
|
|
||||||
|
def _get_model_from_alias(self, model: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the model from the alias.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- str, the litellm model name
|
||||||
|
- None, if model is not in model group alias
|
||||||
|
"""
|
||||||
|
if model not in self.model_group_alias:
|
||||||
|
return None
|
||||||
|
|
||||||
|
_item = self.model_group_alias[model]
|
||||||
|
if isinstance(_item, str):
|
||||||
|
model = _item
|
||||||
|
else:
|
||||||
|
model = _item["model"]
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _get_deployment_by_litellm_model(self, model: str) -> List:
|
||||||
|
"""
|
||||||
|
Get the deployment by litellm model.
|
||||||
|
"""
|
||||||
|
return [m for m in self.model_list if m["litellm_params"]["model"] == model]
|
||||||
|
|
||||||
def _common_checks_available_deployment(
|
def _common_checks_available_deployment(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: Optional[List[Dict[str, str]]] = None,
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
input: Optional[Union[str, List]] = None,
|
input: Optional[Union[str, List]] = None,
|
||||||
specific_deployment: Optional[bool] = False,
|
specific_deployment: Optional[bool] = False,
|
||||||
) -> Tuple[str, Union[list, dict]]:
|
) -> Tuple[str, Union[List, Dict]]:
|
||||||
"""
|
"""
|
||||||
Common checks for 'get_available_deployment' across sync + async call.
|
Common checks for 'get_available_deployment' across sync + async call.
|
||||||
|
|
||||||
If 'healthy_deployments' returned is None, this means the user chose a specific deployment
|
If 'healthy_deployments' returned is None, this means the user chose a specific deployment
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
- Dict, if specific model chosen
|
- str, the litellm model name
|
||||||
- List, if multiple models chosen
|
- List, if multiple models chosen
|
||||||
|
- Dict, if specific model chosen
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# check if aliases set on litellm model alias map
|
# check if aliases set on litellm model alias map
|
||||||
if specific_deployment is True:
|
if specific_deployment is True:
|
||||||
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
|
return model, self._get_deployment_by_litellm_model(model=model)
|
||||||
for deployment in self.model_list:
|
|
||||||
deployment_model = deployment.get("litellm_params").get("model")
|
|
||||||
if deployment_model == model:
|
|
||||||
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
|
|
||||||
# return the first deployment where the `model` matches the specificed deployment name
|
|
||||||
return deployment_model, deployment
|
|
||||||
raise ValueError(
|
|
||||||
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.get_model_names()}"
|
|
||||||
)
|
|
||||||
elif model in self.get_model_ids():
|
elif model in self.get_model_ids():
|
||||||
deployment = self.get_model_info(id=model)
|
deployment = self.get_deployment(model_id=model)
|
||||||
if deployment is not None:
|
if deployment is not None:
|
||||||
deployment_model = deployment.get("litellm_params", {}).get("model")
|
deployment_model = deployment.litellm_params.model
|
||||||
return deployment_model, deployment
|
return deployment_model, deployment.model_dump(exclude_none=True)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"LiteLLM Router: Trying to call specific deployment, but Model ID :{model} does not exist in \
|
f"LiteLLM Router: Trying to call specific deployment, but Model ID :{model} does not exist in \
|
||||||
Model ID List: {self.get_model_ids}"
|
Model ID List: {self.get_model_ids}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if model in self.model_group_alias:
|
_model_from_alias = self._get_model_from_alias(model=model)
|
||||||
_item = self.model_group_alias[model]
|
if _model_from_alias is not None:
|
||||||
if isinstance(_item, str):
|
model = _model_from_alias
|
||||||
model = _item
|
|
||||||
else:
|
|
||||||
model = _item["model"]
|
|
||||||
|
|
||||||
if model not in self.model_names:
|
if model not in self.model_names:
|
||||||
# check if provider/ specific wildcard routing use pattern matching
|
# check if provider/ specific wildcard routing use pattern matching
|
||||||
custom_llm_provider: Optional[str] = None
|
pattern_deployments = self.pattern_router.get_deployments_by_pattern(
|
||||||
try:
|
model=model,
|
||||||
(
|
)
|
||||||
_,
|
if pattern_deployments:
|
||||||
custom_llm_provider,
|
return model, pattern_deployments
|
||||||
_,
|
|
||||||
_,
|
|
||||||
) = litellm.get_llm_provider(model=model)
|
|
||||||
except Exception:
|
|
||||||
# get_llm_provider raises exception when provider is unknown
|
|
||||||
pass
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.pattern_router.route(model):
|
|
||||||
does exact pattern matching. Example openai/gpt-3.5-turbo gets routed to pattern openai/*
|
|
||||||
|
|
||||||
self.pattern_router.route(f"{custom_llm_provider}/{model}"):
|
|
||||||
does pattern matching using litellm.get_llm_provider(), example claude-3-5-sonnet-20240620 gets routed to anthropic/* since 'claude-3-5-sonnet-20240620' is an Anthropic Model
|
|
||||||
"""
|
|
||||||
_pattern_router_response = self.pattern_router.route(
|
|
||||||
model
|
|
||||||
) or self.pattern_router.route(f"{custom_llm_provider}/{model}")
|
|
||||||
if _pattern_router_response is not None:
|
|
||||||
provider_deployments = []
|
|
||||||
for deployment in _pattern_router_response:
|
|
||||||
dep = copy.deepcopy(deployment)
|
|
||||||
dep["litellm_params"]["model"] = model
|
|
||||||
provider_deployments.append(dep)
|
|
||||||
return model, provider_deployments
|
|
||||||
|
|
||||||
# check if default deployment is set
|
# check if default deployment is set
|
||||||
if self.default_deployment is not None:
|
if self.default_deployment is not None:
|
||||||
|
@ -5094,12 +4955,11 @@ class Router:
|
||||||
|
|
||||||
## get healthy deployments
|
## get healthy deployments
|
||||||
### get all deployments
|
### get all deployments
|
||||||
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
healthy_deployments = self._get_all_deployments(model_name=model)
|
||||||
|
|
||||||
if len(healthy_deployments) == 0:
|
if len(healthy_deployments) == 0:
|
||||||
# check if the user sent in a deployment name instead
|
# check if the user sent in a deployment name instead
|
||||||
healthy_deployments = [
|
healthy_deployments = self._get_deployment_by_litellm_model(model=model)
|
||||||
m for m in self.model_list if m["litellm_params"]["model"] == model
|
|
||||||
]
|
|
||||||
|
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"initial list of deployments: {healthy_deployments}"
|
f"initial list of deployments: {healthy_deployments}"
|
||||||
|
@ -5151,7 +5011,6 @@ class Router:
|
||||||
input=input,
|
input=input,
|
||||||
specific_deployment=specific_deployment,
|
specific_deployment=specific_deployment,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
if isinstance(healthy_deployments, dict):
|
if isinstance(healthy_deployments, dict):
|
||||||
return healthy_deployments
|
return healthy_deployments
|
||||||
|
|
||||||
|
|
|
@ -2,9 +2,11 @@
|
||||||
Class to handle llm wildcard routing and regex pattern matching
|
Class to handle llm wildcard routing and regex pattern matching
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from litellm import get_llm_provider
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,6 +84,55 @@ class PatternMatchRouter:
|
||||||
|
|
||||||
return None # No matching pattern found
|
return None # No matching pattern found
|
||||||
|
|
||||||
|
def get_pattern(
|
||||||
|
self, model: str, custom_llm_provider: Optional[str] = None
|
||||||
|
) -> Optional[List[Dict]]:
|
||||||
|
"""
|
||||||
|
Check if a pattern exists for the given model and custom llm provider
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: str
|
||||||
|
custom_llm_provider: Optional[str]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if pattern exists, False otherwise
|
||||||
|
"""
|
||||||
|
if custom_llm_provider is None:
|
||||||
|
try:
|
||||||
|
(
|
||||||
|
_,
|
||||||
|
custom_llm_provider,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = get_llm_provider(model=model)
|
||||||
|
except Exception:
|
||||||
|
# get_llm_provider raises exception when provider is unknown
|
||||||
|
pass
|
||||||
|
return self.route(model) or self.route(f"{custom_llm_provider}/{model}")
|
||||||
|
|
||||||
|
def get_deployments_by_pattern(
|
||||||
|
self, model: str, custom_llm_provider: Optional[str] = None
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Get the deployments by pattern
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: str
|
||||||
|
custom_llm_provider: Optional[str]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: llm deployments matching the pattern
|
||||||
|
"""
|
||||||
|
pattern_match = self.get_pattern(model, custom_llm_provider)
|
||||||
|
if pattern_match:
|
||||||
|
provider_deployments = []
|
||||||
|
for deployment in pattern_match:
|
||||||
|
dep = copy.deepcopy(deployment)
|
||||||
|
dep["litellm_params"]["model"] = model
|
||||||
|
provider_deployments.append(dep)
|
||||||
|
return provider_deployments
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
# Example usage:
|
# Example usage:
|
||||||
# router = PatternRouter()
|
# router = PatternRouter()
|
||||||
|
|
|
@ -75,29 +75,28 @@ def get_functions_from_router(file_path):
|
||||||
|
|
||||||
ignored_function_names = [
|
ignored_function_names = [
|
||||||
"__init__",
|
"__init__",
|
||||||
"_acreate_file",
|
|
||||||
"_acreate_batch",
|
|
||||||
"acreate_assistants",
|
|
||||||
"adelete_assistant",
|
|
||||||
"aget_assistants",
|
|
||||||
"acreate_thread",
|
|
||||||
"aget_thread",
|
|
||||||
"a_add_message",
|
|
||||||
"aget_messages",
|
|
||||||
"arun_thread",
|
|
||||||
"try_retrieve_batch",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
router_file = "./litellm/router.py" # Update this path if it's located elsewhere
|
router_file = [
|
||||||
# router_file = "../../litellm/router.py" ## LOCAL TESTING
|
"./litellm/router.py",
|
||||||
|
"./litellm/router_utils/batch_utils.py",
|
||||||
|
"./litellm/router_utils/pattern_match_deployments.py",
|
||||||
|
]
|
||||||
|
# router_file = [
|
||||||
|
# "../../litellm/router.py",
|
||||||
|
# "../../litellm/router_utils/pattern_match_deployments.py",
|
||||||
|
# "../../litellm/router_utils/batch_utils.py",
|
||||||
|
# ] ## LOCAL TESTING
|
||||||
tests_dir = (
|
tests_dir = (
|
||||||
"./tests/" # Update this path if your tests directory is located elsewhere
|
"./tests/" # Update this path if your tests directory is located elsewhere
|
||||||
)
|
)
|
||||||
# tests_dir = "../../tests/" # LOCAL TESTING
|
# tests_dir = "../../tests/" # LOCAL TESTING
|
||||||
|
|
||||||
router_functions = get_functions_from_router(router_file)
|
router_functions = []
|
||||||
|
for file in router_file:
|
||||||
|
router_functions.extend(get_functions_from_router(file))
|
||||||
print("router_functions: ", router_functions)
|
print("router_functions: ", router_functions)
|
||||||
called_functions_in_tests = get_all_functions_called_in_tests(tests_dir)
|
called_functions_in_tests = get_all_functions_called_in_tests(tests_dir)
|
||||||
untested_functions = [
|
untested_functions = [
|
||||||
|
|
66
tests/code_coverage_tests/router_enforce_line_length.py
Normal file
66
tests/code_coverage_tests/router_enforce_line_length.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
import ast
|
||||||
|
import os
|
||||||
|
|
||||||
|
MAX_FUNCTION_LINES = 100
|
||||||
|
|
||||||
|
|
||||||
|
def get_function_line_counts(file_path):
|
||||||
|
"""
|
||||||
|
Extracts all function names and their line counts from a given Python file.
|
||||||
|
"""
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
tree = ast.parse(file.read())
|
||||||
|
|
||||||
|
function_line_counts = []
|
||||||
|
|
||||||
|
for node in tree.body:
|
||||||
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||||
|
# Top-level functions
|
||||||
|
line_count = node.end_lineno - node.lineno + 1
|
||||||
|
function_line_counts.append((node.name, line_count))
|
||||||
|
elif isinstance(node, ast.ClassDef):
|
||||||
|
# Functions inside classes
|
||||||
|
for class_node in node.body:
|
||||||
|
if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||||
|
line_count = class_node.end_lineno - class_node.lineno + 1
|
||||||
|
function_line_counts.append((class_node.name, line_count))
|
||||||
|
|
||||||
|
return function_line_counts
|
||||||
|
|
||||||
|
|
||||||
|
ignored_functions = [
|
||||||
|
"__init__",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def check_function_lengths(router_file):
|
||||||
|
"""
|
||||||
|
Checks if any function in the specified file exceeds the maximum allowed length.
|
||||||
|
"""
|
||||||
|
function_line_counts = get_function_line_counts(router_file)
|
||||||
|
long_functions = [
|
||||||
|
(name, count)
|
||||||
|
for name, count in function_line_counts
|
||||||
|
if count > MAX_FUNCTION_LINES and name not in ignored_functions
|
||||||
|
]
|
||||||
|
|
||||||
|
if long_functions:
|
||||||
|
print("The following functions exceed the allowed line count:")
|
||||||
|
for name, count in long_functions:
|
||||||
|
print(f"- {name}: {count} lines")
|
||||||
|
raise Exception(
|
||||||
|
f"{len(long_functions)} functions in {router_file} exceed {MAX_FUNCTION_LINES} lines"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("All functions in the router file are within the allowed line limit.")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Update this path to point to the correct location of router.py
|
||||||
|
router_file = "../../litellm/router.py" # LOCAL TESTING
|
||||||
|
|
||||||
|
check_function_lengths(router_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -2569,6 +2569,15 @@ async def test_router_batch_endpoints(provider):
|
||||||
)
|
)
|
||||||
print("Response from creating file=", file_obj)
|
print("Response from creating file=", file_obj)
|
||||||
|
|
||||||
|
## TEST 2 - test underlying create_file function
|
||||||
|
file_obj = await router._acreate_file(
|
||||||
|
model="my-custom-name",
|
||||||
|
file=open(file_path, "rb"),
|
||||||
|
purpose="batch",
|
||||||
|
custom_llm_provider=provider,
|
||||||
|
)
|
||||||
|
print("Response from creating file=", file_obj)
|
||||||
|
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
batch_input_file_id = file_obj.id
|
batch_input_file_id = file_obj.id
|
||||||
assert (
|
assert (
|
||||||
|
@ -2583,6 +2592,15 @@ async def test_router_batch_endpoints(provider):
|
||||||
custom_llm_provider=provider,
|
custom_llm_provider=provider,
|
||||||
metadata={"key1": "value1", "key2": "value2"},
|
metadata={"key1": "value1", "key2": "value2"},
|
||||||
)
|
)
|
||||||
|
## TEST 2 - test underlying create_batch function
|
||||||
|
create_batch_response = await router._acreate_batch(
|
||||||
|
model="my-custom-name",
|
||||||
|
completion_window="24h",
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
input_file_id=batch_input_file_id,
|
||||||
|
custom_llm_provider=provider,
|
||||||
|
metadata={"key1": "value1", "key2": "value2"},
|
||||||
|
)
|
||||||
|
|
||||||
print("response from router.create_batch=", create_batch_response)
|
print("response from router.create_batch=", create_batch_response)
|
||||||
|
|
||||||
|
|
84
tests/router_unit_tests/test_router_batch_utils.py
Normal file
84
tests/router_unit_tests/test_router_batch_utils.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import Request
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
from litellm import Router
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock
|
||||||
|
|
||||||
|
import json
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Dict, List
|
||||||
|
from litellm.router_utils.batch_utils import (
|
||||||
|
replace_model_in_jsonl,
|
||||||
|
_get_router_metadata_variable_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_jsonl_data() -> List[Dict]:
|
||||||
|
"""Fixture providing sample JSONL data"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"body": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"body": {"model": "gpt-4", "messages": [{"role": "user", "content": "Hi"}]}},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_jsonl_bytes(sample_jsonl_data) -> bytes:
|
||||||
|
"""Fixture providing sample JSONL as bytes"""
|
||||||
|
jsonl_str = "\n".join(json.dumps(line) for line in sample_jsonl_data)
|
||||||
|
return jsonl_str.encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_file_like(sample_jsonl_bytes):
|
||||||
|
"""Fixture providing a file-like object"""
|
||||||
|
return BytesIO(sample_jsonl_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
# Test cases
|
||||||
|
def test_bytes_input(sample_jsonl_bytes):
|
||||||
|
"""Test with bytes input"""
|
||||||
|
new_model = "claude-3"
|
||||||
|
result = replace_model_in_jsonl(sample_jsonl_bytes, new_model)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_tuple_input(sample_jsonl_bytes):
|
||||||
|
"""Test with tuple input"""
|
||||||
|
new_model = "claude-3"
|
||||||
|
test_tuple = ("test.jsonl", sample_jsonl_bytes, "application/json")
|
||||||
|
result = replace_model_in_jsonl(test_tuple, new_model)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_like_object(sample_file_like):
|
||||||
|
"""Test with file-like object input"""
|
||||||
|
new_model = "claude-3"
|
||||||
|
result = replace_model_in_jsonl(sample_file_like, new_model)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_metadata_variable_name():
|
||||||
|
"""Test that the variable name is correct"""
|
||||||
|
assert _get_router_metadata_variable_name(function_name="completion") == "metadata"
|
||||||
|
assert (
|
||||||
|
_get_router_metadata_variable_name(function_name="batch") == "litellm_metadata"
|
||||||
|
)
|
|
@ -41,6 +41,20 @@ def model_list():
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"model_name": "*",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/*",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "claude-*",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "anthropic/*",
|
||||||
|
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -834,3 +848,69 @@ def test_flush_cache(model_list):
|
||||||
assert router.cache.get_cache("test") == "test"
|
assert router.cache.get_cache("test") == "test"
|
||||||
router.flush_cache()
|
router.flush_cache()
|
||||||
assert router.cache.get_cache("test") is None
|
assert router.cache.get_cache("test") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_assistants_endpoint(model_list):
|
||||||
|
"""Test if the 'initialize_assistants_endpoint' function is working correctly"""
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
router.initialize_assistants_endpoint()
|
||||||
|
assert router.acreate_assistants is not None
|
||||||
|
assert router.adelete_assistant is not None
|
||||||
|
assert router.aget_assistants is not None
|
||||||
|
assert router.acreate_thread is not None
|
||||||
|
assert router.aget_thread is not None
|
||||||
|
assert router.arun_thread is not None
|
||||||
|
assert router.aget_messages is not None
|
||||||
|
assert router.a_add_message is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_pass_through_assistants_endpoint_factory(model_list):
|
||||||
|
"""Test if the 'pass_through_assistants_endpoint_factory' function is working correctly"""
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
router._pass_through_assistants_endpoint_factory(
|
||||||
|
original_function=litellm.acreate_assistants,
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
client=None,
|
||||||
|
**{},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_function(model_list):
|
||||||
|
"""Test if the 'factory_function' function is working correctly"""
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
router.factory_function(litellm.acreate_assistants)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_model_from_alias(model_list):
|
||||||
|
"""Test if the 'get_model_from_alias' function is working correctly"""
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
model_group_alias={"gpt-4o": "gpt-3.5-turbo"},
|
||||||
|
)
|
||||||
|
model = router._get_model_from_alias(model="gpt-4o")
|
||||||
|
assert model == "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_deployment_by_litellm_model(model_list):
|
||||||
|
"""Test if the 'get_deployment_by_litellm_model' function is working correctly"""
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
deployment = router._get_deployment_by_litellm_model(model="gpt-3.5-turbo")
|
||||||
|
assert deployment is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_pattern(model_list):
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
pattern = router.pattern_router.get_pattern(model="claude-3")
|
||||||
|
assert pattern is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_deployments_by_pattern(model_list):
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
deployments = router.pattern_router.get_deployments_by_pattern(model="claude-3")
|
||||||
|
assert deployments is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_model_in_jsonl(model_list):
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
deployments = router.pattern_router.get_deployments_by_pattern(model="claude-3")
|
||||||
|
assert deployments is not None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue