fix(router.py): support retry and fallbacks for atext_completion

This commit is contained in:
Krrish Dholakia 2023-12-30 11:19:13 +05:30
parent 7ecd7b3e8d
commit 38f55249e1
6 changed files with 290 additions and 69 deletions

View file

@ -212,15 +212,15 @@ async def acompletion(*args, **kwargs):
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False): # return an async generator
return _async_streaming(
response=response,
model=model,
custom_llm_provider=custom_llm_provider,
args=args,
)
else:
return response
# if kwargs.get("stream", False): # return an async generator
# return _async_streaming(
# response=response,
# model=model,
# custom_llm_provider=custom_llm_provider,
# args=args,
# )
# else:
return response
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(

View file

@ -86,6 +86,7 @@ from fastapi import (
Depends,
BackgroundTasks,
Header,
Response,
)
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer
@ -1068,6 +1069,7 @@ def model_list():
)
async def completion(
request: Request,
fastapi_response: Response,
model: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
background_tasks: BackgroundTasks = BackgroundTasks(),
@ -1143,17 +1145,23 @@ async def completion(
else: # router is not set
response = await litellm.atext_completion(**data)
model_id = response._hidden_params.get("model_id", None) or ""
print(f"final response: {response}")
if (
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
custom_headers = {"x-litellm-model-id": model_id}
return StreamingResponse(
async_data_generator(
user_api_key_dict=user_api_key_dict, response=response
user_api_key_dict=user_api_key_dict,
response=response,
),
media_type="text/event-stream",
headers=custom_headers,
)
fastapi_response.headers["x-litellm-model-id"] = model_id
return response
except Exception as e:
print(f"EXCEPTION RAISED IN PROXY MAIN.PY")
@ -1187,6 +1195,7 @@ async def completion(
) # azure compatible endpoint
async def chat_completion(
request: Request,
fastapi_response: Response,
model: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
background_tasks: BackgroundTasks = BackgroundTasks(),
@ -1282,19 +1291,24 @@ async def chat_completion(
else: # router is not set
response = await litellm.acompletion(**data)
print(f"final response: {response}")
model_id = response._hidden_params.get("model_id", None) or ""
if (
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
custom_headers = {"x-litellm-model-id": model_id}
return StreamingResponse(
async_data_generator(
user_api_key_dict=user_api_key_dict, response=response
user_api_key_dict=user_api_key_dict,
response=response,
),
media_type="text/event-stream",
headers=custom_headers,
)
fastapi_response.headers["x-litellm-model-id"] = model_id
return response
except Exception as e:
traceback.print_exc()
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e
)

View file

@ -191,7 +191,9 @@ class Router:
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
### ROUTING SETUP ###
if routing_strategy == "least-busy":
self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache)
self.leastbusy_logger = LeastBusyLoggingHandler(
router_cache=self.cache, model_list=self.model_list
)
## add callback
if isinstance(litellm.input_callback, list):
litellm.input_callback.append(self.leastbusy_logger) # type: ignore
@ -506,7 +508,13 @@ class Router:
**kwargs,
):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
messages = [{"role": "user", "content": prompt}]
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(
@ -530,7 +538,6 @@ class Router:
if self.num_retries > 0:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.completion
return self.function_with_retries(**kwargs)
else:
@ -546,16 +553,34 @@ class Router:
**kwargs,
):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._atext_completion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
messages = [{"role": "user", "content": prompt}]
# pick the one that is available (lowest TPM/RPM)
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
async def _atext_completion(self, model: str, prompt: str, **kwargs):
try:
self.print_verbose(
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
)
deployment = self.get_available_deployment(
model=model,
messages=messages,
messages=[{"role": "user", "content": prompt}],
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{"deployment": deployment["litellm_params"]["model"]}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
@ -564,27 +589,38 @@ class Router:
elif k == "metadata":
kwargs[k].update(v)
########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"]
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
data["model"] = original_model_string
# call via litellm.atext_completion()
response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await asyncio.wait_for(
litellm.atext_completion(
**{
**data,
"prompt": prompt,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
),
timeout=self.timeout,
)
self.success_calls[model_name] += 1
return response
except Exception as e:
if self.num_retries > 0:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.completion
return self.function_with_retries(**kwargs)
else:
raise e
if model_name is not None:
self.fail_calls[model_name] += 1
raise e
def embedding(
self,
@ -1531,34 +1567,10 @@ class Router:
model
] # update the model to the actual value if an alias has been passed in
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
deployments = self.leastbusy_logger.get_available_deployments(
model_group=model
deployment = self.leastbusy_logger.get_available_deployments(
model_group=model, healthy_deployments=healthy_deployments
)
self.print_verbose(f"deployments in least-busy router: {deployments}")
# pick least busy deployment
min_traffic = float("inf")
min_deployment = None
for k, v in deployments.items():
if v < min_traffic:
min_traffic = v
min_deployment = k
self.print_verbose(f"min_deployment: {min_deployment};")
############## No Available Deployments passed, we do a random pick #################
if min_deployment is None:
min_deployment = random.choice(healthy_deployments)
############## Available Deployments passed, we find the relevant item #################
else:
## check if min deployment is a string, if so, cast it to int
for m in healthy_deployments:
if isinstance(min_deployment, str) and isinstance(
m["model_info"]["id"], int
):
min_deployment = int(min_deployment)
if m["model_info"]["id"] == min_deployment:
return m
self.print_verbose(f"no healthy deployment with that id found!")
min_deployment = random.choice(healthy_deployments)
return min_deployment
return deployment
elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
############## Check if we can do a RPM/TPM based weighted pick #################

View file

@ -6,7 +6,7 @@
# - use litellm.success + failure callbacks to log when a request completed
# - in get_available_deployment, for a given model group name -> pick based on traffic
import dotenv, os, requests
import dotenv, os, requests, random
from typing import Optional
dotenv.load_dotenv() # Loading env variables using dotenv
@ -20,9 +20,10 @@ class LeastBusyLoggingHandler(CustomLogger):
logged_success: int = 0
logged_failure: int = 0
def __init__(self, router_cache: DualCache):
def __init__(self, router_cache: DualCache, model_list: list):
self.router_cache = router_cache
self.mapping_deployment_to_id: dict = {}
self.model_list = model_list
def log_pre_api_call(self, model, messages, kwargs):
"""
@ -168,8 +169,28 @@ class LeastBusyLoggingHandler(CustomLogger):
except Exception as e:
pass
def get_available_deployments(self, model_group: str):
def get_available_deployments(self, model_group: str, healthy_deployments: list):
request_count_api_key = f"{model_group}_request_count"
return_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
deployments = self.router_cache.get_cache(key=request_count_api_key) or {}
all_deployments = deployments
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in all_deployments:
all_deployments[d["model_info"]["id"]] = 0
# map deployment to id
return return_dict
# pick least busy deployment
min_traffic = float("inf")
min_deployment = None
for k, v in all_deployments.items():
if v < min_traffic:
min_traffic = v
min_deployment = k
if min_deployment is not None:
## check if min deployment is a string, if so, cast it to int
for m in healthy_deployments:
if m["model_info"]["id"] == min_deployment:
return m
min_deployment = random.choice(healthy_deployments)
else:
min_deployment = random.choice(healthy_deployments)
return min_deployment

View file

@ -1,7 +1,7 @@
#### What this tests ####
# This tests the router's ability to identify the least busy deployment
import sys, os, asyncio, time
import sys, os, asyncio, time, random
import traceback
from dotenv import load_dotenv
@ -128,3 +128,139 @@ def test_router_get_available_deployments():
assert return_dict[1] == 10
assert return_dict[2] == 54
assert return_dict[3] == 100
## Test with Real calls ##
@pytest.mark.asyncio
async def test_router_atext_completion_streaming():
prompt = "Hello, can you generate a 500 words poem?"
model = "azure-model"
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_CANADA_API_KEY",
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 3},
},
]
router = Router(
model_list=model_list,
routing_strategy="least-busy",
set_verbose=False,
num_retries=3,
) # type: ignore
### Call the async calls in sequence, so we start 1 call before going to the next.
## CALL 1
await asyncio.sleep(random.uniform(0, 2))
await router.atext_completion(model=model, prompt=prompt, stream=True)
## CALL 2
await asyncio.sleep(random.uniform(0, 2))
await router.atext_completion(model=model, prompt=prompt, stream=True)
## CALL 3
await asyncio.sleep(random.uniform(0, 2))
await router.atext_completion(model=model, prompt=prompt, stream=True)
cache_key = f"{model}_request_count"
## check if calls equally distributed
cache_dict = router.cache.get_cache(key=cache_key)
for k, v in cache_dict.items():
assert v == 1
# asyncio.run(test_router_atext_completion_streaming())
@pytest.mark.asyncio
async def test_router_completion_streaming():
messages = [
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
]
model = "azure-model"
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_CANADA_API_KEY",
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 3},
},
]
router = Router(
model_list=model_list,
routing_strategy="least-busy",
set_verbose=False,
num_retries=3,
) # type: ignore
### Call the async calls in sequence, so we start 1 call before going to the next.
## CALL 1
await asyncio.sleep(random.uniform(0, 2))
await router.acompletion(model=model, messages=messages, stream=True)
## CALL 2
await asyncio.sleep(random.uniform(0, 2))
await router.acompletion(model=model, messages=messages, stream=True)
## CALL 3
await asyncio.sleep(random.uniform(0, 2))
await router.acompletion(model=model, messages=messages, stream=True)
cache_key = f"{model}_request_count"
## check if calls equally distributed
cache_dict = router.cache.get_cache(key=cache_key)
for k, v in cache_dict.items():
assert v == 1

View file

@ -479,6 +479,8 @@ class EmbeddingResponse(OpenAIObject):
usage: Optional[Usage] = None
"""Usage statistics for the embedding request."""
_hidden_params: dict = {}
def __init__(
self, model=None, usage=None, stream=False, response_ms=None, data=None
):
@ -640,6 +642,8 @@ class ImageResponse(OpenAIObject):
usage: Optional[dict] = None
_hidden_params: dict = {}
def __init__(self, created=None, data=None, response_ms=None):
if response_ms:
_response_ms = response_ms
@ -2053,6 +2057,10 @@ def client(original_function):
target=logging_obj.success_handler, args=(result, start_time, end_time)
).start()
# RETURN RESULT
if hasattr(result, "_hidden_params"):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None
)
result._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
@ -2273,6 +2281,10 @@ def client(original_function):
target=logging_obj.success_handler, args=(result, start_time, end_time)
).start()
# RETURN RESULT
if hasattr(result, "_hidden_params"):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None
)
if isinstance(result, ModelResponse):
result._response_ms = (
end_time - start_time
@ -6527,6 +6539,13 @@ class CustomStreamWrapper:
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
self.holding_chunk = ""
self.complete_response = ""
self._hidden_params = {
"model_id": (
self.logging_obj.model_call_details.get("litellm_params", {})
.get("model_info", {})
.get("id", None)
)
} # returned as x-litellm-model-id response header in proxy
def __iter__(self):
return self
@ -7417,6 +7436,15 @@ class CustomStreamWrapper:
threading.Thread(
target=self.logging_obj.success_handler, args=(response,)
).start() # log response
# RETURN RESULT
if hasattr(response, "_hidden_params"):
response._hidden_params["model_id"] = (
self.logging_obj.model_call_details.get(
"litellm_params", {}
)
.get("model_info", {})
.get("id", None)
)
return response
except StopIteration:
raise # Re-raise StopIteration
@ -7467,6 +7495,16 @@ class CustomStreamWrapper:
processed_chunk,
)
)
# RETURN RESULT
if hasattr(processed_chunk, "_hidden_params"):
model_id = (
self.logging_obj.model_call_details.get(
"litellm_params", {}
)
.get("model_info", {})
.get("id", None)
)
processed_chunk._hidden_params["model_id"] = model_id
return processed_chunk
raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls