forked from phoenix/litellm-mirror
fix(router.py): support retry and fallbacks for atext_completion
This commit is contained in:
parent
7ecd7b3e8d
commit
38f55249e1
6 changed files with 290 additions and 69 deletions
|
@ -212,15 +212,15 @@ async def acompletion(*args, **kwargs):
|
|||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
if kwargs.get("stream", False): # return an async generator
|
||||
return _async_streaming(
|
||||
response=response,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
args=args,
|
||||
)
|
||||
else:
|
||||
return response
|
||||
# if kwargs.get("stream", False): # return an async generator
|
||||
# return _async_streaming(
|
||||
# response=response,
|
||||
# model=model,
|
||||
# custom_llm_provider=custom_llm_provider,
|
||||
# args=args,
|
||||
# )
|
||||
# else:
|
||||
return response
|
||||
except Exception as e:
|
||||
custom_llm_provider = custom_llm_provider or "openai"
|
||||
raise exception_type(
|
||||
|
|
|
@ -86,6 +86,7 @@ from fastapi import (
|
|||
Depends,
|
||||
BackgroundTasks,
|
||||
Header,
|
||||
Response,
|
||||
)
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
@ -1068,6 +1069,7 @@ def model_list():
|
|||
)
|
||||
async def completion(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
model: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
|
@ -1143,17 +1145,23 @@ async def completion(
|
|||
else: # router is not set
|
||||
response = await litellm.atext_completion(**data)
|
||||
|
||||
model_id = response._hidden_params.get("model_id", None) or ""
|
||||
|
||||
print(f"final response: {response}")
|
||||
if (
|
||||
"stream" in data and data["stream"] == True
|
||||
): # use generate_responses to stream responses
|
||||
custom_headers = {"x-litellm-model-id": model_id}
|
||||
return StreamingResponse(
|
||||
async_data_generator(
|
||||
user_api_key_dict=user_api_key_dict, response=response
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
response=response,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers=custom_headers,
|
||||
)
|
||||
|
||||
fastapi_response.headers["x-litellm-model-id"] = model_id
|
||||
return response
|
||||
except Exception as e:
|
||||
print(f"EXCEPTION RAISED IN PROXY MAIN.PY")
|
||||
|
@ -1187,6 +1195,7 @@ async def completion(
|
|||
) # azure compatible endpoint
|
||||
async def chat_completion(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
model: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
|
@ -1282,19 +1291,24 @@ async def chat_completion(
|
|||
else: # router is not set
|
||||
response = await litellm.acompletion(**data)
|
||||
|
||||
print(f"final response: {response}")
|
||||
model_id = response._hidden_params.get("model_id", None) or ""
|
||||
if (
|
||||
"stream" in data and data["stream"] == True
|
||||
): # use generate_responses to stream responses
|
||||
custom_headers = {"x-litellm-model-id": model_id}
|
||||
return StreamingResponse(
|
||||
async_data_generator(
|
||||
user_api_key_dict=user_api_key_dict, response=response
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
response=response,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers=custom_headers,
|
||||
)
|
||||
|
||||
fastapi_response.headers["x-litellm-model-id"] = model_id
|
||||
return response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
||||
)
|
||||
|
|
|
@ -191,7 +191,9 @@ class Router:
|
|||
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||
### ROUTING SETUP ###
|
||||
if routing_strategy == "least-busy":
|
||||
self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache)
|
||||
self.leastbusy_logger = LeastBusyLoggingHandler(
|
||||
router_cache=self.cache, model_list=self.model_list
|
||||
)
|
||||
## add callback
|
||||
if isinstance(litellm.input_callback, list):
|
||||
litellm.input_callback.append(self.leastbusy_logger) # type: ignore
|
||||
|
@ -506,7 +508,13 @@ class Router:
|
|||
**kwargs,
|
||||
):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["prompt"] = prompt
|
||||
kwargs["original_function"] = self._acompletion
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(
|
||||
|
@ -530,7 +538,6 @@ class Router:
|
|||
if self.num_retries > 0:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_exception"] = e
|
||||
kwargs["original_function"] = self.completion
|
||||
return self.function_with_retries(**kwargs)
|
||||
else:
|
||||
|
@ -546,16 +553,34 @@ class Router:
|
|||
**kwargs,
|
||||
):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["prompt"] = prompt
|
||||
kwargs["original_function"] = self._atext_completion
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def _atext_completion(self, model: str, prompt: str, **kwargs):
|
||||
try:
|
||||
self.print_verbose(
|
||||
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
|
||||
)
|
||||
deployment = self.get_available_deployment(
|
||||
model=model,
|
||||
messages=messages,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"deployment": deployment["litellm_params"]["model"]}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs
|
||||
|
@ -564,27 +589,38 @@ class Router:
|
|||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
########## remove -ModelID-XXXX from model ##############
|
||||
original_model_string = data["model"]
|
||||
# Find the index of "ModelID" in the string
|
||||
index_of_model_id = original_model_string.find("-ModelID")
|
||||
# Remove everything after "-ModelID" if it exists
|
||||
if index_of_model_id != -1:
|
||||
data["model"] = original_model_string[:index_of_model_id]
|
||||
potential_model_client = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||
)
|
||||
# check if provided keys == client keys #
|
||||
dynamic_api_key = kwargs.get("api_key", None)
|
||||
if (
|
||||
dynamic_api_key is not None
|
||||
and potential_model_client is not None
|
||||
and dynamic_api_key != potential_model_client.api_key
|
||||
):
|
||||
model_client = None
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
# call via litellm.atext_completion()
|
||||
response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||
model_client = potential_model_client
|
||||
self.total_calls[model_name] += 1
|
||||
response = await asyncio.wait_for(
|
||||
litellm.atext_completion(
|
||||
**{
|
||||
**data,
|
||||
"prompt": prompt,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
**kwargs,
|
||||
}
|
||||
),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
self.success_calls[model_name] += 1
|
||||
return response
|
||||
except Exception as e:
|
||||
if self.num_retries > 0:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_exception"] = e
|
||||
kwargs["original_function"] = self.completion
|
||||
return self.function_with_retries(**kwargs)
|
||||
else:
|
||||
raise e
|
||||
if model_name is not None:
|
||||
self.fail_calls[model_name] += 1
|
||||
raise e
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
|
@ -1531,34 +1567,10 @@ class Router:
|
|||
model
|
||||
] # update the model to the actual value if an alias has been passed in
|
||||
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
||||
deployments = self.leastbusy_logger.get_available_deployments(
|
||||
model_group=model
|
||||
deployment = self.leastbusy_logger.get_available_deployments(
|
||||
model_group=model, healthy_deployments=healthy_deployments
|
||||
)
|
||||
self.print_verbose(f"deployments in least-busy router: {deployments}")
|
||||
# pick least busy deployment
|
||||
min_traffic = float("inf")
|
||||
min_deployment = None
|
||||
for k, v in deployments.items():
|
||||
if v < min_traffic:
|
||||
min_traffic = v
|
||||
min_deployment = k
|
||||
self.print_verbose(f"min_deployment: {min_deployment};")
|
||||
############## No Available Deployments passed, we do a random pick #################
|
||||
if min_deployment is None:
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
############## Available Deployments passed, we find the relevant item #################
|
||||
else:
|
||||
## check if min deployment is a string, if so, cast it to int
|
||||
for m in healthy_deployments:
|
||||
if isinstance(min_deployment, str) and isinstance(
|
||||
m["model_info"]["id"], int
|
||||
):
|
||||
min_deployment = int(min_deployment)
|
||||
if m["model_info"]["id"] == min_deployment:
|
||||
return m
|
||||
self.print_verbose(f"no healthy deployment with that id found!")
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
return min_deployment
|
||||
return deployment
|
||||
elif self.routing_strategy == "simple-shuffle":
|
||||
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
|
||||
############## Check if we can do a RPM/TPM based weighted pick #################
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
# - use litellm.success + failure callbacks to log when a request completed
|
||||
# - in get_available_deployment, for a given model group name -> pick based on traffic
|
||||
|
||||
import dotenv, os, requests
|
||||
import dotenv, os, requests, random
|
||||
from typing import Optional
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
|
@ -20,9 +20,10 @@ class LeastBusyLoggingHandler(CustomLogger):
|
|||
logged_success: int = 0
|
||||
logged_failure: int = 0
|
||||
|
||||
def __init__(self, router_cache: DualCache):
|
||||
def __init__(self, router_cache: DualCache, model_list: list):
|
||||
self.router_cache = router_cache
|
||||
self.mapping_deployment_to_id: dict = {}
|
||||
self.model_list = model_list
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
"""
|
||||
|
@ -168,8 +169,28 @@ class LeastBusyLoggingHandler(CustomLogger):
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
def get_available_deployments(self, model_group: str):
|
||||
def get_available_deployments(self, model_group: str, healthy_deployments: list):
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
return_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
deployments = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
all_deployments = deployments
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
if d["model_info"]["id"] not in all_deployments:
|
||||
all_deployments[d["model_info"]["id"]] = 0
|
||||
# map deployment to id
|
||||
return return_dict
|
||||
# pick least busy deployment
|
||||
min_traffic = float("inf")
|
||||
min_deployment = None
|
||||
for k, v in all_deployments.items():
|
||||
if v < min_traffic:
|
||||
min_traffic = v
|
||||
min_deployment = k
|
||||
if min_deployment is not None:
|
||||
## check if min deployment is a string, if so, cast it to int
|
||||
for m in healthy_deployments:
|
||||
if m["model_info"]["id"] == min_deployment:
|
||||
return m
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
else:
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
return min_deployment
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#### What this tests ####
|
||||
# This tests the router's ability to identify the least busy deployment
|
||||
|
||||
import sys, os, asyncio, time
|
||||
import sys, os, asyncio, time, random
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
@ -128,3 +128,139 @@ def test_router_get_available_deployments():
|
|||
assert return_dict[1] == 10
|
||||
assert return_dict[2] == 54
|
||||
assert return_dict[3] == 100
|
||||
|
||||
|
||||
## Test with Real calls ##
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_atext_completion_streaming():
|
||||
prompt = "Hello, can you generate a 500 words poem?"
|
||||
model = "azure-model"
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-turbo",
|
||||
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||
"rpm": 1440,
|
||||
},
|
||||
"model_info": {"id": 1},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||
"rpm": 6,
|
||||
},
|
||||
"model_info": {"id": 2},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_key": "os.environ/AZURE_CANADA_API_KEY",
|
||||
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
||||
"rpm": 6,
|
||||
},
|
||||
"model_info": {"id": 3},
|
||||
},
|
||||
]
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy="least-busy",
|
||||
set_verbose=False,
|
||||
num_retries=3,
|
||||
) # type: ignore
|
||||
|
||||
### Call the async calls in sequence, so we start 1 call before going to the next.
|
||||
|
||||
## CALL 1
|
||||
await asyncio.sleep(random.uniform(0, 2))
|
||||
await router.atext_completion(model=model, prompt=prompt, stream=True)
|
||||
|
||||
## CALL 2
|
||||
await asyncio.sleep(random.uniform(0, 2))
|
||||
await router.atext_completion(model=model, prompt=prompt, stream=True)
|
||||
|
||||
## CALL 3
|
||||
await asyncio.sleep(random.uniform(0, 2))
|
||||
await router.atext_completion(model=model, prompt=prompt, stream=True)
|
||||
|
||||
cache_key = f"{model}_request_count"
|
||||
## check if calls equally distributed
|
||||
cache_dict = router.cache.get_cache(key=cache_key)
|
||||
for k, v in cache_dict.items():
|
||||
assert v == 1
|
||||
|
||||
|
||||
# asyncio.run(test_router_atext_completion_streaming())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_completion_streaming():
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
|
||||
]
|
||||
model = "azure-model"
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-turbo",
|
||||
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||
"rpm": 1440,
|
||||
},
|
||||
"model_info": {"id": 1},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||
"rpm": 6,
|
||||
},
|
||||
"model_info": {"id": 2},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_key": "os.environ/AZURE_CANADA_API_KEY",
|
||||
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
||||
"rpm": 6,
|
||||
},
|
||||
"model_info": {"id": 3},
|
||||
},
|
||||
]
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy="least-busy",
|
||||
set_verbose=False,
|
||||
num_retries=3,
|
||||
) # type: ignore
|
||||
|
||||
### Call the async calls in sequence, so we start 1 call before going to the next.
|
||||
|
||||
## CALL 1
|
||||
await asyncio.sleep(random.uniform(0, 2))
|
||||
await router.acompletion(model=model, messages=messages, stream=True)
|
||||
|
||||
## CALL 2
|
||||
await asyncio.sleep(random.uniform(0, 2))
|
||||
await router.acompletion(model=model, messages=messages, stream=True)
|
||||
|
||||
## CALL 3
|
||||
await asyncio.sleep(random.uniform(0, 2))
|
||||
await router.acompletion(model=model, messages=messages, stream=True)
|
||||
|
||||
cache_key = f"{model}_request_count"
|
||||
## check if calls equally distributed
|
||||
cache_dict = router.cache.get_cache(key=cache_key)
|
||||
for k, v in cache_dict.items():
|
||||
assert v == 1
|
||||
|
|
|
@ -479,6 +479,8 @@ class EmbeddingResponse(OpenAIObject):
|
|||
usage: Optional[Usage] = None
|
||||
"""Usage statistics for the embedding request."""
|
||||
|
||||
_hidden_params: dict = {}
|
||||
|
||||
def __init__(
|
||||
self, model=None, usage=None, stream=False, response_ms=None, data=None
|
||||
):
|
||||
|
@ -640,6 +642,8 @@ class ImageResponse(OpenAIObject):
|
|||
|
||||
usage: Optional[dict] = None
|
||||
|
||||
_hidden_params: dict = {}
|
||||
|
||||
def __init__(self, created=None, data=None, response_ms=None):
|
||||
if response_ms:
|
||||
_response_ms = response_ms
|
||||
|
@ -2053,6 +2057,10 @@ def client(original_function):
|
|||
target=logging_obj.success_handler, args=(result, start_time, end_time)
|
||||
).start()
|
||||
# RETURN RESULT
|
||||
if hasattr(result, "_hidden_params"):
|
||||
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
||||
"id", None
|
||||
)
|
||||
result._response_ms = (
|
||||
end_time - start_time
|
||||
).total_seconds() * 1000 # return response latency in ms like openai
|
||||
|
@ -2273,6 +2281,10 @@ def client(original_function):
|
|||
target=logging_obj.success_handler, args=(result, start_time, end_time)
|
||||
).start()
|
||||
# RETURN RESULT
|
||||
if hasattr(result, "_hidden_params"):
|
||||
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
||||
"id", None
|
||||
)
|
||||
if isinstance(result, ModelResponse):
|
||||
result._response_ms = (
|
||||
end_time - start_time
|
||||
|
@ -6527,6 +6539,13 @@ class CustomStreamWrapper:
|
|||
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
|
||||
self.holding_chunk = ""
|
||||
self.complete_response = ""
|
||||
self._hidden_params = {
|
||||
"model_id": (
|
||||
self.logging_obj.model_call_details.get("litellm_params", {})
|
||||
.get("model_info", {})
|
||||
.get("id", None)
|
||||
)
|
||||
} # returned as x-litellm-model-id response header in proxy
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
@ -7417,6 +7436,15 @@ class CustomStreamWrapper:
|
|||
threading.Thread(
|
||||
target=self.logging_obj.success_handler, args=(response,)
|
||||
).start() # log response
|
||||
# RETURN RESULT
|
||||
if hasattr(response, "_hidden_params"):
|
||||
response._hidden_params["model_id"] = (
|
||||
self.logging_obj.model_call_details.get(
|
||||
"litellm_params", {}
|
||||
)
|
||||
.get("model_info", {})
|
||||
.get("id", None)
|
||||
)
|
||||
return response
|
||||
except StopIteration:
|
||||
raise # Re-raise StopIteration
|
||||
|
@ -7467,6 +7495,16 @@ class CustomStreamWrapper:
|
|||
processed_chunk,
|
||||
)
|
||||
)
|
||||
# RETURN RESULT
|
||||
if hasattr(processed_chunk, "_hidden_params"):
|
||||
model_id = (
|
||||
self.logging_obj.model_call_details.get(
|
||||
"litellm_params", {}
|
||||
)
|
||||
.get("model_info", {})
|
||||
.get("id", None)
|
||||
)
|
||||
processed_chunk._hidden_params["model_id"] = model_id
|
||||
return processed_chunk
|
||||
raise StopAsyncIteration
|
||||
else: # temporary patch for non-aiohttp async calls
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue