diff --git a/litellm/router.py b/litellm/router.py index 08efbc414..b4589c9f0 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -69,7 +69,7 @@ from litellm.types.router import ( AlertingConfig, AllowedFailsPolicy, AssistantsTypedDict, - CustomRoutingStrategy, + CustomRoutingStrategyBase, Deployment, DeploymentTypedDict, LiteLLM_Params, @@ -4815,7 +4815,18 @@ class Router: except Exception as e: pass - def set_custom_routing_strategy(self, CustomRoutingStrategy: CustomRoutingStrategy): + def set_custom_routing_strategy( + self, CustomRoutingStrategy: CustomRoutingStrategyBase + ): + """ + Sets get_available_deployment and async_get_available_deployment on an instanced of litellm.Router + + Use this to set your custom routing strategy + + Args: + CustomRoutingStrategy: litellm.router.CustomRoutingStrategyBase + """ + setattr( self, "get_available_deployment", diff --git a/litellm/tests/test_router_custom_routing.py b/litellm/tests/test_router_custom_routing.py index d66c304be..afd602b93 100644 --- a/litellm/tests/test_router_custom_routing.py +++ b/litellm/tests/test_router_custom_routing.py @@ -21,10 +21,6 @@ import pytest import litellm from litellm import Router -from litellm.caching import DualCache -from litellm.router import CustomRoutingStrategy as BaseCustomRoutingStrategy -from litellm.router import Deployment, LiteLLM_Params -from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler router = Router( model_list=[ @@ -49,11 +45,12 @@ router = Router( ], set_verbose=True, debug_level="DEBUG", - timeout=1, -) # type: ignore +) + +from litellm.router import CustomRoutingStrategyBase -class CustomRoutingStrategy(BaseCustomRoutingStrategy): +class CustomRoutingStrategy(CustomRoutingStrategyBase): async def async_get_available_deployment( self, model: str, @@ -62,6 +59,20 @@ class CustomRoutingStrategy(BaseCustomRoutingStrategy): specific_deployment: Optional[bool] = False, request_kwargs: Optional[Dict] = None, ): + """ + Asynchronously retrieves the available deployment based on the given parameters. + + Args: + model (str): The name of the model. + messages (Optional[List[Dict[str, str]]], optional): The list of messages for a given request. Defaults to None. + input (Optional[Union[str, List]], optional): The input for a given embedding request. Defaults to None. + specific_deployment (Optional[bool], optional): Whether to retrieve a specific deployment. Defaults to False. + request_kwargs (Optional[Dict], optional): Additional request keyword arguments. Defaults to None. + + Returns: + Returns an element from litellm.router.model_list + + """ print("In CUSTOM async get available deployment") model_list = router.model_list print("router model list=", model_list) @@ -79,7 +90,20 @@ class CustomRoutingStrategy(BaseCustomRoutingStrategy): specific_deployment: Optional[bool] = False, request_kwargs: Optional[Dict] = None, ): - # used for router.completion() calls + """ + Synchronously retrieves the available deployment based on the given parameters. + + Args: + model (str): The name of the model. + messages (Optional[List[Dict[str, str]]], optional): The list of messages for a given request. Defaults to None. + input (Optional[Union[str, List]], optional): The input for a given embedding request. Defaults to None. + specific_deployment (Optional[bool], optional): Whether to retrieve a specific deployment. Defaults to False. + request_kwargs (Optional[Dict], optional): Additional request keyword arguments. Defaults to None. + + Returns: + Returns an element from litellm.router.model_list + + """ pass diff --git a/litellm/types/router.py b/litellm/types/router.py index 25b1b5c9c..206216ef0 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -453,7 +453,7 @@ class AssistantsTypedDict(TypedDict): litellm_params: LiteLLMParamsTypedDict -class CustomRoutingStrategy: +class CustomRoutingStrategyBase: async def async_get_available_deployment( self, model: str, @@ -462,6 +462,20 @@ class CustomRoutingStrategy: specific_deployment: Optional[bool] = False, request_kwargs: Optional[Dict] = None, ): + """ + Asynchronously retrieves the available deployment based on the given parameters. + + Args: + model (str): The name of the model. + messages (Optional[List[Dict[str, str]]], optional): The list of messages for a given request. Defaults to None. + input (Optional[Union[str, List]], optional): The input for a given embedding request. Defaults to None. + specific_deployment (Optional[bool], optional): Whether to retrieve a specific deployment. Defaults to False. + request_kwargs (Optional[Dict], optional): Additional request keyword arguments. Defaults to None. + + Returns: + Returns an element from litellm.router.model_list + + """ pass def get_available_deployment( @@ -472,4 +486,18 @@ class CustomRoutingStrategy: specific_deployment: Optional[bool] = False, request_kwargs: Optional[Dict] = None, ): + """ + Synchronously retrieves the available deployment based on the given parameters. + + Args: + model (str): The name of the model. + messages (Optional[List[Dict[str, str]]], optional): The list of messages for a given request. Defaults to None. + input (Optional[Union[str, List]], optional): The input for a given embedding request. Defaults to None. + specific_deployment (Optional[bool], optional): Whether to retrieve a specific deployment. Defaults to False. + request_kwargs (Optional[Dict], optional): Additional request keyword arguments. Defaults to None. + + Returns: + Returns an element from litellm.router.model_list + + """ pass