forked from phoenix/litellm-mirror
fix(router.py): support fallbacks / retries with sync embedding calls
This commit is contained in:
parent
e07174736f
commit
9735250db7
3 changed files with 181 additions and 37 deletions
|
@ -970,44 +970,81 @@ class Router:
|
||||||
is_async: Optional[bool] = False,
|
is_async: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[List[float], None]:
|
) -> Union[List[float], None]:
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
try:
|
||||||
deployment = self.get_available_deployment(
|
kwargs["model"] = model
|
||||||
model=model,
|
kwargs["input"] = input
|
||||||
input=input,
|
kwargs["original_function"] = self._embedding
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
)
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
kwargs.setdefault("model_info", {})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
kwargs.setdefault("metadata", {}).update(
|
response = self.function_with_fallbacks(**kwargs)
|
||||||
{"model_group": model, "deployment": deployment["litellm_params"]["model"]}
|
return response
|
||||||
) # [TODO]: move to using async_function_with_fallbacks
|
except Exception as e:
|
||||||
data = deployment["litellm_params"].copy()
|
raise e
|
||||||
for k, v in self.default_litellm_params.items():
|
|
||||||
|
def _embedding(self, input: Union[str, List], model: str, **kwargs):
|
||||||
|
try:
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"Inside embedding()- model: {model}; kwargs: {kwargs}"
|
||||||
|
)
|
||||||
|
deployment = self.get_available_deployment(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
)
|
||||||
|
kwargs.setdefault("metadata", {}).update(
|
||||||
|
{
|
||||||
|
"deployment": deployment["litellm_params"]["model"],
|
||||||
|
"model_info": deployment.get("model_info", {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
|
data = deployment["litellm_params"].copy()
|
||||||
|
model_name = data["model"]
|
||||||
|
for k, v in self.default_litellm_params.items():
|
||||||
|
if (
|
||||||
|
k not in kwargs
|
||||||
|
): # prioritize model-specific params > default router params
|
||||||
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
|
kwargs[k].update(v)
|
||||||
|
|
||||||
|
potential_model_client = self._get_client(
|
||||||
|
deployment=deployment, kwargs=kwargs, client_type="sync"
|
||||||
|
)
|
||||||
|
# check if provided keys == client keys #
|
||||||
|
dynamic_api_key = kwargs.get("api_key", None)
|
||||||
if (
|
if (
|
||||||
k not in kwargs
|
dynamic_api_key is not None
|
||||||
): # prioritize model-specific params > default router params
|
and potential_model_client is not None
|
||||||
kwargs[k] = v
|
and dynamic_api_key != potential_model_client.api_key
|
||||||
elif k == "metadata":
|
):
|
||||||
kwargs[k].update(v)
|
model_client = None
|
||||||
potential_model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
else:
|
||||||
# check if provided keys == client keys #
|
model_client = potential_model_client
|
||||||
dynamic_api_key = kwargs.get("api_key", None)
|
|
||||||
if (
|
self.total_calls[model_name] += 1
|
||||||
dynamic_api_key is not None
|
response = litellm.embedding(
|
||||||
and potential_model_client is not None
|
**{
|
||||||
and dynamic_api_key != potential_model_client.api_key
|
**data,
|
||||||
):
|
"input": input,
|
||||||
model_client = None
|
"caching": self.cache_responses,
|
||||||
else:
|
"client": model_client,
|
||||||
model_client = potential_model_client
|
**kwargs,
|
||||||
return litellm.embedding(
|
}
|
||||||
**{
|
)
|
||||||
**data,
|
self.success_calls[model_name] += 1
|
||||||
"input": input,
|
verbose_router_logger.info(
|
||||||
"caching": self.cache_responses,
|
f"litellm.embedding(model={model_name})\033[32m 200 OK\033[0m"
|
||||||
"client": model_client,
|
)
|
||||||
**kwargs,
|
return response
|
||||||
}
|
except Exception as e:
|
||||||
)
|
verbose_router_logger.info(
|
||||||
|
f"litellm.embedding(model={model_name})\033[31m Exception {str(e)}\033[0m"
|
||||||
|
)
|
||||||
|
if model_name is not None:
|
||||||
|
self.fail_calls[model_name] += 1
|
||||||
|
raise e
|
||||||
|
|
||||||
async def aembedding(
|
async def aembedding(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -227,6 +227,57 @@ async def test_async_fallbacks():
|
||||||
# test_async_fallbacks()
|
# test_async_fallbacks()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_fallbacks_embeddings():
|
||||||
|
litellm.set_verbose = False
|
||||||
|
model_list = [
|
||||||
|
{ # list of model deployments
|
||||||
|
"model_name": "bad-azure-embedding-model", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/azure-embedding-model",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800,
|
||||||
|
},
|
||||||
|
{ # list of model deployments
|
||||||
|
"model_name": "good-azure-embedding-model", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/azure-embedding-model",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
fallbacks=[{"bad-azure-embedding-model": ["good-azure-embedding-model"]}],
|
||||||
|
set_verbose=False,
|
||||||
|
)
|
||||||
|
customHandler = MyCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
user_message = "Hello, how are you?"
|
||||||
|
input = [user_message]
|
||||||
|
try:
|
||||||
|
kwargs = {"model": "bad-azure-embedding-model", "input": input}
|
||||||
|
response = router.embedding(**kwargs)
|
||||||
|
print(f"customHandler.previous_models: {customHandler.previous_models}")
|
||||||
|
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
||||||
|
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
||||||
|
router.reset()
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {e}")
|
||||||
|
finally:
|
||||||
|
router.reset()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_fallbacks_embeddings():
|
async def test_async_fallbacks_embeddings():
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
|
56
litellm/tests/test_router_with_fallbacks.py
Normal file
56
litellm/tests/test_router_with_fallbacks.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
# [LOCAL TEST] - runs against mock openai proxy
|
||||||
|
# # What this tests?
|
||||||
|
# ## This tests if fallbacks works for 429 errors
|
||||||
|
|
||||||
|
# import sys, os, time
|
||||||
|
# import traceback, asyncio
|
||||||
|
# import pytest
|
||||||
|
|
||||||
|
# sys.path.insert(
|
||||||
|
# 0, os.path.abspath("../..")
|
||||||
|
# ) # Adds the parent directory to the system path
|
||||||
|
# import litellm
|
||||||
|
# from litellm import Router
|
||||||
|
|
||||||
|
# model_list = [
|
||||||
|
# { # list of model deployments
|
||||||
|
# "model_name": "text-embedding-ada-002", # model alias
|
||||||
|
# "litellm_params": { # params for litellm completion/embedding call
|
||||||
|
# "model": "text-embedding-ada-002", # actual model name
|
||||||
|
# "api_key": "sk-fakekey",
|
||||||
|
# "api_base": "http://0.0.0.0:8080",
|
||||||
|
# },
|
||||||
|
# "tpm": 1000,
|
||||||
|
# "rpm": 6,
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "model_name": "text-embedding-ada-002-fallback",
|
||||||
|
# "litellm_params": { # params for litellm completion/embedding call
|
||||||
|
# "model": "openai/text-embedding-ada-002-anything-else", # actual model name
|
||||||
|
# "api_key": "sk-fakekey2",
|
||||||
|
# "api_base": "http://0.0.0.0:8080",
|
||||||
|
# },
|
||||||
|
# "tpm": 1000,
|
||||||
|
# "rpm": 6,
|
||||||
|
# },
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# router = Router(
|
||||||
|
# model_list=model_list,
|
||||||
|
# fallbacks=[
|
||||||
|
# {"text-embedding-ada-002": ["text-embedding-ada-002-fallback"]},
|
||||||
|
# {"text-embedding-ada-002-fallback": ["text-embedding-ada-002"]},
|
||||||
|
# ],
|
||||||
|
# set_verbose=True,
|
||||||
|
# num_retries=0,
|
||||||
|
# debug_level="INFO",
|
||||||
|
# routing_strategy="usage-based-routing",
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
# def test_embedding_with_fallbacks():
|
||||||
|
# response = router.embedding(model="text-embedding-ada-002", input=["Hello world"])
|
||||||
|
# print(f"response: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
# test_embedding_with_fallbacks()
|
Loading…
Add table
Add a link
Reference in a new issue