diff --git a/docs/my-website/docs/routing.md b/docs/my-website/docs/routing.md
index 63fac9456..fd4fb8658 100644
--- a/docs/my-website/docs/routing.md
+++ b/docs/my-website/docs/routing.md
@@ -95,7 +95,7 @@ print(response)
- `router.image_generation()` - completion calls in OpenAI `/v1/images/generations` endpoint format
- `router.aimage_generation()` - async image generation calls
-## Advanced - Routing Strategies
+## Advanced - Routing Strategies ⭐️
#### Routing Strategies - Weighted Pick, Rate Limit Aware, Least Busy, Latency Based, Cost Based
Router provides 4 strategies for routing your calls across multiple deployments:
@@ -262,7 +262,7 @@ if response is not None:
)
```
-### Set Time Window
+#### Set Time Window
Set time window for how far back to consider when averaging latency for a deployment.
@@ -278,7 +278,7 @@ router_settings:
routing_strategy_args: {"ttl": 10}
```
-### Set Lowest Latency Buffer
+#### Set Lowest Latency Buffer
Set a buffer within which deployments are candidates for making calls to.
@@ -468,6 +468,122 @@ asyncio.run(router_acompletion())
```
+
+
+
+**Plugin a custom routing strategy to select deployments**
+
+
+Step 1. Define your custom routing strategy
+
+```python
+
+from litellm.router import CustomRoutingStrategyBase
+class CustomRoutingStrategy(CustomRoutingStrategyBase):
+ async def async_get_available_deployment(
+ self,
+ model: str,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ specific_deployment: Optional[bool] = False,
+ request_kwargs: Optional[Dict] = None,
+ ):
+ """
+ Asynchronously retrieves the available deployment based on the given parameters.
+
+ Args:
+ model (str): The name of the model.
+ messages (Optional[List[Dict[str, str]]], optional): The list of messages for a given request. Defaults to None.
+ input (Optional[Union[str, List]], optional): The input for a given embedding request. Defaults to None.
+ specific_deployment (Optional[bool], optional): Whether to retrieve a specific deployment. Defaults to False.
+ request_kwargs (Optional[Dict], optional): Additional request keyword arguments. Defaults to None.
+
+ Returns:
+ Returns an element from litellm.router.model_list
+
+ """
+ print("In CUSTOM async get available deployment")
+ model_list = router.model_list
+ print("router model list=", model_list)
+ for model in model_list:
+ if isinstance(model, dict):
+ if model["litellm_params"]["model"] == "openai/very-special-endpoint":
+ return model
+ pass
+
+ def get_available_deployment(
+ self,
+ model: str,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ specific_deployment: Optional[bool] = False,
+ request_kwargs: Optional[Dict] = None,
+ ):
+ """
+ Synchronously retrieves the available deployment based on the given parameters.
+
+ Args:
+ model (str): The name of the model.
+ messages (Optional[List[Dict[str, str]]], optional): The list of messages for a given request. Defaults to None.
+ input (Optional[Union[str, List]], optional): The input for a given embedding request. Defaults to None.
+ specific_deployment (Optional[bool], optional): Whether to retrieve a specific deployment. Defaults to False.
+ request_kwargs (Optional[Dict], optional): Additional request keyword arguments. Defaults to None.
+
+ Returns:
+ Returns an element from litellm.router.model_list
+
+ """
+ pass
+```
+
+Step 2. Initialize Router with custom routing strategy
+```python
+from litellm import Router
+
+router = Router(
+ model_list=[
+ {
+ "model_name": "azure-model",
+ "litellm_params": {
+ "model": "openai/very-special-endpoint",
+ "api_base": "https://exampleopenaiendpoint-production.up.railway.app/", # If you are Krrish, this is OpenAI Endpoint3 on our Railway endpoint :)
+ "api_key": "fake-key",
+ },
+ "model_info": {"id": "very-special-endpoint"},
+ },
+ {
+ "model_name": "azure-model",
+ "litellm_params": {
+ "model": "openai/fast-endpoint",
+ "api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
+ "api_key": "fake-key",
+ },
+ "model_info": {"id": "fast-endpoint"},
+ },
+ ],
+ set_verbose=True,
+ debug_level="DEBUG",
+ timeout=1,
+) # type: ignore
+
+router.set_custom_routing_strategy(CustomRoutingStrategy()) # 👈 Set your routing strategy here
+```
+
+Step 3. Test your routing strategy. Expect your custom routing strategy to be called when running `router.acompletion` requests
+```python
+for _ in range(10):
+ response = await router.acompletion(
+ model="azure-model", messages=[{"role": "user", "content": "hello"}]
+ )
+ print(response)
+ _picked_model_id = response._hidden_params["model_id"]
+ print("picked model=", _picked_model_id)
+```
+
+
+
+
+
Picks a deployment based on the lowest cost
@@ -563,7 +679,6 @@ asyncio.run(router_acompletion())
```
-
## Basic Reliability
diff --git a/litellm/router.py b/litellm/router.py
index 9200089d5..b4589c9f0 100644
--- a/litellm/router.py
+++ b/litellm/router.py
@@ -69,6 +69,7 @@ from litellm.types.router import (
AlertingConfig,
AllowedFailsPolicy,
AssistantsTypedDict,
+ CustomRoutingStrategyBase,
Deployment,
DeploymentTypedDict,
LiteLLM_Params,
@@ -4814,6 +4815,29 @@ class Router:
except Exception as e:
pass
+ def set_custom_routing_strategy(
+ self, CustomRoutingStrategy: CustomRoutingStrategyBase
+ ):
+ """
+ Sets get_available_deployment and async_get_available_deployment on an instanced of litellm.Router
+
+ Use this to set your custom routing strategy
+
+ Args:
+ CustomRoutingStrategy: litellm.router.CustomRoutingStrategyBase
+ """
+
+ setattr(
+ self,
+ "get_available_deployment",
+ CustomRoutingStrategy.get_available_deployment,
+ )
+ setattr(
+ self,
+ "async_get_available_deployment",
+ CustomRoutingStrategy.async_get_available_deployment,
+ )
+
def flush_cache(self):
litellm.cache = None
self.cache.flush_cache()
diff --git a/litellm/tests/test_router_custom_routing.py b/litellm/tests/test_router_custom_routing.py
new file mode 100644
index 000000000..afd602b93
--- /dev/null
+++ b/litellm/tests/test_router_custom_routing.py
@@ -0,0 +1,150 @@
+import asyncio
+import os
+import random
+import sys
+import time
+import traceback
+from datetime import datetime, timedelta
+
+from dotenv import load_dotenv
+
+load_dotenv()
+import copy
+import os
+
+sys.path.insert(
+ 0, os.path.abspath("../..")
+) # Adds the parent directory to the system path
+from typing import Dict, List, Optional, Union
+
+import pytest
+
+import litellm
+from litellm import Router
+
+router = Router(
+ model_list=[
+ {
+ "model_name": "azure-model",
+ "litellm_params": {
+ "model": "openai/very-special-endpoint",
+ "api_base": "https://exampleopenaiendpoint-production.up.railway.app/", # If you are Krrish, this is OpenAI Endpoint3 on our Railway endpoint :)
+ "api_key": "fake-key",
+ },
+ "model_info": {"id": "very-special-endpoint"},
+ },
+ {
+ "model_name": "azure-model",
+ "litellm_params": {
+ "model": "openai/fast-endpoint",
+ "api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
+ "api_key": "fake-key",
+ },
+ "model_info": {"id": "fast-endpoint"},
+ },
+ ],
+ set_verbose=True,
+ debug_level="DEBUG",
+)
+
+from litellm.router import CustomRoutingStrategyBase
+
+
+class CustomRoutingStrategy(CustomRoutingStrategyBase):
+ async def async_get_available_deployment(
+ self,
+ model: str,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ specific_deployment: Optional[bool] = False,
+ request_kwargs: Optional[Dict] = None,
+ ):
+ """
+ Asynchronously retrieves the available deployment based on the given parameters.
+
+ Args:
+ model (str): The name of the model.
+ messages (Optional[List[Dict[str, str]]], optional): The list of messages for a given request. Defaults to None.
+ input (Optional[Union[str, List]], optional): The input for a given embedding request. Defaults to None.
+ specific_deployment (Optional[bool], optional): Whether to retrieve a specific deployment. Defaults to False.
+ request_kwargs (Optional[Dict], optional): Additional request keyword arguments. Defaults to None.
+
+ Returns:
+ Returns an element from litellm.router.model_list
+
+ """
+ print("In CUSTOM async get available deployment")
+ model_list = router.model_list
+ print("router model list=", model_list)
+ for model in model_list:
+ if isinstance(model, dict):
+ if model["litellm_params"]["model"] == "openai/very-special-endpoint":
+ return model
+ pass
+
+ def get_available_deployment(
+ self,
+ model: str,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ specific_deployment: Optional[bool] = False,
+ request_kwargs: Optional[Dict] = None,
+ ):
+ """
+ Synchronously retrieves the available deployment based on the given parameters.
+
+ Args:
+ model (str): The name of the model.
+ messages (Optional[List[Dict[str, str]]], optional): The list of messages for a given request. Defaults to None.
+ input (Optional[Union[str, List]], optional): The input for a given embedding request. Defaults to None.
+ specific_deployment (Optional[bool], optional): Whether to retrieve a specific deployment. Defaults to False.
+ request_kwargs (Optional[Dict], optional): Additional request keyword arguments. Defaults to None.
+
+ Returns:
+ Returns an element from litellm.router.model_list
+
+ """
+ pass
+
+
+@pytest.mark.asyncio
+async def test_custom_routing():
+ import litellm
+
+ litellm.set_verbose = True
+ router.set_custom_routing_strategy(CustomRoutingStrategy())
+
+ # make 4 requests
+ for _ in range(4):
+ try:
+ response = await router.acompletion(
+ model="azure-model", messages=[{"role": "user", "content": "hello"}]
+ )
+ print(response)
+ except Exception as e:
+ print("got exception", e)
+
+ await asyncio.sleep(1)
+ print("done sending initial requests to collect latency")
+ """
+ Note: for debugging
+ - By this point: slow-endpoint should have timed out 3-4 times and should be heavily penalized :)
+ - The next 10 requests should all be routed to the fast-endpoint
+ """
+
+ deployments = {}
+ # make 10 requests
+ for _ in range(10):
+ response = await router.acompletion(
+ model="azure-model", messages=[{"role": "user", "content": "hello"}]
+ )
+ print(response)
+ _picked_model_id = response._hidden_params["model_id"]
+ if _picked_model_id not in deployments:
+ deployments[_picked_model_id] = 1
+ else:
+ deployments[_picked_model_id] += 1
+ print("deployments", deployments)
+
+ # ALL the Requests should have been routed to the fast-endpoint
+ # assert deployments["fast-endpoint"] == 10
diff --git a/litellm/types/router.py b/litellm/types/router.py
index da3c999dc..206216ef0 100644
--- a/litellm/types/router.py
+++ b/litellm/types/router.py
@@ -451,3 +451,53 @@ class ModelGroupInfo(BaseModel):
class AssistantsTypedDict(TypedDict):
custom_llm_provider: Literal["azure", "openai"]
litellm_params: LiteLLMParamsTypedDict
+
+
+class CustomRoutingStrategyBase:
+ async def async_get_available_deployment(
+ self,
+ model: str,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ specific_deployment: Optional[bool] = False,
+ request_kwargs: Optional[Dict] = None,
+ ):
+ """
+ Asynchronously retrieves the available deployment based on the given parameters.
+
+ Args:
+ model (str): The name of the model.
+ messages (Optional[List[Dict[str, str]]], optional): The list of messages for a given request. Defaults to None.
+ input (Optional[Union[str, List]], optional): The input for a given embedding request. Defaults to None.
+ specific_deployment (Optional[bool], optional): Whether to retrieve a specific deployment. Defaults to False.
+ request_kwargs (Optional[Dict], optional): Additional request keyword arguments. Defaults to None.
+
+ Returns:
+ Returns an element from litellm.router.model_list
+
+ """
+ pass
+
+ def get_available_deployment(
+ self,
+ model: str,
+ messages: Optional[List[Dict[str, str]]] = None,
+ input: Optional[Union[str, List]] = None,
+ specific_deployment: Optional[bool] = False,
+ request_kwargs: Optional[Dict] = None,
+ ):
+ """
+ Synchronously retrieves the available deployment based on the given parameters.
+
+ Args:
+ model (str): The name of the model.
+ messages (Optional[List[Dict[str, str]]], optional): The list of messages for a given request. Defaults to None.
+ input (Optional[Union[str, List]], optional): The input for a given embedding request. Defaults to None.
+ specific_deployment (Optional[bool], optional): Whether to retrieve a specific deployment. Defaults to False.
+ request_kwargs (Optional[Dict], optional): Additional request keyword arguments. Defaults to None.
+
+ Returns:
+ Returns an element from litellm.router.model_list
+
+ """
+ pass