forked from phoenix/litellm-mirror
fix(router.py): support retry and fallbacks for atext_completion
This commit is contained in:
parent
7ecd7b3e8d
commit
38f55249e1
6 changed files with 290 additions and 69 deletions
|
@ -212,14 +212,14 @@ async def acompletion(*args, **kwargs):
|
||||||
else:
|
else:
|
||||||
# Call the synchronous function using run_in_executor
|
# Call the synchronous function using run_in_executor
|
||||||
response = await loop.run_in_executor(None, func_with_context)
|
response = await loop.run_in_executor(None, func_with_context)
|
||||||
if kwargs.get("stream", False): # return an async generator
|
# if kwargs.get("stream", False): # return an async generator
|
||||||
return _async_streaming(
|
# return _async_streaming(
|
||||||
response=response,
|
# response=response,
|
||||||
model=model,
|
# model=model,
|
||||||
custom_llm_provider=custom_llm_provider,
|
# custom_llm_provider=custom_llm_provider,
|
||||||
args=args,
|
# args=args,
|
||||||
)
|
# )
|
||||||
else:
|
# else:
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
custom_llm_provider = custom_llm_provider or "openai"
|
custom_llm_provider = custom_llm_provider or "openai"
|
||||||
|
|
|
@ -86,6 +86,7 @@ from fastapi import (
|
||||||
Depends,
|
Depends,
|
||||||
BackgroundTasks,
|
BackgroundTasks,
|
||||||
Header,
|
Header,
|
||||||
|
Response,
|
||||||
)
|
)
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
@ -1068,6 +1069,7 @@ def model_list():
|
||||||
)
|
)
|
||||||
async def completion(
|
async def completion(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||||
|
@ -1143,17 +1145,23 @@ async def completion(
|
||||||
else: # router is not set
|
else: # router is not set
|
||||||
response = await litellm.atext_completion(**data)
|
response = await litellm.atext_completion(**data)
|
||||||
|
|
||||||
|
model_id = response._hidden_params.get("model_id", None) or ""
|
||||||
|
|
||||||
print(f"final response: {response}")
|
print(f"final response: {response}")
|
||||||
if (
|
if (
|
||||||
"stream" in data and data["stream"] == True
|
"stream" in data and data["stream"] == True
|
||||||
): # use generate_responses to stream responses
|
): # use generate_responses to stream responses
|
||||||
|
custom_headers = {"x-litellm-model-id": model_id}
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
async_data_generator(
|
async_data_generator(
|
||||||
user_api_key_dict=user_api_key_dict, response=response
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
response=response,
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
|
headers=custom_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fastapi_response.headers["x-litellm-model-id"] = model_id
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"EXCEPTION RAISED IN PROXY MAIN.PY")
|
print(f"EXCEPTION RAISED IN PROXY MAIN.PY")
|
||||||
|
@ -1187,6 +1195,7 @@ async def completion(
|
||||||
) # azure compatible endpoint
|
) # azure compatible endpoint
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||||
|
@ -1282,19 +1291,24 @@ async def chat_completion(
|
||||||
else: # router is not set
|
else: # router is not set
|
||||||
response = await litellm.acompletion(**data)
|
response = await litellm.acompletion(**data)
|
||||||
|
|
||||||
print(f"final response: {response}")
|
model_id = response._hidden_params.get("model_id", None) or ""
|
||||||
if (
|
if (
|
||||||
"stream" in data and data["stream"] == True
|
"stream" in data and data["stream"] == True
|
||||||
): # use generate_responses to stream responses
|
): # use generate_responses to stream responses
|
||||||
|
custom_headers = {"x-litellm-model-id": model_id}
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
async_data_generator(
|
async_data_generator(
|
||||||
user_api_key_dict=user_api_key_dict, response=response
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
response=response,
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
|
headers=custom_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fastapi_response.headers["x-litellm-model-id"] = model_id
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
user_api_key_dict=user_api_key_dict, original_exception=e
|
||||||
)
|
)
|
||||||
|
|
|
@ -191,7 +191,9 @@ class Router:
|
||||||
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||||
### ROUTING SETUP ###
|
### ROUTING SETUP ###
|
||||||
if routing_strategy == "least-busy":
|
if routing_strategy == "least-busy":
|
||||||
self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache)
|
self.leastbusy_logger = LeastBusyLoggingHandler(
|
||||||
|
router_cache=self.cache, model_list=self.model_list
|
||||||
|
)
|
||||||
## add callback
|
## add callback
|
||||||
if isinstance(litellm.input_callback, list):
|
if isinstance(litellm.input_callback, list):
|
||||||
litellm.input_callback.append(self.leastbusy_logger) # type: ignore
|
litellm.input_callback.append(self.leastbusy_logger) # type: ignore
|
||||||
|
@ -506,7 +508,13 @@ class Router:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
kwargs["model"] = model
|
||||||
|
kwargs["prompt"] = prompt
|
||||||
|
kwargs["original_function"] = self._acompletion
|
||||||
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
|
||||||
messages = [{"role": "user", "content": prompt}]
|
messages = [{"role": "user", "content": prompt}]
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
deployment = self.get_available_deployment(
|
deployment = self.get_available_deployment(
|
||||||
|
@ -530,7 +538,6 @@ class Router:
|
||||||
if self.num_retries > 0:
|
if self.num_retries > 0:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
kwargs["original_exception"] = e
|
|
||||||
kwargs["original_function"] = self.completion
|
kwargs["original_function"] = self.completion
|
||||||
return self.function_with_retries(**kwargs)
|
return self.function_with_retries(**kwargs)
|
||||||
else:
|
else:
|
||||||
|
@ -546,16 +553,34 @@ class Router:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
kwargs["model"] = model
|
||||||
|
kwargs["prompt"] = prompt
|
||||||
|
kwargs["original_function"] = self._atext_completion
|
||||||
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
messages = [{"role": "user", "content": prompt}]
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def _atext_completion(self, model: str, prompt: str, **kwargs):
|
||||||
|
try:
|
||||||
|
self.print_verbose(
|
||||||
|
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
|
||||||
|
)
|
||||||
deployment = self.get_available_deployment(
|
deployment = self.get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=[{"role": "user", "content": prompt}],
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
)
|
)
|
||||||
|
kwargs.setdefault("metadata", {}).update(
|
||||||
|
{"deployment": deployment["litellm_params"]["model"]}
|
||||||
|
)
|
||||||
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
|
model_name = data["model"]
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if (
|
if (
|
||||||
k not in kwargs
|
k not in kwargs
|
||||||
|
@ -564,26 +589,37 @@ class Router:
|
||||||
elif k == "metadata":
|
elif k == "metadata":
|
||||||
kwargs[k].update(v)
|
kwargs[k].update(v)
|
||||||
|
|
||||||
########## remove -ModelID-XXXX from model ##############
|
potential_model_client = self._get_client(
|
||||||
original_model_string = data["model"]
|
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||||
# Find the index of "ModelID" in the string
|
)
|
||||||
index_of_model_id = original_model_string.find("-ModelID")
|
# check if provided keys == client keys #
|
||||||
# Remove everything after "-ModelID" if it exists
|
dynamic_api_key = kwargs.get("api_key", None)
|
||||||
if index_of_model_id != -1:
|
if (
|
||||||
data["model"] = original_model_string[:index_of_model_id]
|
dynamic_api_key is not None
|
||||||
|
and potential_model_client is not None
|
||||||
|
and dynamic_api_key != potential_model_client.api_key
|
||||||
|
):
|
||||||
|
model_client = None
|
||||||
else:
|
else:
|
||||||
data["model"] = original_model_string
|
model_client = potential_model_client
|
||||||
# call via litellm.atext_completion()
|
self.total_calls[model_name] += 1
|
||||||
response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
response = await asyncio.wait_for(
|
||||||
|
litellm.atext_completion(
|
||||||
|
**{
|
||||||
|
**data,
|
||||||
|
"prompt": prompt,
|
||||||
|
"caching": self.cache_responses,
|
||||||
|
"client": model_client,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
self.success_calls[model_name] += 1
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.num_retries > 0:
|
if model_name is not None:
|
||||||
kwargs["model"] = model
|
self.fail_calls[model_name] += 1
|
||||||
kwargs["messages"] = messages
|
|
||||||
kwargs["original_exception"] = e
|
|
||||||
kwargs["original_function"] = self.completion
|
|
||||||
return self.function_with_retries(**kwargs)
|
|
||||||
else:
|
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def embedding(
|
def embedding(
|
||||||
|
@ -1531,34 +1567,10 @@ class Router:
|
||||||
model
|
model
|
||||||
] # update the model to the actual value if an alias has been passed in
|
] # update the model to the actual value if an alias has been passed in
|
||||||
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
||||||
deployments = self.leastbusy_logger.get_available_deployments(
|
deployment = self.leastbusy_logger.get_available_deployments(
|
||||||
model_group=model
|
model_group=model, healthy_deployments=healthy_deployments
|
||||||
)
|
)
|
||||||
self.print_verbose(f"deployments in least-busy router: {deployments}")
|
return deployment
|
||||||
# pick least busy deployment
|
|
||||||
min_traffic = float("inf")
|
|
||||||
min_deployment = None
|
|
||||||
for k, v in deployments.items():
|
|
||||||
if v < min_traffic:
|
|
||||||
min_traffic = v
|
|
||||||
min_deployment = k
|
|
||||||
self.print_verbose(f"min_deployment: {min_deployment};")
|
|
||||||
############## No Available Deployments passed, we do a random pick #################
|
|
||||||
if min_deployment is None:
|
|
||||||
min_deployment = random.choice(healthy_deployments)
|
|
||||||
############## Available Deployments passed, we find the relevant item #################
|
|
||||||
else:
|
|
||||||
## check if min deployment is a string, if so, cast it to int
|
|
||||||
for m in healthy_deployments:
|
|
||||||
if isinstance(min_deployment, str) and isinstance(
|
|
||||||
m["model_info"]["id"], int
|
|
||||||
):
|
|
||||||
min_deployment = int(min_deployment)
|
|
||||||
if m["model_info"]["id"] == min_deployment:
|
|
||||||
return m
|
|
||||||
self.print_verbose(f"no healthy deployment with that id found!")
|
|
||||||
min_deployment = random.choice(healthy_deployments)
|
|
||||||
return min_deployment
|
|
||||||
elif self.routing_strategy == "simple-shuffle":
|
elif self.routing_strategy == "simple-shuffle":
|
||||||
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
|
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
|
||||||
############## Check if we can do a RPM/TPM based weighted pick #################
|
############## Check if we can do a RPM/TPM based weighted pick #################
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
# - use litellm.success + failure callbacks to log when a request completed
|
# - use litellm.success + failure callbacks to log when a request completed
|
||||||
# - in get_available_deployment, for a given model group name -> pick based on traffic
|
# - in get_available_deployment, for a given model group name -> pick based on traffic
|
||||||
|
|
||||||
import dotenv, os, requests
|
import dotenv, os, requests, random
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
|
@ -20,9 +20,10 @@ class LeastBusyLoggingHandler(CustomLogger):
|
||||||
logged_success: int = 0
|
logged_success: int = 0
|
||||||
logged_failure: int = 0
|
logged_failure: int = 0
|
||||||
|
|
||||||
def __init__(self, router_cache: DualCache):
|
def __init__(self, router_cache: DualCache, model_list: list):
|
||||||
self.router_cache = router_cache
|
self.router_cache = router_cache
|
||||||
self.mapping_deployment_to_id: dict = {}
|
self.mapping_deployment_to_id: dict = {}
|
||||||
|
self.model_list = model_list
|
||||||
|
|
||||||
def log_pre_api_call(self, model, messages, kwargs):
|
def log_pre_api_call(self, model, messages, kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -168,8 +169,28 @@ class LeastBusyLoggingHandler(CustomLogger):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_available_deployments(self, model_group: str):
|
def get_available_deployments(self, model_group: str, healthy_deployments: list):
|
||||||
request_count_api_key = f"{model_group}_request_count"
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
return_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
deployments = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||||
|
all_deployments = deployments
|
||||||
|
for d in healthy_deployments:
|
||||||
|
## if healthy deployment not yet used
|
||||||
|
if d["model_info"]["id"] not in all_deployments:
|
||||||
|
all_deployments[d["model_info"]["id"]] = 0
|
||||||
# map deployment to id
|
# map deployment to id
|
||||||
return return_dict
|
# pick least busy deployment
|
||||||
|
min_traffic = float("inf")
|
||||||
|
min_deployment = None
|
||||||
|
for k, v in all_deployments.items():
|
||||||
|
if v < min_traffic:
|
||||||
|
min_traffic = v
|
||||||
|
min_deployment = k
|
||||||
|
if min_deployment is not None:
|
||||||
|
## check if min deployment is a string, if so, cast it to int
|
||||||
|
for m in healthy_deployments:
|
||||||
|
if m["model_info"]["id"] == min_deployment:
|
||||||
|
return m
|
||||||
|
min_deployment = random.choice(healthy_deployments)
|
||||||
|
else:
|
||||||
|
min_deployment = random.choice(healthy_deployments)
|
||||||
|
return min_deployment
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#### What this tests ####
|
#### What this tests ####
|
||||||
# This tests the router's ability to identify the least busy deployment
|
# This tests the router's ability to identify the least busy deployment
|
||||||
|
|
||||||
import sys, os, asyncio, time
|
import sys, os, asyncio, time, random
|
||||||
import traceback
|
import traceback
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
@ -128,3 +128,139 @@ def test_router_get_available_deployments():
|
||||||
assert return_dict[1] == 10
|
assert return_dict[1] == 10
|
||||||
assert return_dict[2] == 54
|
assert return_dict[2] == 54
|
||||||
assert return_dict[3] == 100
|
assert return_dict[3] == 100
|
||||||
|
|
||||||
|
|
||||||
|
## Test with Real calls ##
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_atext_completion_streaming():
|
||||||
|
prompt = "Hello, can you generate a 500 words poem?"
|
||||||
|
model = "azure-model"
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||||
|
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||||
|
"rpm": 1440,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-35-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||||
|
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||||
|
"rpm": 6,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-35-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_CANADA_API_KEY",
|
||||||
|
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
||||||
|
"rpm": 6,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 3},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
routing_strategy="least-busy",
|
||||||
|
set_verbose=False,
|
||||||
|
num_retries=3,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
### Call the async calls in sequence, so we start 1 call before going to the next.
|
||||||
|
|
||||||
|
## CALL 1
|
||||||
|
await asyncio.sleep(random.uniform(0, 2))
|
||||||
|
await router.atext_completion(model=model, prompt=prompt, stream=True)
|
||||||
|
|
||||||
|
## CALL 2
|
||||||
|
await asyncio.sleep(random.uniform(0, 2))
|
||||||
|
await router.atext_completion(model=model, prompt=prompt, stream=True)
|
||||||
|
|
||||||
|
## CALL 3
|
||||||
|
await asyncio.sleep(random.uniform(0, 2))
|
||||||
|
await router.atext_completion(model=model, prompt=prompt, stream=True)
|
||||||
|
|
||||||
|
cache_key = f"{model}_request_count"
|
||||||
|
## check if calls equally distributed
|
||||||
|
cache_dict = router.cache.get_cache(key=cache_key)
|
||||||
|
for k, v in cache_dict.items():
|
||||||
|
assert v == 1
|
||||||
|
|
||||||
|
|
||||||
|
# asyncio.run(test_router_atext_completion_streaming())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_completion_streaming():
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
|
||||||
|
]
|
||||||
|
model = "azure-model"
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||||
|
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||||
|
"rpm": 1440,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-35-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||||
|
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||||
|
"rpm": 6,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-35-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_CANADA_API_KEY",
|
||||||
|
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
||||||
|
"rpm": 6,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 3},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
routing_strategy="least-busy",
|
||||||
|
set_verbose=False,
|
||||||
|
num_retries=3,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
### Call the async calls in sequence, so we start 1 call before going to the next.
|
||||||
|
|
||||||
|
## CALL 1
|
||||||
|
await asyncio.sleep(random.uniform(0, 2))
|
||||||
|
await router.acompletion(model=model, messages=messages, stream=True)
|
||||||
|
|
||||||
|
## CALL 2
|
||||||
|
await asyncio.sleep(random.uniform(0, 2))
|
||||||
|
await router.acompletion(model=model, messages=messages, stream=True)
|
||||||
|
|
||||||
|
## CALL 3
|
||||||
|
await asyncio.sleep(random.uniform(0, 2))
|
||||||
|
await router.acompletion(model=model, messages=messages, stream=True)
|
||||||
|
|
||||||
|
cache_key = f"{model}_request_count"
|
||||||
|
## check if calls equally distributed
|
||||||
|
cache_dict = router.cache.get_cache(key=cache_key)
|
||||||
|
for k, v in cache_dict.items():
|
||||||
|
assert v == 1
|
||||||
|
|
|
@ -479,6 +479,8 @@ class EmbeddingResponse(OpenAIObject):
|
||||||
usage: Optional[Usage] = None
|
usage: Optional[Usage] = None
|
||||||
"""Usage statistics for the embedding request."""
|
"""Usage statistics for the embedding request."""
|
||||||
|
|
||||||
|
_hidden_params: dict = {}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model=None, usage=None, stream=False, response_ms=None, data=None
|
self, model=None, usage=None, stream=False, response_ms=None, data=None
|
||||||
):
|
):
|
||||||
|
@ -640,6 +642,8 @@ class ImageResponse(OpenAIObject):
|
||||||
|
|
||||||
usage: Optional[dict] = None
|
usage: Optional[dict] = None
|
||||||
|
|
||||||
|
_hidden_params: dict = {}
|
||||||
|
|
||||||
def __init__(self, created=None, data=None, response_ms=None):
|
def __init__(self, created=None, data=None, response_ms=None):
|
||||||
if response_ms:
|
if response_ms:
|
||||||
_response_ms = response_ms
|
_response_ms = response_ms
|
||||||
|
@ -2053,6 +2057,10 @@ def client(original_function):
|
||||||
target=logging_obj.success_handler, args=(result, start_time, end_time)
|
target=logging_obj.success_handler, args=(result, start_time, end_time)
|
||||||
).start()
|
).start()
|
||||||
# RETURN RESULT
|
# RETURN RESULT
|
||||||
|
if hasattr(result, "_hidden_params"):
|
||||||
|
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
||||||
|
"id", None
|
||||||
|
)
|
||||||
result._response_ms = (
|
result._response_ms = (
|
||||||
end_time - start_time
|
end_time - start_time
|
||||||
).total_seconds() * 1000 # return response latency in ms like openai
|
).total_seconds() * 1000 # return response latency in ms like openai
|
||||||
|
@ -2273,6 +2281,10 @@ def client(original_function):
|
||||||
target=logging_obj.success_handler, args=(result, start_time, end_time)
|
target=logging_obj.success_handler, args=(result, start_time, end_time)
|
||||||
).start()
|
).start()
|
||||||
# RETURN RESULT
|
# RETURN RESULT
|
||||||
|
if hasattr(result, "_hidden_params"):
|
||||||
|
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
||||||
|
"id", None
|
||||||
|
)
|
||||||
if isinstance(result, ModelResponse):
|
if isinstance(result, ModelResponse):
|
||||||
result._response_ms = (
|
result._response_ms = (
|
||||||
end_time - start_time
|
end_time - start_time
|
||||||
|
@ -6527,6 +6539,13 @@ class CustomStreamWrapper:
|
||||||
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
|
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
|
||||||
self.holding_chunk = ""
|
self.holding_chunk = ""
|
||||||
self.complete_response = ""
|
self.complete_response = ""
|
||||||
|
self._hidden_params = {
|
||||||
|
"model_id": (
|
||||||
|
self.logging_obj.model_call_details.get("litellm_params", {})
|
||||||
|
.get("model_info", {})
|
||||||
|
.get("id", None)
|
||||||
|
)
|
||||||
|
} # returned as x-litellm-model-id response header in proxy
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
@ -7417,6 +7436,15 @@ class CustomStreamWrapper:
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self.logging_obj.success_handler, args=(response,)
|
target=self.logging_obj.success_handler, args=(response,)
|
||||||
).start() # log response
|
).start() # log response
|
||||||
|
# RETURN RESULT
|
||||||
|
if hasattr(response, "_hidden_params"):
|
||||||
|
response._hidden_params["model_id"] = (
|
||||||
|
self.logging_obj.model_call_details.get(
|
||||||
|
"litellm_params", {}
|
||||||
|
)
|
||||||
|
.get("model_info", {})
|
||||||
|
.get("id", None)
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise # Re-raise StopIteration
|
raise # Re-raise StopIteration
|
||||||
|
@ -7467,6 +7495,16 @@ class CustomStreamWrapper:
|
||||||
processed_chunk,
|
processed_chunk,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# RETURN RESULT
|
||||||
|
if hasattr(processed_chunk, "_hidden_params"):
|
||||||
|
model_id = (
|
||||||
|
self.logging_obj.model_call_details.get(
|
||||||
|
"litellm_params", {}
|
||||||
|
)
|
||||||
|
.get("model_info", {})
|
||||||
|
.get("id", None)
|
||||||
|
)
|
||||||
|
processed_chunk._hidden_params["model_id"] = model_id
|
||||||
return processed_chunk
|
return processed_chunk
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
else: # temporary patch for non-aiohttp async calls
|
else: # temporary patch for non-aiohttp async calls
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue