forked from phoenix/litellm-mirror
feat(router.py): add support for retry/fallbacks for async embedding calls
This commit is contained in:
parent
c12e3bd565
commit
a37a18ca80
4 changed files with 124 additions and 102 deletions
|
@ -686,48 +686,72 @@ class Router:
|
||||||
is_async: Optional[bool] = True,
|
is_async: Optional[bool] = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[List[float], None]:
|
) -> Union[List[float], None]:
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
try:
|
||||||
deployment = self.get_available_deployment(
|
kwargs["model"] = model
|
||||||
model=model,
|
kwargs["input"] = input
|
||||||
input=input,
|
kwargs["original_function"] = self._aembedding
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
)
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
kwargs.setdefault("metadata", {}).update(
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
{"model_group": model, "deployment": deployment["litellm_params"]["model"]}
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
)
|
return response
|
||||||
data = deployment["litellm_params"].copy()
|
except Exception as e:
|
||||||
kwargs["model_info"] = deployment.get("model_info", {})
|
raise e
|
||||||
for k, v in self.default_litellm_params.items():
|
|
||||||
|
async def _aembedding(self, input: Union[str, List], model: str, **kwargs):
|
||||||
|
try:
|
||||||
|
self.print_verbose(
|
||||||
|
f"Inside _aembedding()- model: {model}; kwargs: {kwargs}"
|
||||||
|
)
|
||||||
|
deployment = self.get_available_deployment(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
)
|
||||||
|
kwargs.setdefault("metadata", {}).update(
|
||||||
|
{"deployment": deployment["litellm_params"]["model"]}
|
||||||
|
)
|
||||||
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
|
data = deployment["litellm_params"].copy()
|
||||||
|
model_name = data["model"]
|
||||||
|
for k, v in self.default_litellm_params.items():
|
||||||
|
if (
|
||||||
|
k not in kwargs
|
||||||
|
): # prioritize model-specific params > default router params
|
||||||
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
|
kwargs[k].update(v)
|
||||||
|
|
||||||
|
potential_model_client = self._get_client(
|
||||||
|
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||||
|
)
|
||||||
|
# check if provided keys == client keys #
|
||||||
|
dynamic_api_key = kwargs.get("api_key", None)
|
||||||
if (
|
if (
|
||||||
k not in kwargs
|
dynamic_api_key is not None
|
||||||
): # prioritize model-specific params > default router params
|
and potential_model_client is not None
|
||||||
kwargs[k] = v
|
and dynamic_api_key != potential_model_client.api_key
|
||||||
elif k == "metadata":
|
):
|
||||||
kwargs[k].update(v)
|
model_client = None
|
||||||
|
else:
|
||||||
|
model_client = potential_model_client
|
||||||
|
|
||||||
potential_model_client = self._get_client(
|
self.total_calls[model_name] += 1
|
||||||
deployment=deployment, kwargs=kwargs, client_type="async"
|
response = await litellm.aembedding(
|
||||||
)
|
**{
|
||||||
# check if provided keys == client keys #
|
**data,
|
||||||
dynamic_api_key = kwargs.get("api_key", None)
|
"input": input,
|
||||||
if (
|
"caching": self.cache_responses,
|
||||||
dynamic_api_key is not None
|
"client": model_client,
|
||||||
and potential_model_client is not None
|
**kwargs,
|
||||||
and dynamic_api_key != potential_model_client.api_key
|
}
|
||||||
):
|
)
|
||||||
model_client = None
|
self.success_calls[model_name] += 1
|
||||||
else:
|
return response
|
||||||
model_client = potential_model_client
|
except Exception as e:
|
||||||
|
if model_name is not None:
|
||||||
return await litellm.aembedding(
|
self.fail_calls[model_name] += 1
|
||||||
**{
|
raise e
|
||||||
**data,
|
|
||||||
"input": input,
|
|
||||||
"caching": self.cache_responses,
|
|
||||||
"client": model_client,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def async_function_with_fallbacks(self, *args, **kwargs):
|
async def async_function_with_fallbacks(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -1200,65 +1224,6 @@ class Router:
|
||||||
self.print_verbose(f"retrieve cooldown models: {cooldown_models}")
|
self.print_verbose(f"retrieve cooldown models: {cooldown_models}")
|
||||||
return cooldown_models
|
return cooldown_models
|
||||||
|
|
||||||
def _start_health_check_thread(self):
|
|
||||||
"""
|
|
||||||
Starts a separate thread to perform health checks periodically.
|
|
||||||
"""
|
|
||||||
health_check_thread = threading.Thread(
|
|
||||||
target=self._perform_health_checks, daemon=True
|
|
||||||
)
|
|
||||||
health_check_thread.start()
|
|
||||||
|
|
||||||
def _perform_health_checks(self):
|
|
||||||
"""
|
|
||||||
Periodically performs health checks on the servers.
|
|
||||||
Updates the list of healthy servers accordingly.
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
self.healthy_deployments = self._health_check()
|
|
||||||
# Adjust the time interval based on your needs
|
|
||||||
time.sleep(15)
|
|
||||||
|
|
||||||
def _health_check(self):
|
|
||||||
"""
|
|
||||||
Performs a health check on the deployments
|
|
||||||
Returns the list of healthy deployments
|
|
||||||
"""
|
|
||||||
healthy_deployments = []
|
|
||||||
for deployment in self.model_list:
|
|
||||||
litellm_args = deployment["litellm_params"]
|
|
||||||
try:
|
|
||||||
start_time = time.time()
|
|
||||||
litellm.completion(
|
|
||||||
messages=[{"role": "user", "content": ""}],
|
|
||||||
max_tokens=1,
|
|
||||||
**litellm_args,
|
|
||||||
) # hit the server with a blank message to see how long it takes to respond
|
|
||||||
end_time = time.time()
|
|
||||||
response_time = end_time - start_time
|
|
||||||
logging.debug(f"response_time: {response_time}")
|
|
||||||
healthy_deployments.append((deployment, response_time))
|
|
||||||
healthy_deployments.sort(key=lambda x: x[1])
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
return healthy_deployments
|
|
||||||
|
|
||||||
def weighted_shuffle_by_latency(self, items):
|
|
||||||
# Sort the items by latency
|
|
||||||
sorted_items = sorted(items, key=lambda x: x[1])
|
|
||||||
# Get only the latencies
|
|
||||||
latencies = [i[1] for i in sorted_items]
|
|
||||||
# Calculate the sum of all latencies
|
|
||||||
total_latency = sum(latencies)
|
|
||||||
# Calculate the weight for each latency (lower latency = higher weight)
|
|
||||||
weights = [total_latency - latency for latency in latencies]
|
|
||||||
# Get a weighted random item
|
|
||||||
if sum(weights) == 0:
|
|
||||||
chosen_item = random.choice(sorted_items)[0]
|
|
||||||
else:
|
|
||||||
chosen_item = random.choices(sorted_items, weights=weights, k=1)[0][0]
|
|
||||||
return chosen_item
|
|
||||||
|
|
||||||
def set_client(self, model: dict):
|
def set_client(self, model: dict):
|
||||||
"""
|
"""
|
||||||
Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
||||||
|
|
|
@ -151,7 +151,6 @@ class LowestTPMLoggingHandler(CustomLogger):
|
||||||
## if healthy deployment not yet used
|
## if healthy deployment not yet used
|
||||||
if d["model_info"]["id"] not in all_deployments:
|
if d["model_info"]["id"] not in all_deployments:
|
||||||
all_deployments[d["model_info"]["id"]] = 0
|
all_deployments[d["model_info"]["id"]] = 0
|
||||||
|
|
||||||
input_tokens = token_counter(messages=messages, text=input)
|
input_tokens = token_counter(messages=messages, text=input)
|
||||||
for item, item_tpm in all_deployments.items():
|
for item, item_tpm in all_deployments.items():
|
||||||
## get the item from model list
|
## get the item from model list
|
||||||
|
|
|
@ -227,6 +227,60 @@ async def test_async_fallbacks():
|
||||||
# test_async_fallbacks()
|
# test_async_fallbacks()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_fallbacks_embeddings():
|
||||||
|
litellm.set_verbose = False
|
||||||
|
model_list = [
|
||||||
|
{ # list of model deployments
|
||||||
|
"model_name": "bad-azure-embedding-model", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/azure-embedding-model",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800,
|
||||||
|
},
|
||||||
|
{ # list of model deployments
|
||||||
|
"model_name": "good-azure-embedding-model", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/azure-embedding-model",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
fallbacks=[{"bad-azure-embedding-model": ["good-azure-embedding-model"]}],
|
||||||
|
set_verbose=False,
|
||||||
|
)
|
||||||
|
customHandler = MyCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
user_message = "Hello, how are you?"
|
||||||
|
input = [user_message]
|
||||||
|
try:
|
||||||
|
kwargs = {"model": "bad-azure-embedding-model", "input": input}
|
||||||
|
response = await router.aembedding(**kwargs)
|
||||||
|
print(f"customHandler.previous_models: {customHandler.previous_models}")
|
||||||
|
await asyncio.sleep(
|
||||||
|
0.05
|
||||||
|
) # allow a delay as success_callbacks are on a separate thread
|
||||||
|
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
||||||
|
router.reset()
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {e}")
|
||||||
|
finally:
|
||||||
|
router.reset()
|
||||||
|
|
||||||
|
|
||||||
def test_dynamic_fallbacks_sync():
|
def test_dynamic_fallbacks_sync():
|
||||||
"""
|
"""
|
||||||
Allow setting the fallback in the router.completion() call.
|
Allow setting the fallback in the router.completion() call.
|
||||||
|
|
|
@ -122,12 +122,16 @@ def test_get_available_deployments():
|
||||||
## CHECK WHAT'S SELECTED ##
|
## CHECK WHAT'S SELECTED ##
|
||||||
print(
|
print(
|
||||||
lowest_tpm_logger.get_available_deployments(
|
lowest_tpm_logger.get_available_deployments(
|
||||||
model_group=model_group, healthy_deployments=model_list
|
model_group=model_group,
|
||||||
|
healthy_deployments=model_list,
|
||||||
|
input=["Hello world"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
lowest_tpm_logger.get_available_deployments(
|
lowest_tpm_logger.get_available_deployments(
|
||||||
model_group=model_group, healthy_deployments=model_list
|
model_group=model_group,
|
||||||
|
healthy_deployments=model_list,
|
||||||
|
input=["Hello world"],
|
||||||
)["model_info"]["id"]
|
)["model_info"]["id"]
|
||||||
== "5678"
|
== "5678"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue