fix(router.py): support fallbacks / retries with sync embedding calls

This commit is contained in:
Krrish Dholakia 2024-03-11 14:51:22 -07:00
parent e07174736f
commit 9735250db7
3 changed files with 181 additions and 37 deletions

View file

@ -970,44 +970,81 @@ class Router:
is_async: Optional[bool] = False, is_async: Optional[bool] = False,
**kwargs, **kwargs,
) -> Union[List[float], None]: ) -> Union[List[float], None]:
# pick the one that is available (lowest TPM/RPM) try:
deployment = self.get_available_deployment( kwargs["model"] = model
model=model, kwargs["input"] = input
input=input, kwargs["original_function"] = self._embedding
specific_deployment=kwargs.pop("specific_deployment", None), kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
) timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("model_info", {}) kwargs.setdefault("metadata", {}).update({"model_group": model})
kwargs.setdefault("metadata", {}).update( response = self.function_with_fallbacks(**kwargs)
{"model_group": model, "deployment": deployment["litellm_params"]["model"]} return response
) # [TODO]: move to using async_function_with_fallbacks except Exception as e:
data = deployment["litellm_params"].copy() raise e
for k, v in self.default_litellm_params.items():
def _embedding(self, input: Union[str, List], model: str, **kwargs):
try:
verbose_router_logger.debug(
f"Inside embedding()- model: {model}; kwargs: {kwargs}"
)
deployment = self.get_available_deployment(
model=model,
input=input,
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="sync"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if ( if (
k not in kwargs dynamic_api_key is not None
): # prioritize model-specific params > default router params and potential_model_client is not None
kwargs[k] = v and dynamic_api_key != potential_model_client.api_key
elif k == "metadata": ):
kwargs[k].update(v) model_client = None
potential_model_client = self._get_client(deployment=deployment, kwargs=kwargs) else:
# check if provided keys == client keys # model_client = potential_model_client
dynamic_api_key = kwargs.get("api_key", None)
if ( self.total_calls[model_name] += 1
dynamic_api_key is not None response = litellm.embedding(
and potential_model_client is not None **{
and dynamic_api_key != potential_model_client.api_key **data,
): "input": input,
model_client = None "caching": self.cache_responses,
else: "client": model_client,
model_client = potential_model_client **kwargs,
return litellm.embedding( }
**{ )
**data, self.success_calls[model_name] += 1
"input": input, verbose_router_logger.info(
"caching": self.cache_responses, f"litellm.embedding(model={model_name})\033[32m 200 OK\033[0m"
"client": model_client, )
**kwargs, return response
} except Exception as e:
) verbose_router_logger.info(
f"litellm.embedding(model={model_name})\033[31m Exception {str(e)}\033[0m"
)
if model_name is not None:
self.fail_calls[model_name] += 1
raise e
async def aembedding( async def aembedding(
self, self,

View file

@ -227,6 +227,57 @@ async def test_async_fallbacks():
# test_async_fallbacks() # test_async_fallbacks()
def test_sync_fallbacks_embeddings():
litellm.set_verbose = False
model_list = [
{ # list of model deployments
"model_name": "bad-azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{ # list of model deployments
"model_name": "good-azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
]
router = Router(
model_list=model_list,
fallbacks=[{"bad-azure-embedding-model": ["good-azure-embedding-model"]}],
set_verbose=False,
)
customHandler = MyCustomHandler()
litellm.callbacks = [customHandler]
user_message = "Hello, how are you?"
input = [user_message]
try:
kwargs = {"model": "bad-azure-embedding-model", "input": input}
response = router.embedding(**kwargs)
print(f"customHandler.previous_models: {customHandler.previous_models}")
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
router.reset()
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
finally:
router.reset()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_fallbacks_embeddings(): async def test_async_fallbacks_embeddings():
litellm.set_verbose = False litellm.set_verbose = False

View file

@ -0,0 +1,56 @@
# [LOCAL TEST] - runs against mock openai proxy
# # What this tests?
# ## This tests if fallbacks works for 429 errors
# import sys, os, time
# import traceback, asyncio
# import pytest
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import litellm
# from litellm import Router
# model_list = [
# { # list of model deployments
# "model_name": "text-embedding-ada-002", # model alias
# "litellm_params": { # params for litellm completion/embedding call
# "model": "text-embedding-ada-002", # actual model name
# "api_key": "sk-fakekey",
# "api_base": "http://0.0.0.0:8080",
# },
# "tpm": 1000,
# "rpm": 6,
# },
# {
# "model_name": "text-embedding-ada-002-fallback",
# "litellm_params": { # params for litellm completion/embedding call
# "model": "openai/text-embedding-ada-002-anything-else", # actual model name
# "api_key": "sk-fakekey2",
# "api_base": "http://0.0.0.0:8080",
# },
# "tpm": 1000,
# "rpm": 6,
# },
# ]
# router = Router(
# model_list=model_list,
# fallbacks=[
# {"text-embedding-ada-002": ["text-embedding-ada-002-fallback"]},
# {"text-embedding-ada-002-fallback": ["text-embedding-ada-002"]},
# ],
# set_verbose=True,
# num_retries=0,
# debug_level="INFO",
# routing_strategy="usage-based-routing",
# )
# def test_embedding_with_fallbacks():
# response = router.embedding(model="text-embedding-ada-002", input=["Hello world"])
# print(f"response: {response}")
# test_embedding_with_fallbacks()