Litellm router code coverage 3 (#6274)

* refactor(router.py): move assistants api endpoints to using 1 pass-through factory function

Reduces code, increases testing coverage

* refactor(router.py): reduce _common_check_available_deployment function size

make code more maintainable - reduce possible errors

* test(router_code_coverage.py): include batch_utils + pattern matching in enforced 100% code coverage

Improves reliability

* fix(router.py): fix model id match model dump
This commit is contained in:
Krish Dholakia 2024-10-16 21:30:25 -07:00 committed by GitHub
parent 891e9001b5
commit e22e8d24ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 407 additions and 244 deletions

View file

@ -24,7 +24,18 @@ import traceback
import uuid
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Tuple,
TypedDict,
Union,
)
import httpx
import openai
@ -520,6 +531,19 @@ class Router:
if self.alerting_config is not None:
self._initialize_alerting()
self.initialize_assistants_endpoint()
def initialize_assistants_endpoint(self):
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
self.adelete_assistant = self.factory_function(litellm.adelete_assistant)
self.aget_assistants = self.factory_function(litellm.aget_assistants)
self.acreate_thread = self.factory_function(litellm.acreate_thread)
self.aget_thread = self.factory_function(litellm.aget_thread)
self.a_add_message = self.factory_function(litellm.a_add_message)
self.aget_messages = self.factory_function(litellm.aget_messages)
self.arun_thread = self.factory_function(litellm.arun_thread)
def validate_fallbacks(self, fallback_param: Optional[List]):
"""
Validate the fallbacks parameter.
@ -2167,7 +2191,6 @@ class Router:
raise e
#### FILES API ####
async def acreate_file(
self,
model: str,
@ -2504,114 +2527,29 @@ class Router:
#### ASSISTANTS API ####
async def acreate_assistants(
self,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Assistant:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.acreate_assistants(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def adelete_assistant(
self,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AssistantDeleted:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.adelete_assistant(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def aget_assistants(
self,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[Assistant]:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.aget_assistants(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def acreate_thread(
self,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.acreate_thread(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def aget_thread(
self,
thread_id: str,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.aget_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
client=client,
def factory_function(self, original_function: Callable):
async def new_function(
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional["AsyncOpenAI"] = None,
**kwargs,
)
):
return await self._pass_through_assistants_endpoint_factory(
original_function=original_function,
custom_llm_provider=custom_llm_provider,
client=client,
**kwargs,
)
async def a_add_message(
return new_function
async def _pass_through_assistants_endpoint_factory(
self,
thread_id: str,
role: Literal["user", "assistant"],
content: str,
attachments: Optional[List[Attachment]] = None,
metadata: Optional[dict] = None,
original_function: Callable,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> OpenAIMessage:
):
"""Internal helper function to pass through the assistants endpoint"""
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
@ -2620,76 +2558,8 @@ class Router:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.a_add_message(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
role=role,
content=content,
attachments=attachments,
metadata=metadata,
client=client,
**kwargs,
)
async def aget_messages(
self,
thread_id: str,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[OpenAIMessage]:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.aget_messages(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
client=client,
**kwargs,
)
async def arun_thread(
self,
thread_id: str,
assistant_id: str,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
additional_instructions: Optional[str] = None,
instructions: Optional[str] = None,
metadata: Optional[dict] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
tools: Optional[Iterable[AssistantToolParam]] = None,
client: Optional[Any] = None,
**kwargs,
) -> Run:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.arun_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
stream=stream,
tools=tools,
client=client,
**kwargs,
return await original_function( # type: ignore
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
#### [END] ASSISTANTS API ####
@ -4609,7 +4479,7 @@ class Router:
"""
returned_models: List[DeploymentTypedDict] = []
for model in self.model_list:
if model["model_name"] == model_name:
if model_name is not None and model["model_name"] == model_name:
if model_alias is not None:
alias_model = copy.deepcopy(model)
alias_model["model_name"] = model_alias
@ -5007,82 +4877,73 @@ class Router:
return _returned_deployments
def _get_model_from_alias(self, model: str) -> Optional[str]:
"""
Get the model from the alias.
Returns:
- str, the litellm model name
- None, if model is not in model group alias
"""
if model not in self.model_group_alias:
return None
_item = self.model_group_alias[model]
if isinstance(_item, str):
model = _item
else:
model = _item["model"]
return model
def _get_deployment_by_litellm_model(self, model: str) -> List:
"""
Get the deployment by litellm model.
"""
return [m for m in self.model_list if m["litellm_params"]["model"] == model]
def _common_checks_available_deployment(
self,
model: str,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
specific_deployment: Optional[bool] = False,
) -> Tuple[str, Union[list, dict]]:
) -> Tuple[str, Union[List, Dict]]:
"""
Common checks for 'get_available_deployment' across sync + async call.
If 'healthy_deployments' returned is None, this means the user chose a specific deployment
Returns
- Dict, if specific model chosen
- str, the litellm model name
- List, if multiple models chosen
- Dict, if specific model chosen
"""
# check if aliases set on litellm model alias map
if specific_deployment is True:
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
for deployment in self.model_list:
deployment_model = deployment.get("litellm_params").get("model")
if deployment_model == model:
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
# return the first deployment where the `model` matches the specificed deployment name
return deployment_model, deployment
raise ValueError(
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.get_model_names()}"
)
return model, self._get_deployment_by_litellm_model(model=model)
elif model in self.get_model_ids():
deployment = self.get_model_info(id=model)
deployment = self.get_deployment(model_id=model)
if deployment is not None:
deployment_model = deployment.get("litellm_params", {}).get("model")
return deployment_model, deployment
deployment_model = deployment.litellm_params.model
return deployment_model, deployment.model_dump(exclude_none=True)
raise ValueError(
f"LiteLLM Router: Trying to call specific deployment, but Model ID :{model} does not exist in \
Model ID List: {self.get_model_ids}"
)
if model in self.model_group_alias:
_item = self.model_group_alias[model]
if isinstance(_item, str):
model = _item
else:
model = _item["model"]
_model_from_alias = self._get_model_from_alias(model=model)
if _model_from_alias is not None:
model = _model_from_alias
if model not in self.model_names:
# check if provider/ specific wildcard routing use pattern matching
custom_llm_provider: Optional[str] = None
try:
(
_,
custom_llm_provider,
_,
_,
) = litellm.get_llm_provider(model=model)
except Exception:
# get_llm_provider raises exception when provider is unknown
pass
"""
self.pattern_router.route(model):
does exact pattern matching. Example openai/gpt-3.5-turbo gets routed to pattern openai/*
self.pattern_router.route(f"{custom_llm_provider}/{model}"):
does pattern matching using litellm.get_llm_provider(), example claude-3-5-sonnet-20240620 gets routed to anthropic/* since 'claude-3-5-sonnet-20240620' is an Anthropic Model
"""
_pattern_router_response = self.pattern_router.route(
model
) or self.pattern_router.route(f"{custom_llm_provider}/{model}")
if _pattern_router_response is not None:
provider_deployments = []
for deployment in _pattern_router_response:
dep = copy.deepcopy(deployment)
dep["litellm_params"]["model"] = model
provider_deployments.append(dep)
return model, provider_deployments
pattern_deployments = self.pattern_router.get_deployments_by_pattern(
model=model,
)
if pattern_deployments:
return model, pattern_deployments
# check if default deployment is set
if self.default_deployment is not None:
@ -5094,12 +4955,11 @@ class Router:
## get healthy deployments
### get all deployments
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
healthy_deployments = self._get_all_deployments(model_name=model)
if len(healthy_deployments) == 0:
# check if the user sent in a deployment name instead
healthy_deployments = [
m for m in self.model_list if m["litellm_params"]["model"] == model
]
healthy_deployments = self._get_deployment_by_litellm_model(model=model)
verbose_router_logger.debug(
f"initial list of deployments: {healthy_deployments}"
@ -5151,7 +5011,6 @@ class Router:
input=input,
specific_deployment=specific_deployment,
) # type: ignore
if isinstance(healthy_deployments, dict):
return healthy_deployments