fix(router.py): support retry and fallbacks for atext_completion

This commit is contained in:
Krrish Dholakia 2023-12-30 11:19:13 +05:30
parent 7ecd7b3e8d
commit 38f55249e1
6 changed files with 290 additions and 69 deletions

View file

@ -191,7 +191,9 @@ class Router:
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
### ROUTING SETUP ###
if routing_strategy == "least-busy":
self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache)
self.leastbusy_logger = LeastBusyLoggingHandler(
router_cache=self.cache, model_list=self.model_list
)
## add callback
if isinstance(litellm.input_callback, list):
litellm.input_callback.append(self.leastbusy_logger) # type: ignore
@ -506,7 +508,13 @@ class Router:
**kwargs,
):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
messages = [{"role": "user", "content": prompt}]
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(
@ -530,7 +538,6 @@ class Router:
if self.num_retries > 0:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.completion
return self.function_with_retries(**kwargs)
else:
@ -546,16 +553,34 @@ class Router:
**kwargs,
):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._atext_completion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
messages = [{"role": "user", "content": prompt}]
# pick the one that is available (lowest TPM/RPM)
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
async def _atext_completion(self, model: str, prompt: str, **kwargs):
try:
self.print_verbose(
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
)
deployment = self.get_available_deployment(
model=model,
messages=messages,
messages=[{"role": "user", "content": prompt}],
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{"deployment": deployment["litellm_params"]["model"]}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
@ -564,27 +589,38 @@ class Router:
elif k == "metadata":
kwargs[k].update(v)
########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"]
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
data["model"] = original_model_string
# call via litellm.atext_completion()
response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await asyncio.wait_for(
litellm.atext_completion(
**{
**data,
"prompt": prompt,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
),
timeout=self.timeout,
)
self.success_calls[model_name] += 1
return response
except Exception as e:
if self.num_retries > 0:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.completion
return self.function_with_retries(**kwargs)
else:
raise e
if model_name is not None:
self.fail_calls[model_name] += 1
raise e
def embedding(
self,
@ -1531,34 +1567,10 @@ class Router:
model
] # update the model to the actual value if an alias has been passed in
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
deployments = self.leastbusy_logger.get_available_deployments(
model_group=model
deployment = self.leastbusy_logger.get_available_deployments(
model_group=model, healthy_deployments=healthy_deployments
)
self.print_verbose(f"deployments in least-busy router: {deployments}")
# pick least busy deployment
min_traffic = float("inf")
min_deployment = None
for k, v in deployments.items():
if v < min_traffic:
min_traffic = v
min_deployment = k
self.print_verbose(f"min_deployment: {min_deployment};")
############## No Available Deployments passed, we do a random pick #################
if min_deployment is None:
min_deployment = random.choice(healthy_deployments)
############## Available Deployments passed, we find the relevant item #################
else:
## check if min deployment is a string, if so, cast it to int
for m in healthy_deployments:
if isinstance(min_deployment, str) and isinstance(
m["model_info"]["id"], int
):
min_deployment = int(min_deployment)
if m["model_info"]["id"] == min_deployment:
return m
self.print_verbose(f"no healthy deployment with that id found!")
min_deployment = random.choice(healthy_deployments)
return min_deployment
return deployment
elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
############## Check if we can do a RPM/TPM based weighted pick #################