mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(router.py): support retry and fallbacks for atext_completion
This commit is contained in:
parent
7ecd7b3e8d
commit
38f55249e1
6 changed files with 290 additions and 69 deletions
|
@ -191,7 +191,9 @@ class Router:
|
|||
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||
### ROUTING SETUP ###
|
||||
if routing_strategy == "least-busy":
|
||||
self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache)
|
||||
self.leastbusy_logger = LeastBusyLoggingHandler(
|
||||
router_cache=self.cache, model_list=self.model_list
|
||||
)
|
||||
## add callback
|
||||
if isinstance(litellm.input_callback, list):
|
||||
litellm.input_callback.append(self.leastbusy_logger) # type: ignore
|
||||
|
@ -506,7 +508,13 @@ class Router:
|
|||
**kwargs,
|
||||
):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["prompt"] = prompt
|
||||
kwargs["original_function"] = self._acompletion
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(
|
||||
|
@ -530,7 +538,6 @@ class Router:
|
|||
if self.num_retries > 0:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_exception"] = e
|
||||
kwargs["original_function"] = self.completion
|
||||
return self.function_with_retries(**kwargs)
|
||||
else:
|
||||
|
@ -546,16 +553,34 @@ class Router:
|
|||
**kwargs,
|
||||
):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["prompt"] = prompt
|
||||
kwargs["original_function"] = self._atext_completion
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def _atext_completion(self, model: str, prompt: str, **kwargs):
|
||||
try:
|
||||
self.print_verbose(
|
||||
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
|
||||
)
|
||||
deployment = self.get_available_deployment(
|
||||
model=model,
|
||||
messages=messages,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"deployment": deployment["litellm_params"]["model"]}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs
|
||||
|
@ -564,27 +589,38 @@ class Router:
|
|||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
########## remove -ModelID-XXXX from model ##############
|
||||
original_model_string = data["model"]
|
||||
# Find the index of "ModelID" in the string
|
||||
index_of_model_id = original_model_string.find("-ModelID")
|
||||
# Remove everything after "-ModelID" if it exists
|
||||
if index_of_model_id != -1:
|
||||
data["model"] = original_model_string[:index_of_model_id]
|
||||
potential_model_client = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||
)
|
||||
# check if provided keys == client keys #
|
||||
dynamic_api_key = kwargs.get("api_key", None)
|
||||
if (
|
||||
dynamic_api_key is not None
|
||||
and potential_model_client is not None
|
||||
and dynamic_api_key != potential_model_client.api_key
|
||||
):
|
||||
model_client = None
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
# call via litellm.atext_completion()
|
||||
response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||
model_client = potential_model_client
|
||||
self.total_calls[model_name] += 1
|
||||
response = await asyncio.wait_for(
|
||||
litellm.atext_completion(
|
||||
**{
|
||||
**data,
|
||||
"prompt": prompt,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
**kwargs,
|
||||
}
|
||||
),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
self.success_calls[model_name] += 1
|
||||
return response
|
||||
except Exception as e:
|
||||
if self.num_retries > 0:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_exception"] = e
|
||||
kwargs["original_function"] = self.completion
|
||||
return self.function_with_retries(**kwargs)
|
||||
else:
|
||||
raise e
|
||||
if model_name is not None:
|
||||
self.fail_calls[model_name] += 1
|
||||
raise e
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
|
@ -1531,34 +1567,10 @@ class Router:
|
|||
model
|
||||
] # update the model to the actual value if an alias has been passed in
|
||||
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
||||
deployments = self.leastbusy_logger.get_available_deployments(
|
||||
model_group=model
|
||||
deployment = self.leastbusy_logger.get_available_deployments(
|
||||
model_group=model, healthy_deployments=healthy_deployments
|
||||
)
|
||||
self.print_verbose(f"deployments in least-busy router: {deployments}")
|
||||
# pick least busy deployment
|
||||
min_traffic = float("inf")
|
||||
min_deployment = None
|
||||
for k, v in deployments.items():
|
||||
if v < min_traffic:
|
||||
min_traffic = v
|
||||
min_deployment = k
|
||||
self.print_verbose(f"min_deployment: {min_deployment};")
|
||||
############## No Available Deployments passed, we do a random pick #################
|
||||
if min_deployment is None:
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
############## Available Deployments passed, we find the relevant item #################
|
||||
else:
|
||||
## check if min deployment is a string, if so, cast it to int
|
||||
for m in healthy_deployments:
|
||||
if isinstance(min_deployment, str) and isinstance(
|
||||
m["model_info"]["id"], int
|
||||
):
|
||||
min_deployment = int(min_deployment)
|
||||
if m["model_info"]["id"] == min_deployment:
|
||||
return m
|
||||
self.print_verbose(f"no healthy deployment with that id found!")
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
return min_deployment
|
||||
return deployment
|
||||
elif self.routing_strategy == "simple-shuffle":
|
||||
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
|
||||
############## Check if we can do a RPM/TPM based weighted pick #################
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue