Support temporary budget increases on keys (#7754)

* fix(gpt_transformation.py): fix response_format translation check for 4o models

Fixes https://github.com/BerriAI/litellm/issues/7616

* feat(key_management_endpoints.py): support 'temp_budget_increase' and 'temp_budget_expiry' fields

Allow proxy admin to grant temporary budget increases to keys

* fix(proxy/_types.py): enforce temp_budget_increase and temp_budget_expiry are always passed together

* feat(user_api_key_auth.py): initial working temp budget increase logic

ensures key budget exceeded error checks for temp budget in key metadata

* feat(proxy_server.py): return the key max budget and key spend in the response headers

Allows clientside user to know their remaining limits

* test: add unit testing for new proxy utils

Ensures new key budget is correctly handled

* docs(temporary_budget_increase.md): add doc on temporary budget increase

* fix(utils.py): remove 3.5 from response_format check for now

not all azure  3.5 models support response_format

* fix(user_api_key_auth.py): return valid user api key auth object on all paths
This commit is contained in:
Krish Dholakia 2025-01-14 17:03:11 -08:00 committed by GitHub
parent 000d3152a8
commit d7a13ad561
11 changed files with 259 additions and 52 deletions

View file

@ -0,0 +1,74 @@
# ✨ Temporary Budget Increase
Set temporary budget increase for a LiteLLM Virtual Key. Use this if you get asked to increase the budget for a key temporarily.
| Heirarchy | Supported |
|-----------|-----------|
| LiteLLM Virtual Key | ✅ |
| User | ❌ |
| Team | ❌ |
| Organization | ❌ |
:::note
✨ Temporary Budget Increase is a LiteLLM Enterprise feature.
[Enterprise Pricing](https://www.litellm.ai/#pricing)
[Get free 7-day trial key](https://www.litellm.ai/#trial)
:::
1. Create a LiteLLM Virtual Key with budget
```bash
curl -L -X POST 'http://localhost:4000/key/generate' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer LITELLM_MASTER_KEY' \
-d '{
"max_budget": 0.0000001
}'
```
Expected response:
```json
{
"key": "sk-your-new-key"
}
```
2. Update key with temporary budget increase
```bash
curl -L -X POST 'http://localhost:4000/key/update' \
-H 'Authorization: Bearer LITELLM_MASTER_KEY' \
-H 'Content-Type: application/json' \
-d '{
"key": "sk-your-new-key",
"temp_budget_increase": 100,
"temp_budget_expiry": "2025-01-15"
}'
```
3. Test it!
```bash
curl -L -X POST 'http://localhost:4000/chat/completions' \
-H 'Authorization: Bearer sk-your-new-key' \
-H 'Content-Type: application/json' \
-d '{
"model": "gpt-4o",
"messages": [{"role": "user", "content": "Hello, world!"}]
}'
```
Expected Response Header:
```
x-litellm-key-max-budget: 100.0000001
```

View file

@ -107,7 +107,7 @@ const sidebars = {
{ {
type: "category", type: "category",
label: "Budgets + Rate Limits", label: "Budgets + Rate Limits",
items: ["proxy/users", "proxy/rate_limit_tiers", "proxy/team_budgets", "proxy/customers"], items: ["proxy/users", "proxy/temporary_budget_increase", "proxy/rate_limit_tiers", "proxy/team_budgets", "proxy/customers"],
}, },
{ {
type: "link", type: "link",

View file

@ -8,6 +8,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
) )
from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.utils import ModelResponse from litellm.types.utils import ModelResponse
from litellm.utils import supports_response_schema
from ....exceptions import UnsupportedParamsError from ....exceptions import UnsupportedParamsError
from ....types.llms.openai import ( from ....types.llms.openai import (
@ -105,6 +106,19 @@ class AzureOpenAIConfig(BaseConfig):
"parallel_tool_calls", "parallel_tool_calls",
] ]
def _is_response_format_supported_model(self, model: str) -> bool:
"""
- all 4o models are supported
- check if 'supports_response_format' is True from get_model_info
- [TODO] support smart retries for 3.5 models (some supported, some not)
"""
if "4o" in model:
return True
elif supports_response_schema(model):
return True
return False
def map_openai_params( def map_openai_params(
self, self,
non_default_params: dict, non_default_params: dict,
@ -176,10 +190,14 @@ class AzureOpenAIConfig(BaseConfig):
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective. - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
""" """
_is_response_format_supported_model = (
self._is_response_format_supported_model(model)
)
if json_schema is not None and ( if json_schema is not None and (
(api_version_year <= "2024" and api_version_month < "08") (api_version_year <= "2024" and api_version_month < "08")
or "gpt-4o" not in model or not _is_response_format_supported_model
): # azure api version "2024-08-01-preview" onwards supports 'json_schema' only for gpt-4o ): # azure api version "2024-08-01-preview" onwards supports 'json_schema' only for gpt-4o/3.5 models
_tool_choice = ChatCompletionToolChoiceObjectParam( _tool_choice = ChatCompletionToolChoiceObjectParam(
type="function", type="function",
function=ChatCompletionToolChoiceFunctionParam( function=ChatCompletionToolChoiceFunctionParam(

View file

@ -1,43 +1,7 @@
model_list: model_list:
# At least one model must exist for the proxy to start. - model_name: "gpt-4o"
- model_name: gpt-4o
litellm_params: litellm_params:
model: gpt-4o model: "azure/gpt-4o"
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/AZURE_API_KEY
# timeout: 0.1 # timeout in (seconds) api_base: os.environ/AZURE_API_BASE
# stream_timeout: 0.01 # timeout for stream requests (seconds)
- model_name: anthropic.claude-3-5-sonnet-20241022-v2:0
litellm_params:
model: bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0
- model_name: nova-lite
litellm_params:
model: bedrock/us.amazon.nova-lite-v1:0
- model_name: llama3-2-11b-instruct-v1:0
litellm_params:
model: bedrock/us.meta.llama3-2-11b-instruct-v1:0
- model_name: gpt-4o-bad
litellm_params:
model: gpt-4o
api_key: bad
- model_name: "bedrock/*"
litellm_params:
model: "bedrock/*"
- model_name: "openai/*"
litellm_params:
model: "openai/*"
api_key: os.environ/OPENAI_API_KEY
general_settings:
store_model_in_db: true
disable_prisma_schema_update: true
# master_key: os.environ/LITELLM_MASTER_KEY
litellm_settings:
fallbacks: [{"gpt-4o-bad": ["gpt-4o"]}] #, {"gpt-4o": ["nova-lite"]}]
request_timeout: 600 # raise Timeout error if call takes longer than 600 seconds. Default value is 6000seconds if not set
# set_verbose: false # Switch off Debug Logging, ensure your logs do not have any debugging on
# json_logs: true # Get debug logs in json format
ssl_verify: true
callbacks: ["prometheus"]
service_callback: ["prometheus_system"]
turn_off_message_logging: true # turn off messages in otel
#callbacks: ["langfuse"]
redact_user_api_key_info: true

View file

@ -688,6 +688,17 @@ class UpdateKeyRequest(KeyRequestBase):
duration: Optional[str] = None duration: Optional[str] = None
spend: Optional[float] = None spend: Optional[float] = None
metadata: Optional[dict] = None metadata: Optional[dict] = None
temp_budget_increase: Optional[float] = None
temp_budget_expiry: Optional[datetime] = None
@model_validator(mode="after")
def validate_temp_budget(self) -> "UpdateKeyRequest":
if self.temp_budget_increase is not None or self.temp_budget_expiry is not None:
if self.temp_budget_increase is None or self.temp_budget_expiry is None:
raise ValueError(
"temp_budget_increase and temp_budget_expiry must be set together"
)
return self
class RegenerateKeyRequest(GenerateKeyRequest): class RegenerateKeyRequest(GenerateKeyRequest):
@ -2229,6 +2240,8 @@ LiteLLM_ManagementEndpoint_MetadataFields = [
"guardrails", "guardrails",
"tags", "tags",
"enforced_params", "enforced_params",
"temp_budget_increase",
"temp_budget_expiry",
] ]

View file

@ -811,7 +811,10 @@ async def user_api_key_auth( # noqa: PLR0915
valid_token.allowed_model_region = end_user_params.get( valid_token.allowed_model_region = end_user_params.get(
"allowed_model_region" "allowed_model_region"
) )
# update key budget with temp budget increase
valid_token = _update_key_budget_with_temp_budget_increase(
valid_token
) # updating it here, allows all downstream reporting / checks to use the updated budget
except Exception: except Exception:
verbose_logger.info( verbose_logger.info(
"litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to find token={} in cache or `LiteLLM_VerificationTokenTable`. Defaulting 'valid_token' to None'".format( "litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to find token={} in cache or `LiteLLM_VerificationTokenTable`. Defaulting 'valid_token' to None'".format(
@ -1016,6 +1019,7 @@ async def user_api_key_auth( # noqa: PLR0915
current_cost=valid_token.spend, current_cost=valid_token.spend,
max_budget=valid_token.max_budget, max_budget=valid_token.max_budget,
) )
if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget: if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Crossed Soft Budget for token %s, spend %s, soft_budget %s", "Crossed Soft Budget for token %s, spend %s, soft_budget %s",
@ -1383,3 +1387,25 @@ def get_api_key_from_custom_header(
f"No LiteLLM Virtual Key pass. Please set header={custom_litellm_key_header_name}: Bearer <api_key>" f"No LiteLLM Virtual Key pass. Please set header={custom_litellm_key_header_name}: Bearer <api_key>"
) )
return api_key return api_key
def _get_temp_budget_increase(valid_token: UserAPIKeyAuth):
valid_token_metadata = valid_token.metadata
if (
"temp_budget_increase" in valid_token_metadata
and "temp_budget_expiry" in valid_token_metadata
):
expiry = datetime.fromisoformat(valid_token_metadata["temp_budget_expiry"])
if expiry > datetime.now():
return valid_token_metadata["temp_budget_increase"]
return None
def _update_key_budget_with_temp_budget_increase(
valid_token: UserAPIKeyAuth,
) -> UserAPIKeyAuth:
if valid_token.max_budget is None:
return valid_token
temp_budget_increase = _get_temp_budget_increase(valid_token) or 0.0
valid_token.max_budget = valid_token.max_budget + temp_budget_increase
return valid_token

View file

@ -558,7 +558,10 @@ def prepare_metadata_fields(
try: try:
for k, v in data_json.items(): for k, v in data_json.items():
if k in LiteLLM_ManagementEndpoint_MetadataFields: if k in LiteLLM_ManagementEndpoint_MetadataFields:
casted_metadata[k] = v if isinstance(v, datetime):
casted_metadata[k] = v.isoformat()
else:
casted_metadata[k] = v
except Exception as e: except Exception as e:
verbose_proxy_logger.exception( verbose_proxy_logger.exception(
@ -658,6 +661,8 @@ async def update_key_fn(
- blocked: Optional[bool] - Whether the key is blocked - blocked: Optional[bool] - Whether the key is blocked
- aliases: Optional[dict] - Model aliases for the key - [Docs](https://litellm.vercel.app/docs/proxy/virtual_keys#model-aliases) - aliases: Optional[dict] - Model aliases for the key - [Docs](https://litellm.vercel.app/docs/proxy/virtual_keys#model-aliases)
- config: Optional[dict] - [DEPRECATED PARAM] Key-specific config. - config: Optional[dict] - [DEPRECATED PARAM] Key-specific config.
- temp_budget_increase: Optional[float] - Temporary budget increase for the key (Enterprise only).
- temp_budget_expiry: Optional[str] - Expiry time for the temporary budget increase (Enterprise only).
Example: Example:
```bash ```bash
@ -707,9 +712,8 @@ async def update_key_fn(
existing_key_token=existing_key_row.token, existing_key_token=existing_key_row.token,
) )
response = await prisma_client.update_data( _data = {**non_default_values, "token": key}
token=key, data={**non_default_values, "token": key} response = await prisma_client.update_data(token=key, data=_data)
)
# Delete - key from cache, since it's been updated! # Delete - key from cache, since it's been updated!
# key updated - a new model could have been added to this key. it should not block requests after this is done # key updated - a new model could have been added to this key. it should not block requests after this is done

View file

@ -747,6 +747,8 @@ def get_custom_headers(
"x-litellm-response-cost": str(response_cost), "x-litellm-response-cost": str(response_cost),
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit), "x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit), "x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
"x-litellm-key-max-budget": str(user_api_key_dict.max_budget),
"x-litellm-key-spend": str(user_api_key_dict.spend),
"x-litellm-fastest_response_batch_completion": ( "x-litellm-fastest_response_batch_completion": (
str(fastest_response_batch_completion) str(fastest_response_batch_completion)
if fastest_response_batch_completion is not None if fastest_response_batch_completion is not None

View file

@ -1732,9 +1732,15 @@ def supports_response_schema(
Does not raise error. Defaults to 'False'. Outputs logging.error. Does not raise error. Defaults to 'False'. Outputs logging.error.
""" """
## GET LLM PROVIDER ## ## GET LLM PROVIDER ##
model, custom_llm_provider, _, _ = get_llm_provider( try:
model=model, custom_llm_provider=custom_llm_provider model, custom_llm_provider, _, _ = get_llm_provider(
) model=model, custom_llm_provider=custom_llm_provider
)
except Exception as e:
verbose_logger.debug(
f"Model not found or error in checking response schema support. You passed model={model}, custom_llm_provider={custom_llm_provider}. Error: {str(e)}"
)
return False
# providers that globally support response schema # providers that globally support response schema
PROVIDERS_GLOBALLY_SUPPORT_RESPONSE_SCHEMA = [ PROVIDERS_GLOBALLY_SUPPORT_RESPONSE_SCHEMA = [

View file

@ -248,3 +248,38 @@ def test_get_azure_ad_token_from_username_password(
# Verify the result is the mock token provider # Verify the result is the mock token provider
assert result == mock_token_provider assert result == mock_token_provider
def test_azure_openai_gpt_4o_naming(monkeypatch):
from openai import AzureOpenAI
from pydantic import BaseModel, Field
monkeypatch.setenv("AZURE_API_VERSION", "2024-10-21")
client = AzureOpenAI(
api_key="test-api-key",
base_url="https://my-endpoint-sweden-berri992.openai.azure.com",
api_version="2023-12-01-preview",
)
class ResponseFormat(BaseModel):
number: str = Field(description="total number of days in a week")
days: list[str] = Field(description="name of days in a week")
with patch.object(client.chat.completions.with_raw_response, "create") as mock_post:
try:
completion(
model="azure/gpt4o",
messages=[{"role": "user", "content": "Hello world"}],
response_format=ResponseFormat,
client=client,
)
except Exception as e:
print(e)
mock_post.assert_called_once()
print(mock_post.call_args.kwargs)
assert "tool_calls" not in mock_post.call_args.kwargs

View file

@ -1382,3 +1382,68 @@ def test_custom_openid_response():
jwt_handler=jwt_handler, jwt_handler=jwt_handler,
) )
assert resp.team_ids == ["/test-group"] assert resp.team_ids == ["/test-group"]
def test_update_key_request_validation():
"""
Ensures that the UpdateKeyRequest model validates the temp_budget_increase and temp_budget_expiry fields together
"""
from litellm.proxy._types import UpdateKeyRequest
with pytest.raises(Exception):
UpdateKeyRequest(
key="test_key",
temp_budget_increase=100,
)
with pytest.raises(Exception):
UpdateKeyRequest(
key="test_key",
temp_budget_expiry="2024-01-20T00:00:00Z",
)
UpdateKeyRequest(
key="test_key",
temp_budget_increase=100,
temp_budget_expiry="2024-01-20T00:00:00Z",
)
def test_get_temp_budget_increase():
from litellm.proxy.auth.user_api_key_auth import _get_temp_budget_increase
from litellm.proxy._types import UserAPIKeyAuth
from datetime import datetime, timedelta
expiry = datetime.now() + timedelta(days=1)
expiry_in_isoformat = expiry.isoformat()
valid_token = UserAPIKeyAuth(
max_budget=100,
spend=0,
metadata={
"temp_budget_increase": 100,
"temp_budget_expiry": expiry_in_isoformat,
},
)
assert _get_temp_budget_increase(valid_token) == 100
def test_update_key_budget_with_temp_budget_increase():
from litellm.proxy.auth.user_api_key_auth import (
_update_key_budget_with_temp_budget_increase,
)
from litellm.proxy._types import UserAPIKeyAuth
from datetime import datetime, timedelta
expiry = datetime.now() + timedelta(days=1)
expiry_in_isoformat = expiry.isoformat()
valid_token = UserAPIKeyAuth(
max_budget=100,
spend=0,
metadata={
"temp_budget_increase": 100,
"temp_budget_expiry": expiry_in_isoformat,
},
)
assert _update_key_budget_with_temp_budget_increase(valid_token).max_budget == 200