mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(router.py): skip setting model_group response headers for now
current implementation increases redis cache calls by 3x
This commit is contained in:
parent
6e590e4949
commit
bc6ed7a06f
4 changed files with 43 additions and 89 deletions
|
@ -1860,6 +1860,7 @@ class DualCache(BaseCache):
|
||||||
|
|
||||||
Returns - int - the incremented value
|
Returns - int - the incremented value
|
||||||
"""
|
"""
|
||||||
|
traceback.print_stack()
|
||||||
try:
|
try:
|
||||||
result: int = value
|
result: int = value
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
|
|
|
@ -4664,58 +4664,10 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Add the most accurate rate limit headers for a given model response.
|
Add the most accurate rate limit headers for a given model response.
|
||||||
|
|
||||||
- if healthy_deployments > 1, return model group rate limit headers
|
## TODO: add model group rate limit headers
|
||||||
- else return the model's rate limit headers
|
# - if healthy_deployments > 1, return model group rate limit headers
|
||||||
|
# - else return the model's rate limit headers
|
||||||
"""
|
"""
|
||||||
if model_group is None:
|
|
||||||
return response
|
|
||||||
|
|
||||||
healthy_deployments, all_deployments = (
|
|
||||||
await self._async_get_healthy_deployments(model=model_group)
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
||||||
additional_headers = hidden_params.get("additional_headers", {}) or {}
|
|
||||||
|
|
||||||
if len(healthy_deployments) <= 1:
|
|
||||||
return (
|
|
||||||
response # setting response headers is handled in wrappers in utils.py
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# return model group rate limit headers
|
|
||||||
model_group_info = self.get_model_group_info(model_group=model_group)
|
|
||||||
tpm_usage, rpm_usage = await self.get_model_group_usage(
|
|
||||||
model_group=model_group
|
|
||||||
)
|
|
||||||
model_group_remaining_rpm_limit: Optional[int] = None
|
|
||||||
model_group_rpm_limit: Optional[int] = None
|
|
||||||
model_group_remaining_tpm_limit: Optional[int] = None
|
|
||||||
model_group_tpm_limit: Optional[int] = None
|
|
||||||
|
|
||||||
if model_group_info is not None and model_group_info.rpm is not None:
|
|
||||||
model_group_rpm_limit = model_group_info.rpm
|
|
||||||
if rpm_usage is not None:
|
|
||||||
model_group_remaining_rpm_limit = model_group_info.rpm - rpm_usage
|
|
||||||
if model_group_info is not None and model_group_info.tpm is not None:
|
|
||||||
model_group_tpm_limit = model_group_info.tpm
|
|
||||||
if tpm_usage is not None:
|
|
||||||
model_group_remaining_tpm_limit = model_group_info.tpm - tpm_usage
|
|
||||||
|
|
||||||
if model_group_remaining_rpm_limit is not None:
|
|
||||||
additional_headers["x-ratelimit-remaining-requests"] = (
|
|
||||||
model_group_remaining_rpm_limit
|
|
||||||
)
|
|
||||||
if model_group_rpm_limit is not None:
|
|
||||||
additional_headers["x-ratelimit-limit-requests"] = model_group_rpm_limit
|
|
||||||
if model_group_remaining_tpm_limit is not None:
|
|
||||||
additional_headers["x-ratelimit-remaining-tokens"] = (
|
|
||||||
model_group_remaining_tpm_limit
|
|
||||||
)
|
|
||||||
if model_group_tpm_limit is not None:
|
|
||||||
additional_headers["x-ratelimit-limit-tokens"] = model_group_tpm_limit
|
|
||||||
|
|
||||||
hidden_params["additional_headers"] = additional_headers
|
|
||||||
setattr(response, "_hidden_params", hidden_params)
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
|
def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
|
||||||
|
|
|
@ -3069,6 +3069,7 @@ def test_completion_azure():
|
||||||
api_key="os.environ/AZURE_API_KEY",
|
api_key="os.environ/AZURE_API_KEY",
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
print(f"response hidden params: {response._hidden_params}")
|
||||||
## Test azure flag for backwards-compat
|
## Test azure flag for backwards-compat
|
||||||
# response = completion(
|
# response = completion(
|
||||||
# model="chatgpt-v-2",
|
# model="chatgpt-v-2",
|
||||||
|
|
|
@ -2568,45 +2568,45 @@ def test_model_group_alias(hidden):
|
||||||
assert len(model_names) == len(_model_list) + 1
|
assert len(model_names) == len(_model_list) + 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("on_error", [True, False])
|
# @pytest.mark.parametrize("on_error", [True, False])
|
||||||
@pytest.mark.asyncio
|
# @pytest.mark.asyncio
|
||||||
async def test_router_response_headers(on_error):
|
# async def test_router_response_headers(on_error):
|
||||||
router = Router(
|
# router = Router(
|
||||||
model_list=[
|
# model_list=[
|
||||||
{
|
# {
|
||||||
"model_name": "gpt-3.5-turbo",
|
# "model_name": "gpt-3.5-turbo",
|
||||||
"litellm_params": {
|
# "litellm_params": {
|
||||||
"model": "azure/chatgpt-v-2",
|
# "model": "azure/chatgpt-v-2",
|
||||||
"api_key": os.getenv("AZURE_API_KEY"),
|
# "api_key": os.getenv("AZURE_API_KEY"),
|
||||||
"api_base": os.getenv("AZURE_API_BASE"),
|
# "api_base": os.getenv("AZURE_API_BASE"),
|
||||||
"tpm": 100000,
|
# "tpm": 100000,
|
||||||
"rpm": 100000,
|
# "rpm": 100000,
|
||||||
},
|
# },
|
||||||
},
|
# },
|
||||||
{
|
# {
|
||||||
"model_name": "gpt-3.5-turbo",
|
# "model_name": "gpt-3.5-turbo",
|
||||||
"litellm_params": {
|
# "litellm_params": {
|
||||||
"model": "azure/chatgpt-v-2",
|
# "model": "azure/chatgpt-v-2",
|
||||||
"api_key": os.getenv("AZURE_API_KEY"),
|
# "api_key": os.getenv("AZURE_API_KEY"),
|
||||||
"api_base": os.getenv("AZURE_API_BASE"),
|
# "api_base": os.getenv("AZURE_API_BASE"),
|
||||||
"tpm": 500,
|
# "tpm": 500,
|
||||||
"rpm": 500,
|
# "rpm": 500,
|
||||||
},
|
# },
|
||||||
},
|
# },
|
||||||
]
|
# ]
|
||||||
)
|
# )
|
||||||
|
|
||||||
response = await router.acompletion(
|
# response = await router.acompletion(
|
||||||
model="gpt-3.5-turbo",
|
# model="gpt-3.5-turbo",
|
||||||
messages=[{"role": "user", "content": "Hello world!"}],
|
# messages=[{"role": "user", "content": "Hello world!"}],
|
||||||
mock_testing_rate_limit_error=on_error,
|
# mock_testing_rate_limit_error=on_error,
|
||||||
)
|
# )
|
||||||
|
|
||||||
response_headers = response._hidden_params["additional_headers"]
|
# response_headers = response._hidden_params["additional_headers"]
|
||||||
|
|
||||||
print(response_headers)
|
# print(response_headers)
|
||||||
|
|
||||||
assert response_headers["x-ratelimit-limit-requests"] == 100500
|
# assert response_headers["x-ratelimit-limit-requests"] == 100500
|
||||||
assert int(response_headers["x-ratelimit-remaining-requests"]) > 0
|
# assert int(response_headers["x-ratelimit-remaining-requests"]) > 0
|
||||||
assert response_headers["x-ratelimit-limit-tokens"] == 100500
|
# assert response_headers["x-ratelimit-limit-tokens"] == 100500
|
||||||
assert int(response_headers["x-ratelimit-remaining-tokens"]) > 0
|
# assert int(response_headers["x-ratelimit-remaining-tokens"]) > 0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue