mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Litellm router code coverage 3 (#6274)
* refactor(router.py): move assistants api endpoints to using 1 pass-through factory function Reduces code, increases testing coverage * refactor(router.py): reduce _common_check_available_deployment function size make code more maintainable - reduce possible errors * test(router_code_coverage.py): include batch_utils + pattern matching in enforced 100% code coverage Improves reliability * fix(router.py): fix model id match model dump
This commit is contained in:
parent
891e9001b5
commit
e22e8d24ef
8 changed files with 407 additions and 244 deletions
|
@ -24,7 +24,18 @@ import traceback
|
|||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
|
@ -520,6 +531,19 @@ class Router:
|
|||
if self.alerting_config is not None:
|
||||
self._initialize_alerting()
|
||||
|
||||
self.initialize_assistants_endpoint()
|
||||
|
||||
def initialize_assistants_endpoint(self):
|
||||
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
|
||||
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
|
||||
self.adelete_assistant = self.factory_function(litellm.adelete_assistant)
|
||||
self.aget_assistants = self.factory_function(litellm.aget_assistants)
|
||||
self.acreate_thread = self.factory_function(litellm.acreate_thread)
|
||||
self.aget_thread = self.factory_function(litellm.aget_thread)
|
||||
self.a_add_message = self.factory_function(litellm.a_add_message)
|
||||
self.aget_messages = self.factory_function(litellm.aget_messages)
|
||||
self.arun_thread = self.factory_function(litellm.arun_thread)
|
||||
|
||||
def validate_fallbacks(self, fallback_param: Optional[List]):
|
||||
"""
|
||||
Validate the fallbacks parameter.
|
||||
|
@ -2167,7 +2191,6 @@ class Router:
|
|||
raise e
|
||||
|
||||
#### FILES API ####
|
||||
|
||||
async def acreate_file(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -2504,114 +2527,29 @@ class Router:
|
|||
|
||||
#### ASSISTANTS API ####
|
||||
|
||||
async def acreate_assistants(
|
||||
self,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> Assistant:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.acreate_assistants(
|
||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||
)
|
||||
|
||||
async def adelete_assistant(
|
||||
self,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> AssistantDeleted:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.adelete_assistant(
|
||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||
)
|
||||
|
||||
async def aget_assistants(
|
||||
self,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> AsyncCursorPage[Assistant]:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.aget_assistants(
|
||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||
)
|
||||
|
||||
async def acreate_thread(
|
||||
self,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> Thread:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
return await litellm.acreate_thread(
|
||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||
)
|
||||
|
||||
async def aget_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> Thread:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
return await litellm.aget_thread(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
client=client,
|
||||
def factory_function(self, original_function: Callable):
|
||||
async def new_function(
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional["AsyncOpenAI"] = None,
|
||||
**kwargs,
|
||||
)
|
||||
):
|
||||
return await self._pass_through_assistants_endpoint_factory(
|
||||
original_function=original_function,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
client=client,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def a_add_message(
|
||||
return new_function
|
||||
|
||||
async def _pass_through_assistants_endpoint_factory(
|
||||
self,
|
||||
thread_id: str,
|
||||
role: Literal["user", "assistant"],
|
||||
content: str,
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
original_function: Callable,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> OpenAIMessage:
|
||||
):
|
||||
"""Internal helper function to pass through the assistants endpoint"""
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
|
@ -2620,76 +2558,8 @@ class Router:
|
|||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.a_add_message(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
role=role,
|
||||
content=content,
|
||||
attachments=attachments,
|
||||
metadata=metadata,
|
||||
client=client,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def aget_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
) -> AsyncCursorPage[OpenAIMessage]:
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
return await litellm.aget_messages(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
client=client,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def arun_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
additional_instructions: Optional[str] = None,
|
||||
instructions: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
tools: Optional[Iterable[AssistantToolParam]] = None,
|
||||
client: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> Run:
|
||||
|
||||
if custom_llm_provider is None:
|
||||
if self.assistants_config is not None:
|
||||
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||
kwargs.update(self.assistants_config["litellm_params"])
|
||||
else:
|
||||
raise Exception(
|
||||
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||
)
|
||||
|
||||
return await litellm.arun_thread(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
client=client,
|
||||
**kwargs,
|
||||
return await original_function( # type: ignore
|
||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||
)
|
||||
|
||||
#### [END] ASSISTANTS API ####
|
||||
|
@ -4609,7 +4479,7 @@ class Router:
|
|||
"""
|
||||
returned_models: List[DeploymentTypedDict] = []
|
||||
for model in self.model_list:
|
||||
if model["model_name"] == model_name:
|
||||
if model_name is not None and model["model_name"] == model_name:
|
||||
if model_alias is not None:
|
||||
alias_model = copy.deepcopy(model)
|
||||
alias_model["model_name"] = model_alias
|
||||
|
@ -5007,82 +4877,73 @@ class Router:
|
|||
|
||||
return _returned_deployments
|
||||
|
||||
def _get_model_from_alias(self, model: str) -> Optional[str]:
|
||||
"""
|
||||
Get the model from the alias.
|
||||
|
||||
Returns:
|
||||
- str, the litellm model name
|
||||
- None, if model is not in model group alias
|
||||
"""
|
||||
if model not in self.model_group_alias:
|
||||
return None
|
||||
|
||||
_item = self.model_group_alias[model]
|
||||
if isinstance(_item, str):
|
||||
model = _item
|
||||
else:
|
||||
model = _item["model"]
|
||||
|
||||
return model
|
||||
|
||||
def _get_deployment_by_litellm_model(self, model: str) -> List:
|
||||
"""
|
||||
Get the deployment by litellm model.
|
||||
"""
|
||||
return [m for m in self.model_list if m["litellm_params"]["model"] == model]
|
||||
|
||||
def _common_checks_available_deployment(
|
||||
self,
|
||||
model: str,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
) -> Tuple[str, Union[list, dict]]:
|
||||
) -> Tuple[str, Union[List, Dict]]:
|
||||
"""
|
||||
Common checks for 'get_available_deployment' across sync + async call.
|
||||
|
||||
If 'healthy_deployments' returned is None, this means the user chose a specific deployment
|
||||
|
||||
Returns
|
||||
- Dict, if specific model chosen
|
||||
- str, the litellm model name
|
||||
- List, if multiple models chosen
|
||||
- Dict, if specific model chosen
|
||||
"""
|
||||
|
||||
# check if aliases set on litellm model alias map
|
||||
if specific_deployment is True:
|
||||
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
|
||||
for deployment in self.model_list:
|
||||
deployment_model = deployment.get("litellm_params").get("model")
|
||||
if deployment_model == model:
|
||||
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
|
||||
# return the first deployment where the `model` matches the specificed deployment name
|
||||
return deployment_model, deployment
|
||||
raise ValueError(
|
||||
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.get_model_names()}"
|
||||
)
|
||||
return model, self._get_deployment_by_litellm_model(model=model)
|
||||
elif model in self.get_model_ids():
|
||||
deployment = self.get_model_info(id=model)
|
||||
deployment = self.get_deployment(model_id=model)
|
||||
if deployment is not None:
|
||||
deployment_model = deployment.get("litellm_params", {}).get("model")
|
||||
return deployment_model, deployment
|
||||
deployment_model = deployment.litellm_params.model
|
||||
return deployment_model, deployment.model_dump(exclude_none=True)
|
||||
raise ValueError(
|
||||
f"LiteLLM Router: Trying to call specific deployment, but Model ID :{model} does not exist in \
|
||||
Model ID List: {self.get_model_ids}"
|
||||
)
|
||||
|
||||
if model in self.model_group_alias:
|
||||
_item = self.model_group_alias[model]
|
||||
if isinstance(_item, str):
|
||||
model = _item
|
||||
else:
|
||||
model = _item["model"]
|
||||
_model_from_alias = self._get_model_from_alias(model=model)
|
||||
if _model_from_alias is not None:
|
||||
model = _model_from_alias
|
||||
|
||||
if model not in self.model_names:
|
||||
# check if provider/ specific wildcard routing use pattern matching
|
||||
custom_llm_provider: Optional[str] = None
|
||||
try:
|
||||
(
|
||||
_,
|
||||
custom_llm_provider,
|
||||
_,
|
||||
_,
|
||||
) = litellm.get_llm_provider(model=model)
|
||||
except Exception:
|
||||
# get_llm_provider raises exception when provider is unknown
|
||||
pass
|
||||
|
||||
"""
|
||||
self.pattern_router.route(model):
|
||||
does exact pattern matching. Example openai/gpt-3.5-turbo gets routed to pattern openai/*
|
||||
|
||||
self.pattern_router.route(f"{custom_llm_provider}/{model}"):
|
||||
does pattern matching using litellm.get_llm_provider(), example claude-3-5-sonnet-20240620 gets routed to anthropic/* since 'claude-3-5-sonnet-20240620' is an Anthropic Model
|
||||
"""
|
||||
_pattern_router_response = self.pattern_router.route(
|
||||
model
|
||||
) or self.pattern_router.route(f"{custom_llm_provider}/{model}")
|
||||
if _pattern_router_response is not None:
|
||||
provider_deployments = []
|
||||
for deployment in _pattern_router_response:
|
||||
dep = copy.deepcopy(deployment)
|
||||
dep["litellm_params"]["model"] = model
|
||||
provider_deployments.append(dep)
|
||||
return model, provider_deployments
|
||||
pattern_deployments = self.pattern_router.get_deployments_by_pattern(
|
||||
model=model,
|
||||
)
|
||||
if pattern_deployments:
|
||||
return model, pattern_deployments
|
||||
|
||||
# check if default deployment is set
|
||||
if self.default_deployment is not None:
|
||||
|
@ -5094,12 +4955,11 @@ class Router:
|
|||
|
||||
## get healthy deployments
|
||||
### get all deployments
|
||||
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
||||
healthy_deployments = self._get_all_deployments(model_name=model)
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
# check if the user sent in a deployment name instead
|
||||
healthy_deployments = [
|
||||
m for m in self.model_list if m["litellm_params"]["model"] == model
|
||||
]
|
||||
healthy_deployments = self._get_deployment_by_litellm_model(model=model)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"initial list of deployments: {healthy_deployments}"
|
||||
|
@ -5151,7 +5011,6 @@ class Router:
|
|||
input=input,
|
||||
specific_deployment=specific_deployment,
|
||||
) # type: ignore
|
||||
|
||||
if isinstance(healthy_deployments, dict):
|
||||
return healthy_deployments
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue