From fe0f9213af28237848d553f4c532e1ad57710b91 Mon Sep 17 00:00:00 2001 From: Steve Farthing Date: Mon, 27 Jan 2025 08:58:04 -0500 Subject: [PATCH 001/144] Bing Search Pass Thru --- litellm/proxy/_types.py | 1 + litellm/proxy/auth/user_api_key_auth.py | 10 +++ .../pass_through_endpoints.py | 31 ++++++++- .../test_pass_through_endpoints.py | 69 +++++++++++++++++++ 4 files changed, 110 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index e68d92cee6..831bd21f9c 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2175,6 +2175,7 @@ class SpecialHeaders(enum.Enum): azure_authorization = "API-Key" anthropic_authorization = "x-api-key" google_ai_studio_authorization = "x-goog-api-key" + bing_search_authorization = "Ocp-Apim-Subscription-Key" class LitellmDataForBackendLLMCall(TypedDict, total=False): diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 33247308f6..6b69aefd4f 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -78,6 +78,11 @@ google_ai_studio_api_key_header = APIKeyHeader( auto_error=False, description="If google ai studio client used.", ) +bing_search_header = APIKeyHeader( + name=SpecialHeaders.bing_search_authorization.value, + auto_error=False, + description="Custom header for Bing Search requests", +) def _get_bearer_token( @@ -451,6 +456,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 azure_api_key_header: str, anthropic_api_key_header: Optional[str], google_ai_studio_api_key_header: Optional[str], + bing_search_header: Optional[str], request_data: dict, ) -> UserAPIKeyAuth: @@ -494,6 +500,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 api_key = anthropic_api_key_header elif isinstance(google_ai_studio_api_key_header, str): api_key = google_ai_studio_api_key_header + elif isinstance(bing_search_header, str): + api_key = bing_search_header elif pass_through_endpoints is not None: for endpoint in pass_through_endpoints: if endpoint.get("path", "") == route: @@ -1317,6 +1325,7 @@ async def user_api_key_auth( google_ai_studio_api_key_header: Optional[str] = fastapi.Security( google_ai_studio_api_key_header ), + bing_search_header: Optional[str] = fastapi.Security(bing_search_header), ) -> UserAPIKeyAuth: """ Parent function to authenticate user api key / jwt token. @@ -1330,6 +1339,7 @@ async def user_api_key_auth( azure_api_key_header=azure_api_key_header, anthropic_api_key_header=anthropic_api_key_header, google_ai_studio_api_key_header=google_ai_studio_api_key_header, + bing_search_header=bing_search_header, request_data=request_data, ) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 970af05f6d..fcbdfc1fc6 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -4,6 +4,7 @@ import json from base64 import b64encode from datetime import datetime from typing import List, Optional +from urllib.parse import urlencode, parse_qs import httpx from fastapi import APIRouter, Depends, HTTPException, Request, Response, status @@ -310,6 +311,7 @@ async def pass_through_request( # noqa: PLR0915 user_api_key_dict: UserAPIKeyAuth, custom_body: Optional[dict] = None, forward_headers: Optional[bool] = False, + merge_query_params: Optional[bool] = False, query_params: Optional[dict] = None, stream: Optional[bool] = None, ): @@ -325,6 +327,25 @@ async def pass_through_request( # noqa: PLR0915 request=request, headers=headers, forward_headers=forward_headers ) + if merge_query_params: + # Get the query params from the request + request_query_params = dict(request.query_params) + + # Get the existing query params from the target URL + existing_query_string = url.query.decode("utf-8") + existing_query_params = parse_qs(existing_query_string) + + # parse_qs returns a dict where each value is a list, so let's flatten it + existing_query_params = { + k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items() + } + + # Merge the query params, giving priority to the existing ones + merged_query_params = {**request_query_params, **existing_query_params} + + # Create a new URL with the merged query params + url = url.copy_with(query=urlencode(merged_query_params).encode("ascii")) + endpoint_type: EndpointType = get_endpoint_type(str(url)) _parsed_body = None @@ -604,6 +625,7 @@ def create_pass_through_route( target: str, custom_headers: Optional[dict] = None, _forward_headers: Optional[bool] = False, + _merge_query_params: Optional[bool] = False, dependencies: Optional[List] = None, ): # check if target is an adapter.py or a url @@ -650,6 +672,7 @@ def create_pass_through_route( custom_headers=custom_headers or {}, user_api_key_dict=user_api_key_dict, forward_headers=_forward_headers, + merge_query_params=_merge_query_params, query_params=query_params, stream=stream, custom_body=custom_body, @@ -679,6 +702,7 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): custom_headers=_custom_headers ) _forward_headers = endpoint.get("forward_headers", None) + _merge_query_params = endpoint.get("merge_query_params", None) _auth = endpoint.get("auth", None) _dependencies = None if _auth is not None and str(_auth).lower() == "true": @@ -700,7 +724,12 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): app.add_api_route( # type: ignore path=_path, endpoint=create_pass_through_route( # type: ignore - _path, _target, _custom_headers, _forward_headers, _dependencies + _path, + _target, + _custom_headers, + _forward_headers, + _merge_query_params, + _dependencies, ), methods=["GET", "POST", "PUT", "DELETE", "PATCH"], dependencies=_dependencies, diff --git a/tests/local_testing/test_pass_through_endpoints.py b/tests/local_testing/test_pass_through_endpoints.py index 7e9dfcfc79..8914b9877e 100644 --- a/tests/local_testing/test_pass_through_endpoints.py +++ b/tests/local_testing/test_pass_through_endpoints.py @@ -383,3 +383,72 @@ async def test_pass_through_endpoint_anthropic(client): # Assert the response assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_pass_through_endpoint_bing(client, monkeypatch): + import litellm + + captured_requests = [] + + async def mock_bing_request(*args, **kwargs): + + captured_requests.append((args, kwargs)) + mock_response = httpx.Response( + 200, + json={ + "_type": "SearchResponse", + "queryContext": {"originalQuery": "bob barker"}, + "webPages": { + "webSearchUrl": "https://www.bing.com/search?q=bob+barker", + "totalEstimatedMatches": 12000000, + "value": [], + }, + }, + ) + mock_response.request = Mock(spec=httpx.Request) + return mock_response + + monkeypatch.setattr("httpx.AsyncClient.request", mock_bing_request) + + # Define a pass-through endpoint + pass_through_endpoints = [ + { + "path": "/bing/search", + "target": "https://api.bing.microsoft.com/v7.0/search?setLang=en-US&mkt=en-US", + "headers": {"Ocp-Apim-Subscription-Key": "XX"}, + "forward_headers": True, + # Additional settings + "merge_query_params": True, + "auth": True, + }, + { + "path": "/bing/search-no-merge-params", + "target": "https://api.bing.microsoft.com/v7.0/search?setLang=en-US&mkt=en-US", + "headers": {"Ocp-Apim-Subscription-Key": "XX"}, + "forward_headers": True, + }, + ] + + # Initialize the pass-through endpoint + await initialize_pass_through_endpoints(pass_through_endpoints) + general_settings: Optional[dict] = ( + getattr(litellm.proxy.proxy_server, "general_settings", {}) or {} + ) + general_settings.update({"pass_through_endpoints": pass_through_endpoints}) + setattr(litellm.proxy.proxy_server, "general_settings", general_settings) + + # Make 2 requests thru the pass-through endpoint + client.get("/bing/search?q=bob+barker") + client.get("/bing/search-no-merge-params?q=bob+barker") + + first_transformed_url = captured_requests[0][1]["url"] + second_transformed_url = captured_requests[1][1]["url"] + + # Assert the response + assert ( + first_transformed_url + == "https://api.bing.microsoft.com/v7.0/search?q=bob+barker&setLang=en-US&mkt=en-US" + and second_transformed_url + == "https://api.bing.microsoft.com/v7.0/search?setLang=en-US&mkt=en-US" + ) From 931b44ac1f1d5e920a7487bb7c7b760d212ef37b Mon Sep 17 00:00:00 2001 From: superpoussin22 Date: Sat, 1 Feb 2025 21:23:39 +0100 Subject: [PATCH 002/144] Update bedrock.md remove CUSTOM_ for consistency --- docs/my-website/docs/providers/bedrock.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md index ad2124676f..720b3e8f5d 100644 --- a/docs/my-website/docs/providers/bedrock.md +++ b/docs/my-website/docs/providers/bedrock.md @@ -62,9 +62,9 @@ model_list: - model_name: bedrock-claude-v1 litellm_params: model: bedrock/anthropic.claude-instant-v1 - aws_access_key_id: os.environ/CUSTOM_AWS_ACCESS_KEY_ID - aws_secret_access_key: os.environ/CUSTOM_AWS_SECRET_ACCESS_KEY - aws_region_name: os.environ/CUSTOM_AWS_REGION_NAME + aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID + aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY + aws_region_name: os.environ/AWS_REGION_NAME ``` All possible auth params: @@ -1553,6 +1553,8 @@ curl http://0.0.0.0:4000/rerank \ "Capital punishment has existed in the United States since before it was a country." ], "top_n": 3 + + }' ``` From 9724ee94df144ff338b1f9750545c7858c8eab19 Mon Sep 17 00:00:00 2001 From: Steve Farthing Date: Tue, 4 Feb 2025 21:11:19 -0500 Subject: [PATCH 003/144] Feedback --- litellm/proxy/_types.py | 2 +- litellm/proxy/auth/user_api_key_auth.py | 16 ++++---- .../pass_through_endpoints.py | 41 +++++++++++-------- 3 files changed, 34 insertions(+), 25 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 831bd21f9c..9000c17426 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2175,7 +2175,7 @@ class SpecialHeaders(enum.Enum): azure_authorization = "API-Key" anthropic_authorization = "x-api-key" google_ai_studio_authorization = "x-goog-api-key" - bing_search_authorization = "Ocp-Apim-Subscription-Key" + azure_apim_authorization = "Ocp-Apim-Subscription-Key" class LitellmDataForBackendLLMCall(TypedDict, total=False): diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 6b69aefd4f..b8a3a4d847 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -78,10 +78,10 @@ google_ai_studio_api_key_header = APIKeyHeader( auto_error=False, description="If google ai studio client used.", ) -bing_search_header = APIKeyHeader( - name=SpecialHeaders.bing_search_authorization.value, +azure_apim_header = APIKeyHeader( + name=SpecialHeaders.azure_apim_authorization.value, auto_error=False, - description="Custom header for Bing Search requests", + description="The default name of the subscription key header of Azure", ) @@ -456,7 +456,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 azure_api_key_header: str, anthropic_api_key_header: Optional[str], google_ai_studio_api_key_header: Optional[str], - bing_search_header: Optional[str], + azure_apim_header: Optional[str], request_data: dict, ) -> UserAPIKeyAuth: @@ -500,8 +500,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 api_key = anthropic_api_key_header elif isinstance(google_ai_studio_api_key_header, str): api_key = google_ai_studio_api_key_header - elif isinstance(bing_search_header, str): - api_key = bing_search_header + elif isinstance(azure_apim_header, str): + api_key = azure_apim_header elif pass_through_endpoints is not None: for endpoint in pass_through_endpoints: if endpoint.get("path", "") == route: @@ -1325,7 +1325,7 @@ async def user_api_key_auth( google_ai_studio_api_key_header: Optional[str] = fastapi.Security( google_ai_studio_api_key_header ), - bing_search_header: Optional[str] = fastapi.Security(bing_search_header), + azure_apim_header: Optional[str] = fastapi.Security(azure_apim_header), ) -> UserAPIKeyAuth: """ Parent function to authenticate user api key / jwt token. @@ -1339,7 +1339,7 @@ async def user_api_key_auth( azure_api_key_header=azure_api_key_header, anthropic_api_key_header=anthropic_api_key_header, google_ai_studio_api_key_header=google_ai_studio_api_key_header, - bing_search_header=bing_search_header, + azure_apim_header=azure_apim_header, request_data=request_data, ) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index fcbdfc1fc6..e919cb1a60 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -3,7 +3,7 @@ import asyncio import json from base64 import b64encode from datetime import datetime -from typing import List, Optional +from typing import List, Optional, Union, Dict from urllib.parse import urlencode, parse_qs import httpx @@ -296,6 +296,22 @@ def get_response_headers( return return_headers +def get_merged_query_parameters( + existing_url: httpx.URL, request_query_params: Dict[str, Union[str, list]] +) -> Dict[str, Union[str, List[str]]]: + # Get the existing query params from the target URL + existing_query_string = existing_url.query.decode("utf-8") + existing_query_params = parse_qs(existing_query_string) + + # parse_qs returns a dict where each value is a list, so let's flatten it + existing_query_params = { + k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items() + } + + # Merge the query params, giving priority to the existing ones + return {**request_query_params, **existing_query_params} + + def get_endpoint_type(url: str) -> EndpointType: if ("generateContent") in url or ("streamGenerateContent") in url: return EndpointType.VERTEX_AI @@ -328,23 +344,16 @@ async def pass_through_request( # noqa: PLR0915 ) if merge_query_params: - # Get the query params from the request - request_query_params = dict(request.query_params) - - # Get the existing query params from the target URL - existing_query_string = url.query.decode("utf-8") - existing_query_params = parse_qs(existing_query_string) - - # parse_qs returns a dict where each value is a list, so let's flatten it - existing_query_params = { - k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items() - } - - # Merge the query params, giving priority to the existing ones - merged_query_params = {**request_query_params, **existing_query_params} # Create a new URL with the merged query params - url = url.copy_with(query=urlencode(merged_query_params).encode("ascii")) + url = url.copy_with( + query=urlencode( + get_merged_query_parameters( + existing_url=url, + request_query_params=dict(request.query_params), + ) + ).encode("ascii") + ) endpoint_type: EndpointType = get_endpoint_type(str(url)) From 57faa623e3fd3c301f737a2ad79ebd6f8df112fb Mon Sep 17 00:00:00 2001 From: Emerson Gomes Date: Tue, 25 Feb 2025 10:44:10 -0600 Subject: [PATCH 004/144] Adding Azure Phi-4 --- litellm/model_prices_and_context_window_backup.json | 13 +++++++++++++ model_prices_and_context_window.json | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index a11930cc7f..442d5fa776 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -1732,6 +1732,19 @@ "source":"https://azuremarketplace.microsoft.com/en-us/marketplace/apps/metagenai.meta-llama-3-1-405b-instruct-offer?tab=PlansAndPrice", "supports_tool_choice": true }, + "azure_ai/Phi-4": { + "max_tokens": 4096, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.0000005, + "litellm_provider": "azure_ai", + "mode": "chat", + "supports_vision": false, + "source": "https://techcommunity.microsoft.com/blog/machinelearningblog/affordable-innovation-unveiling-the-pricing-of-phi-3-slms-on-models-as-a-service/4156495", + "supports_function_calling": true, + "supports_tool_choice": true + }, "azure_ai/Phi-3.5-mini-instruct": { "max_tokens": 4096, "max_input_tokens": 128000, diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index a11930cc7f..442d5fa776 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -1732,6 +1732,19 @@ "source":"https://azuremarketplace.microsoft.com/en-us/marketplace/apps/metagenai.meta-llama-3-1-405b-instruct-offer?tab=PlansAndPrice", "supports_tool_choice": true }, + "azure_ai/Phi-4": { + "max_tokens": 4096, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.000000125, + "output_cost_per_token": 0.0000005, + "litellm_provider": "azure_ai", + "mode": "chat", + "supports_vision": false, + "source": "https://techcommunity.microsoft.com/blog/machinelearningblog/affordable-innovation-unveiling-the-pricing-of-phi-3-slms-on-models-as-a-service/4156495", + "supports_function_calling": true, + "supports_tool_choice": true + }, "azure_ai/Phi-3.5-mini-instruct": { "max_tokens": 4096, "max_input_tokens": 128000, From eeee61db658de558a914569567c048fb51278c0f Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 25 Feb 2025 14:50:10 -0800 Subject: [PATCH 005/144] can_team_access_model --- litellm/proxy/auth/auth_checks.py | 92 +++++++++++++------------------ 1 file changed, 39 insertions(+), 53 deletions(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 0590bcb50a..c922599f86 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -38,6 +38,7 @@ from litellm.proxy._types import ( ProxyErrorTypes, ProxyException, RoleBasedPermissions, + SpecialModelNames, UserAPIKeyAuth, ) from litellm.proxy.auth.route_checks import RouteChecks @@ -97,12 +98,23 @@ async def common_checks( ) # 2. If team can call model - _team_model_access_check( - team_object=team_object, - model=_model, - llm_router=llm_router, - team_model_aliases=valid_token.team_model_aliases if valid_token else None, - ) + if ( + team_object is not None + and _model is not None + and can_team_access_model( + model=_model, + team_object=team_object, + llm_router=llm_router, + team_model_aliases=valid_token.team_model_aliases if valid_token else None, + ) + is False + ): + raise ProxyException( + message=f"Team not allowed to access model. Team={team_object.team_id}, Model={_model}. Allowed team models = {team_object.models}", + type=ProxyErrorTypes.team_model_access_denied, + param="model", + code=status.HTTP_401_UNAUTHORIZED, + ) ## 2.1 If user can call model (if personal key) if team_object is None and user_object is not None: @@ -1017,6 +1029,9 @@ async def _can_object_call_model( if (len(filtered_models) == 0 and len(models) == 0) or "*" in filtered_models: all_model_access = True + if SpecialModelNames.all_proxy_models in filtered_models: + all_model_access = True + if model is not None and model not in filtered_models and all_model_access is False: raise ProxyException( message=f"API Key not allowed to access model. This token can only access models={models}. Tried to access {model}", @@ -1074,6 +1089,24 @@ async def can_key_call_model( ) +async def can_team_access_model( + model: str, + team_object: Optional[LiteLLM_TeamTable], + llm_router: Optional[Router], + team_model_aliases: Optional[Dict[str, str]] = None, +) -> Literal[True]: + """ + Returns True if the team can access a specific model. + + """ + return await _can_object_call_model( + model=model, + llm_router=llm_router, + models=team_object.models if team_object else [], + team_model_aliases=team_model_aliases, + ) + + async def can_user_call_model( model: str, llm_router: Optional[Router], @@ -1239,53 +1272,6 @@ async def _team_max_budget_check( ) -def _team_model_access_check( - model: Optional[str], - team_object: Optional[LiteLLM_TeamTable], - llm_router: Optional[Router], - team_model_aliases: Optional[Dict[str, str]] = None, -): - """ - Access check for team models - Raises: - Exception if the team is not allowed to call the`model` - """ - if ( - model is not None - and team_object is not None - and team_object.models is not None - and len(team_object.models) > 0 - and model not in team_object.models - ): - # this means the team has access to all models on the proxy - if "all-proxy-models" in team_object.models or "*" in team_object.models: - # this means the team has access to all models on the proxy - pass - # check if the team model is an access_group - elif ( - model_in_access_group( - model=model, team_models=team_object.models, llm_router=llm_router - ) - is True - ): - pass - elif model and "*" in model: - pass - elif _model_in_team_aliases(model=model, team_model_aliases=team_model_aliases): - pass - elif _model_matches_any_wildcard_pattern_in_list( - model=model, allowed_model_list=team_object.models - ): - pass - else: - raise ProxyException( - message=f"Team not allowed to access model. Team={team_object.team_id}, Model={model}. Allowed team models = {team_object.models}", - type=ProxyErrorTypes.team_model_access_denied, - param="model", - code=status.HTTP_401_UNAUTHORIZED, - ) - - def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool: """ Check if a model matches an allowed pattern. From b6d6e270b49e72a49c3f8a6496a8e965e8eb55c6 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 25 Feb 2025 14:51:57 -0800 Subject: [PATCH 006/144] can_team_access_model --- litellm/proxy/auth/handle_jwt.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 29f4b31f9c..248d553662 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -33,6 +33,7 @@ from litellm.proxy._types import ( ScopeMapping, Span, ) +from litellm.proxy.auth.auth_checks import can_team_access_model from litellm.proxy.utils import PrismaClient, ProxyLogging from .auth_checks import ( @@ -723,8 +724,12 @@ class JWTAuthManager: team_models = team_object.models if isinstance(team_models, list) and ( not requested_model - or requested_model in team_models - or "*" in team_models + or can_team_access_model( + model=requested_model, + team_object=team_object, + llm_router=None, + team_model_aliases=None, + ) ): is_allowed = allowed_routes_check( user_role=LitellmUserRoles.TEAM, From 3d0b56e8a34b73baf5057e8b45fd6dcb2a558920 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 25 Feb 2025 14:57:13 -0800 Subject: [PATCH 007/144] test_can_team_access_model --- tests/proxy_unit_tests/test_auth_checks.py | 28 ++++++++-------------- 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/tests/proxy_unit_tests/test_auth_checks.py b/tests/proxy_unit_tests/test_auth_checks.py index 0a8ebbe018..5b79ace1b9 100644 --- a/tests/proxy_unit_tests/test_auth_checks.py +++ b/tests/proxy_unit_tests/test_auth_checks.py @@ -27,7 +27,7 @@ from litellm.proxy._types import ( ) from litellm.proxy.utils import PrismaClient from litellm.proxy.auth.auth_checks import ( - _team_model_access_check, + can_team_access_model, _virtual_key_soft_budget_check, ) from litellm.proxy.utils import ProxyLogging @@ -427,9 +427,9 @@ async def test_virtual_key_max_budget_check( ], ) @pytest.mark.asyncio -async def test_team_model_access_check(model, team_models, expect_to_work): +async def test_can_team_access_model(model, team_models, expected_result): """ - Test cases for _team_model_access_check: + Test cases for can_team_access_model: 1. Exact model match 2. all-proxy-models access 3. Wildcard (*) access @@ -443,21 +443,13 @@ async def test_team_model_access_check(model, team_models, expect_to_work): models=team_models, ) - try: - _team_model_access_check( - model=model, - team_object=team_object, - llm_router=None, - ) - if not expect_to_work: - pytest.fail( - f"Expected model access check to fail for model={model}, team_models={team_models}" - ) - except Exception as e: - if expect_to_work: - pytest.fail( - f"Expected model access check to work for model={model}, team_models={team_models}. Got error: {str(e)}" - ) + result = await can_team_access_model( + model=model, + team_object=team_object, + llm_router=None, + team_model_aliases=None, + ) + assert result == expected_result @pytest.mark.parametrize( From 7eaf0039193363fd562349d08df2a7744646a7f2 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 25 Feb 2025 15:25:51 -0800 Subject: [PATCH 008/144] expected_result --- tests/proxy_unit_tests/test_auth_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/proxy_unit_tests/test_auth_checks.py b/tests/proxy_unit_tests/test_auth_checks.py index 5b79ace1b9..a5782653ad 100644 --- a/tests/proxy_unit_tests/test_auth_checks.py +++ b/tests/proxy_unit_tests/test_auth_checks.py @@ -394,7 +394,7 @@ async def test_virtual_key_max_budget_check( @pytest.mark.parametrize( - "model, team_models, expect_to_work", + "model, team_models, expected_result", [ ("gpt-4", ["gpt-4"], True), # exact match ("gpt-4", ["all-proxy-models"], True), # all-proxy-models access From 5ead81786dea2322b614118d0f4eac7c7843e3f0 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 1 Mar 2025 17:42:50 -0800 Subject: [PATCH 009/144] test_can_team_access_model --- tests/proxy_unit_tests/test_auth_checks.py | 36 +++++++++++++--------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/tests/proxy_unit_tests/test_auth_checks.py b/tests/proxy_unit_tests/test_auth_checks.py index ec36823633..0eb1a38755 100644 --- a/tests/proxy_unit_tests/test_auth_checks.py +++ b/tests/proxy_unit_tests/test_auth_checks.py @@ -394,7 +394,7 @@ async def test_virtual_key_max_budget_check( @pytest.mark.parametrize( - "model, team_models, expected_result", + "model, team_models, expect_to_work", [ ("gpt-4", ["gpt-4"], True), # exact match ("gpt-4", ["all-proxy-models"], True), # all-proxy-models access @@ -427,7 +427,7 @@ async def test_virtual_key_max_budget_check( ], ) @pytest.mark.asyncio -async def test_can_team_access_model(model, team_models, expected_result): +async def test_can_team_access_model(model, team_models, expect_to_work): """ Test cases for can_team_access_model: 1. Exact model match @@ -438,18 +438,26 @@ async def test_can_team_access_model(model, team_models, expected_result): 6. Empty model list 7. None model list """ - team_object = LiteLLM_TeamTable( - team_id="test-team", - models=team_models, - ) - - result = await can_team_access_model( - model=model, - team_object=team_object, - llm_router=None, - team_model_aliases=None, - ) - assert result == expected_result + try: + team_object = LiteLLM_TeamTable( + team_id="test-team", + models=team_models, + ) + result = await can_team_access_model( + model=model, + team_object=team_object, + llm_router=None, + team_model_aliases=None, + ) + if not expect_to_work: + pytest.fail( + f"Expected model access check to fail for model={model}, team_models={team_models}" + ) + except Exception as e: + if expect_to_work: + pytest.fail( + f"Expected model access check to work for model={model}, team_models={team_models}. Got error: {str(e)}" + ) @pytest.mark.parametrize( From fa88bc96328fd8e313632fd80cb962297a056323 Mon Sep 17 00:00:00 2001 From: Utkash Dubey Date: Mon, 3 Mar 2025 04:16:12 -0800 Subject: [PATCH 010/144] changes --- litellm/__init__.py | 2 +- .../litellm_core_utils/get_model_cost_map.py | 21 ++-- model_prices_and_context_window.json | 2 +- tests/code_coverage_tests/bedrock_pricing.py | 3 +- tests/litellm_utils_tests/test_utils.py | 12 +- .../base_embedding_unit_tests.py | 3 +- tests/llm_translation/base_llm_unit_tests.py | 27 ++--- .../llm_translation/base_rerank_unit_tests.py | 3 +- .../test_anthropic_completion.py | 3 +- .../test_bedrock_completion.py | 15 +-- tests/llm_translation/test_openai_o1.py | 9 +- tests/llm_translation/test_rerank.py | 3 +- tests/llm_translation/test_together_ai.py | 3 +- .../test_amazing_vertex_completion.py | 6 +- tests/local_testing/test_completion_cost.py | 70 ++++------- tests/local_testing/test_embedding.py | 6 +- tests/local_testing/test_get_model_info.py | 21 ++-- tests/local_testing/test_router_utils.py | 12 +- ..._model_prices_and_context_window_schema.py | 111 ++++++++++++++++++ 19 files changed, 191 insertions(+), 141 deletions(-) create mode 100644 tests/test_model_prices_and_context_window_schema.py diff --git a/litellm/__init__.py b/litellm/__init__.py index 60b8cf81a0..a3756251d1 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -308,7 +308,7 @@ _key_management_settings: KeyManagementSettings = KeyManagementSettings() #### PII MASKING #### output_parse_pii: bool = False ############################################# -from litellm.litellm_core_utils.get_model_cost_map import get_model_cost_map +from litellm.litellm_core_utils.get_model_cost_map import get_model_cost_map, get_locally_cached_model_cost_map model_cost = get_model_cost_map(url=model_cost_map_url) custom_prompt_dict: Dict[str, dict] = {} diff --git a/litellm/litellm_core_utils/get_model_cost_map.py b/litellm/litellm_core_utils/get_model_cost_map.py index b8bdaee19c..0e14457b2a 100644 --- a/litellm/litellm_core_utils/get_model_cost_map.py +++ b/litellm/litellm_core_utils/get_model_cost_map.py @@ -8,24 +8,29 @@ export LITELLM_LOCAL_MODEL_COST_MAP=True ``` """ +from functools import cache import os import httpx +@cache +def get_locally_cached_model_cost_map(): + import importlib.resources + import json + + with importlib.resources.open_text( + "litellm", "model_prices_and_context_window_backup.json" + ) as f: + content = json.load(f) + return content + def get_model_cost_map(url: str): if ( os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) or os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == "True" ): - import importlib.resources - import json - - with importlib.resources.open_text( - "litellm", "model_prices_and_context_window_backup.json" - ) as f: - content = json.load(f) - return content + return get_locally_cached_model_cost_map() try: response = httpx.get( diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 96076fa3b8..961b55f49b 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -6,7 +6,7 @@ "input_cost_per_token": 0.0000, "output_cost_per_token": 0.000, "litellm_provider": "one of https://docs.litellm.ai/docs/providers", - "mode": "one of chat, embedding, completion, image_generation, audio_transcription, audio_speech", + "mode": "one of: chat, embedding, completion, image_generation, audio_transcription, audio_speech, image_generation, moderation, moderations, rerank", "supports_function_calling": true, "supports_parallel_function_calling": true, "supports_vision": true, diff --git a/tests/code_coverage_tests/bedrock_pricing.py b/tests/code_coverage_tests/bedrock_pricing.py index b2c9e78b06..9984cb8b0e 100644 --- a/tests/code_coverage_tests/bedrock_pricing.py +++ b/tests/code_coverage_tests/bedrock_pricing.py @@ -191,8 +191,7 @@ def _check_if_model_name_in_pricing( input_cost_per_1k_tokens: str, output_cost_per_1k_tokens: str, ): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() for model, value in litellm.model_cost.items(): if model.startswith(bedrock_model_name): diff --git a/tests/litellm_utils_tests/test_utils.py b/tests/litellm_utils_tests/test_utils.py index 2b1e78a681..fd8ad01c8b 100644 --- a/tests/litellm_utils_tests/test_utils.py +++ b/tests/litellm_utils_tests/test_utils.py @@ -907,8 +907,7 @@ def test_supports_response_schema(model, expected_bool): Should be true for gemini-1.5-pro on google ai studio / vertex ai AND predibase models Should be false otherwise """ - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() from litellm.utils import supports_response_schema @@ -1066,8 +1065,7 @@ def test_async_http_handler_force_ipv4(mock_async_client): "model, expected_bool", [("gpt-3.5-turbo", False), ("gpt-4o-audio-preview", True)] ) def test_supports_audio_input(model, expected_bool): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() from litellm.utils import supports_audio_input, supports_audio_output @@ -1165,8 +1163,7 @@ def test_models_by_provider(): """ Make sure all providers from model map are in the valid providers list """ - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() from litellm import models_by_provider @@ -1484,8 +1481,7 @@ def test_get_valid_models_default(monkeypatch): def test_supports_vision_gemini(): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() from litellm.utils import supports_vision assert supports_vision("gemini-1.5-pro") is True diff --git a/tests/llm_translation/base_embedding_unit_tests.py b/tests/llm_translation/base_embedding_unit_tests.py index 30a9dcc0da..1fcc825481 100644 --- a/tests/llm_translation/base_embedding_unit_tests.py +++ b/tests/llm_translation/base_embedding_unit_tests.py @@ -84,8 +84,7 @@ class BaseLLMEmbeddingTest(ABC): litellm.set_verbose = True from litellm.utils import supports_embedding_image_input - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() base_embedding_call_args = self.get_base_embedding_call_args() if not supports_embedding_image_input(base_embedding_call_args["model"], None): diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index f91ef0eae9..eb18cbce90 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -342,8 +342,7 @@ class BaseLLMChatTest(ABC): from pydantic import BaseModel from litellm.utils import supports_response_schema - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() class TestModel(BaseModel): first_response: str @@ -382,16 +381,14 @@ class BaseLLMChatTest(ABC): from pydantic import BaseModel from litellm.utils import supports_response_schema - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() @pytest.mark.flaky(retries=6, delay=1) def test_json_response_nested_pydantic_obj(self): from pydantic import BaseModel from litellm.utils import supports_response_schema - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() class CalendarEvent(BaseModel): name: str @@ -438,8 +435,7 @@ class BaseLLMChatTest(ABC): from litellm.utils import supports_response_schema from litellm.llms.base_llm.base_utils import type_to_response_format_param - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() class CalendarEvent(BaseModel): name: str @@ -560,8 +556,7 @@ class BaseLLMChatTest(ABC): litellm.set_verbose = True from litellm.utils import supports_vision - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() base_completion_call_args = self.get_base_completion_call_args() if not supports_vision(base_completion_call_args["model"], None): @@ -615,8 +610,7 @@ class BaseLLMChatTest(ABC): litellm.set_verbose = True from litellm.utils import supports_vision - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" @@ -656,8 +650,7 @@ class BaseLLMChatTest(ABC): litellm.set_verbose = True from litellm.utils import supports_prompt_caching - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() base_completion_call_args = self.get_base_completion_call_args() if not supports_prompt_caching(base_completion_call_args["model"], None): @@ -773,8 +766,7 @@ class BaseLLMChatTest(ABC): litellm._turn_on_debug() from litellm.utils import supports_function_calling - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() base_completion_call_args = self.get_base_completion_call_args() if not supports_function_calling(base_completion_call_args["model"], None): @@ -872,8 +864,7 @@ class BaseLLMChatTest(ABC): async def test_completion_cost(self): from litellm import completion_cost - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() litellm.set_verbose = True response = await self.async_completion_function( diff --git a/tests/llm_translation/base_rerank_unit_tests.py b/tests/llm_translation/base_rerank_unit_tests.py index cff4a02753..b3f56f7c64 100644 --- a/tests/llm_translation/base_rerank_unit_tests.py +++ b/tests/llm_translation/base_rerank_unit_tests.py @@ -87,8 +87,7 @@ class BaseLLMRerankTest(ABC): @pytest.mark.parametrize("sync_mode", [True, False]) async def test_basic_rerank(self, sync_mode): litellm._turn_on_debug() - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() rerank_call_args = self.get_base_rerank_call_args() custom_llm_provider = self.get_custom_llm_provider() if sync_mode is True: diff --git a/tests/llm_translation/test_anthropic_completion.py b/tests/llm_translation/test_anthropic_completion.py index 37253a37e6..04158b4ab4 100644 --- a/tests/llm_translation/test_anthropic_completion.py +++ b/tests/llm_translation/test_anthropic_completion.py @@ -693,8 +693,7 @@ class TestAnthropicCompletion(BaseLLMChatTest): from pydantic import BaseModel from litellm.utils import supports_response_schema - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() class RFormat(BaseModel): question: str diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index 2fb0ffb9e5..99e4e7ed1a 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -1975,8 +1975,7 @@ def test_bedrock_converse_route(): def test_bedrock_mapped_converse_models(): litellm.set_verbose = True - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() litellm.add_known_models() litellm.completion( model="bedrock/us.amazon.nova-pro-v1:0", @@ -2108,8 +2107,7 @@ def test_bedrock_supports_tool_call(model, expected_supports_tool_call): class TestBedrockConverseChatCrossRegion(BaseLLMChatTest): def get_base_completion_call_args(self) -> dict: - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() litellm.add_known_models() return { "model": "bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0", @@ -2137,8 +2135,7 @@ class TestBedrockConverseChatCrossRegion(BaseLLMChatTest): """ Test if region models info is correctly used for cost calculation. Using the base model info for cost calculation. """ - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() bedrock_model = "us.anthropic.claude-3-5-sonnet-20241022-v2:0" litellm.model_cost.pop(bedrock_model, None) model = f"bedrock/{bedrock_model}" @@ -2155,8 +2152,7 @@ class TestBedrockConverseChatCrossRegion(BaseLLMChatTest): class TestBedrockConverseChatNormal(BaseLLMChatTest): def get_base_completion_call_args(self) -> dict: - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() litellm.add_known_models() return { "model": "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", @@ -2325,8 +2321,7 @@ def test_bedrock_nova_topk(top_k_param): def test_bedrock_cross_region_inference(monkeypatch): from litellm.llms.custom_httpx.http_handler import HTTPHandler - monkeypatch.setenv("LITELLM_LOCAL_MODEL_COST_MAP", "True") - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() litellm.add_known_models() litellm.set_verbose = True diff --git a/tests/llm_translation/test_openai_o1.py b/tests/llm_translation/test_openai_o1.py index 4208f1ae38..bcd7648f2a 100644 --- a/tests/llm_translation/test_openai_o1.py +++ b/tests/llm_translation/test_openai_o1.py @@ -29,8 +29,7 @@ async def test_o1_handle_system_role(model): from openai import AsyncOpenAI from litellm.utils import supports_system_messages - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() litellm.set_verbose = True @@ -83,8 +82,7 @@ async def test_o1_handle_tool_calling_optional_params( from litellm.utils import ProviderConfigManager from litellm.types.utils import LlmProviders - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() config = ProviderConfigManager.get_provider_chat_config( model=model, provider=LlmProviders.OPENAI @@ -190,8 +188,7 @@ class TestOpenAIO3(BaseOSeriesModelsTest, BaseLLMChatTest): def test_o1_supports_vision(): """Test that o1 supports vision""" - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() for k, v in litellm.model_cost.items(): if k.startswith("o1") and v.get("litellm_provider") == "openai": assert v.get("supports_vision") is True, f"{k} does not support vision" diff --git a/tests/llm_translation/test_rerank.py b/tests/llm_translation/test_rerank.py index d2cb2b6fea..ef5df795ab 100644 --- a/tests/llm_translation/test_rerank.py +++ b/tests/llm_translation/test_rerank.py @@ -274,8 +274,7 @@ class TestLogger(CustomLogger): @pytest.mark.asyncio() async def test_rerank_custom_callbacks(): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() custom_logger = TestLogger() litellm.callbacks = [custom_logger] diff --git a/tests/llm_translation/test_together_ai.py b/tests/llm_translation/test_together_ai.py index b83a700002..f275500817 100644 --- a/tests/llm_translation/test_together_ai.py +++ b/tests/llm_translation/test_together_ai.py @@ -42,8 +42,7 @@ class TestTogetherAI(BaseLLMChatTest): def test_get_supported_response_format_together_ai( self, model: str, expected_bool: bool ) -> None: - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() optional_params = litellm.get_supported_openai_params( model, custom_llm_provider="together_ai" ) diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 02e0c9b2f1..d59df956be 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -1433,8 +1433,7 @@ async def test_gemini_pro_json_schema_args_sent_httpx( enforce_validation, ): load_vertex_ai_credentials() - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() litellm.set_verbose = True messages = [{"role": "user", "content": "List 5 cookie recipes"}] @@ -1554,8 +1553,7 @@ async def test_gemini_pro_json_schema_args_sent_httpx_openai_schema( from pydantic import BaseModel load_vertex_ai_credentials() - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() litellm.set_verbose = True diff --git a/tests/local_testing/test_completion_cost.py b/tests/local_testing/test_completion_cost.py index 200f2c012e..77d49961bd 100644 --- a/tests/local_testing/test_completion_cost.py +++ b/tests/local_testing/test_completion_cost.py @@ -634,8 +634,7 @@ def test_gemini_completion_cost(above_128k, provider): """ Check if cost correctly calculated for gemini models based on context window """ - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() if provider == "gemini": model_name = "gemini-1.5-flash-latest" else: @@ -690,8 +689,7 @@ def _count_characters(text): def test_vertex_ai_completion_cost(): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() text = "The quick brown fox jumps over the lazy dog." characters = _count_characters(text=text) @@ -726,8 +724,7 @@ def test_vertex_ai_medlm_completion_cost(): model=model, messages=messages, custom_llm_provider="vertex_ai" ) - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() model = "vertex_ai/medlm-medium" messages = [{"role": "user", "content": "Test MedLM completion cost."}] @@ -746,8 +743,7 @@ def test_vertex_ai_claude_completion_cost(): from litellm import Choices, Message, ModelResponse from litellm.utils import Usage - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() litellm.set_verbose = True input_tokens = litellm.token_counter( @@ -796,8 +792,7 @@ def test_vertex_ai_embedding_completion_cost(caplog): """ Relevant issue - https://github.com/BerriAI/litellm/issues/4630 """ - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() text = "The quick brown fox jumps over the lazy dog." input_tokens = litellm.token_counter( @@ -839,8 +834,7 @@ def test_vertex_ai_embedding_completion_cost(caplog): # from test_amazing_vertex_completion import load_vertex_ai_credentials # load_vertex_ai_credentials() -# os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" -# litellm.model_cost = litellm.get_model_cost_map(url="") +# litellm.model_cost = litellm.get_locally_cached_model_cost_map() # text = "The quick brown fox jumps over the lazy dog." # input_tokens = litellm.token_counter( @@ -867,8 +861,7 @@ def test_vertex_ai_embedding_completion_cost(caplog): def test_completion_azure_ai(): try: - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() litellm.set_verbose = True response = litellm.completion( @@ -974,8 +967,7 @@ def test_vertex_ai_mistral_predict_cost(usage): @pytest.mark.parametrize("model", ["openai/tts-1", "azure/tts-1"]) def test_completion_cost_tts(model): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() cost = completion_cost( model=model, @@ -1171,8 +1163,7 @@ def test_completion_cost_azure_common_deployment_name(): ], ) def test_completion_cost_prompt_caching(model, custom_llm_provider): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() from litellm.utils import Choices, Message, ModelResponse, Usage @@ -1273,8 +1264,7 @@ def test_completion_cost_prompt_caching(model, custom_llm_provider): ], ) def test_completion_cost_databricks(model): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() model, messages = model, [{"role": "user", "content": "What is 2+2?"}] resp = litellm.completion(model=model, messages=messages) # works fine @@ -1291,8 +1281,7 @@ def test_completion_cost_databricks(model): ], ) def test_completion_cost_databricks_embedding(model): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() resp = litellm.embedding(model=model, input=["hey, how's it going?"]) # works fine print(resp) @@ -1319,8 +1308,7 @@ def test_get_model_params_fireworks_ai(model, base_model): ["fireworks_ai/llama-v3p1-405b-instruct", "fireworks_ai/mixtral-8x7b-instruct"], ) def test_completion_cost_fireworks_ai(model): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() messages = [{"role": "user", "content": "Hey, how's it going?"}] resp = litellm.completion(model=model, messages=messages) # works fine @@ -1337,8 +1325,7 @@ def test_cost_azure_openai_prompt_caching(): ) from litellm import get_model_info - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() model = "azure/o1-mini" @@ -1427,8 +1414,7 @@ def test_cost_azure_openai_prompt_caching(): def test_completion_cost_vertex_llama3(): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() from litellm.utils import Choices, Message, ModelResponse, Usage @@ -1468,8 +1454,7 @@ def test_cost_openai_prompt_caching(): from litellm.utils import Choices, Message, ModelResponse, Usage from litellm import get_model_info - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() model = "gpt-4o-mini-2024-07-18" @@ -1559,8 +1544,7 @@ def test_cost_openai_prompt_caching(): def test_completion_cost_azure_ai_rerank(model): from litellm import RerankResponse, rerank - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() response = RerankResponse( id="b01dbf2e-63c8-4981-9e69-32241da559ed", @@ -1591,8 +1575,7 @@ def test_completion_cost_azure_ai_rerank(model): def test_together_ai_embedding_completion_cost(): from litellm.utils import Choices, EmbeddingResponse, Message, ModelResponse, Usage - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() response = EmbeddingResponse( model="togethercomputer/m2-bert-80M-8k-retrieval", data=[ @@ -2449,8 +2432,7 @@ def test_completion_cost_params_gemini_3(): from litellm.llms.vertex_ai.cost_calculator import cost_per_character - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() response = ModelResponse( id="chatcmpl-61043504-4439-48be-9996-e29bdee24dc3", @@ -2519,8 +2501,7 @@ def test_completion_cost_params_gemini_3(): # @pytest.mark.flaky(retries=3, delay=1) @pytest.mark.parametrize("stream", [False]) # True, async def test_test_completion_cost_gpt4o_audio_output_from_model(stream): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() from litellm.types.utils import ( Choices, Message, @@ -2617,8 +2598,7 @@ def test_completion_cost_model_response_cost(response_model, custom_llm_provider """ from litellm import ModelResponse - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() litellm.set_verbose = True response = { @@ -2718,8 +2698,7 @@ def test_select_model_name_for_cost_calc(): def test_moderations(): from litellm import moderation - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() litellm.add_known_models() assert "omni-moderation-latest" in litellm.model_cost @@ -2772,8 +2751,7 @@ def test_bedrock_cost_calc_with_region(): from litellm import ModelResponse - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() litellm.add_known_models() @@ -2972,9 +2950,7 @@ async def test_cost_calculator_with_custom_pricing_router(model_item, custom_pri def test_json_valid_model_cost_map(): import json - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - - model_cost = litellm.get_model_cost_map(url="") + model_cost = litellm.get_locally_cached_model_cost_map() try: # Attempt to serialize and deserialize the JSON diff --git a/tests/local_testing/test_embedding.py b/tests/local_testing/test_embedding.py index c85a830e5f..c369dd73eb 100644 --- a/tests/local_testing/test_embedding.py +++ b/tests/local_testing/test_embedding.py @@ -115,8 +115,7 @@ def test_openai_embedding_3(): @pytest.mark.asyncio async def test_openai_azure_embedding_simple(model, api_base, api_key, sync_mode): try: - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() # litellm.set_verbose = True if sync_mode: response = embedding( @@ -198,8 +197,7 @@ def _azure_ai_image_mock_response(*args, **kwargs): @pytest.mark.asyncio async def test_azure_ai_embedding_image(model, api_base, api_key, sync_mode): try: - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() input = base64_image if sync_mode: client = HTTPHandler() diff --git a/tests/local_testing/test_get_model_info.py b/tests/local_testing/test_get_model_info.py index c879332c7b..c40ac41be2 100644 --- a/tests/local_testing/test_get_model_info.py +++ b/tests/local_testing/test_get_model_info.py @@ -58,16 +58,14 @@ def test_get_model_info_shows_correct_supports_vision(): def test_get_model_info_shows_assistant_prefill(): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() info = litellm.get_model_info("deepseek/deepseek-chat") print("info", info) assert info.get("supports_assistant_prefill") is True def test_get_model_info_shows_supports_prompt_caching(): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() info = litellm.get_model_info("deepseek/deepseek-chat") print("info", info) assert info.get("supports_prompt_caching") is True @@ -116,8 +114,7 @@ def test_get_model_info_gemini(): """ Tests if ALL gemini models have 'tpm' and 'rpm' in the model info """ - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() model_map = litellm.model_cost for model, info in model_map.items(): @@ -127,8 +124,7 @@ def test_get_model_info_gemini(): def test_get_model_info_bedrock_region(): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() args = { "model": "us.anthropic.claude-3-5-sonnet-20241022-v2:0", "custom_llm_provider": "bedrock", @@ -212,8 +208,7 @@ def test_model_info_bedrock_converse(monkeypatch): This ensures they are automatically routed to the converse endpoint. """ - monkeypatch.setenv("LITELLM_LOCAL_MODEL_COST_MAP", "True") - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() try: # Load whitelist models from file with open("whitelisted_bedrock_models.txt", "r") as file: @@ -231,8 +226,7 @@ def test_model_info_bedrock_converse_enforcement(monkeypatch): """ Test the enforcement of the whitelist by adding a fake model and ensuring the test fails. """ - monkeypatch.setenv("LITELLM_LOCAL_MODEL_COST_MAP", "True") - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() # Add a fake unwhitelisted model litellm.model_cost["fake.bedrock-chat-model"] = { @@ -323,8 +317,7 @@ def test_get_model_info_bedrock_models(): """ from litellm.llms.bedrock.common_utils import BedrockModelInfo - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() for k, v in litellm.model_cost.items(): if v["litellm_provider"] == "bedrock": diff --git a/tests/local_testing/test_router_utils.py b/tests/local_testing/test_router_utils.py index 7de9707579..d0afc440d9 100644 --- a/tests/local_testing/test_router_utils.py +++ b/tests/local_testing/test_router_utils.py @@ -178,8 +178,7 @@ async def test_update_kwargs_before_fallbacks(call_type): def test_router_get_model_info_wildcard_routes(): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() router = Router( model_list=[ { @@ -200,8 +199,7 @@ def test_router_get_model_info_wildcard_routes(): @pytest.mark.asyncio async def test_router_get_model_group_usage_wildcard_routes(): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() router = Router( model_list=[ { @@ -297,8 +295,7 @@ async def test_call_router_callbacks_on_failure(): @pytest.mark.asyncio async def test_router_model_group_headers(): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() from litellm.types.utils import OPENAI_RESPONSE_HEADERS router = Router( @@ -330,8 +327,7 @@ async def test_router_model_group_headers(): @pytest.mark.asyncio async def test_get_remaining_model_group_usage(): - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.model_cost = litellm.get_locally_cached_model_cost_map() from litellm.types.utils import OPENAI_RESPONSE_HEADERS router = Router( diff --git a/tests/test_model_prices_and_context_window_schema.py b/tests/test_model_prices_and_context_window_schema.py new file mode 100644 index 0000000000..80d35f84b4 --- /dev/null +++ b/tests/test_model_prices_and_context_window_schema.py @@ -0,0 +1,111 @@ +import litellm +from jsonschema import validate + +def test_model_prices_and_context_window_json_is_valid(): + ''' + Validates the `model_prices_and_context_window.json` file. + + If this test fails after you update the json, you need to update the schema or correct the change you made. + ''' + + INTENDED_SCHEMA = { + "type": "object", + "additionalProperties": { + "type": "object", + "properties": { + "cache_creation_input_audio_token_cost": {"type": "number"}, + "cache_creation_input_token_cost": {"type": "number"}, + "cache_read_input_token_cost": {"type": "number"}, + "deprecation_date": {"type": "string"}, + "input_cost_per_audio_per_second": {"type": "number"}, + "input_cost_per_audio_per_second_above_128k_tokens": {"type": "number"}, + "input_cost_per_audio_token": {"type": "number"}, + "input_cost_per_character": {"type": "number"}, + "input_cost_per_character_above_128k_tokens": {"type": "number"}, + "input_cost_per_image": {"type": "number"}, + "input_cost_per_image_above_128k_tokens": {"type": "number"}, + "input_cost_per_pixel": {"type": "number"}, + "input_cost_per_query": {"type": "number"}, + "input_cost_per_request": {"type": "number"}, + "input_cost_per_second": {"type": "number"}, + "input_cost_per_token": {"type": "number"}, + "input_cost_per_token_above_128k_tokens": {"type": "number"}, + "input_cost_per_token_batch_requests": {"type": "number"}, + "input_cost_per_token_batches": {"type": "number"}, + "input_cost_per_token_cache_hit": {"type": "number"}, + "input_cost_per_video_per_second": {"type": "number"}, + "input_cost_per_video_per_second_above_128k_tokens": {"type": "number"}, + "input_dbu_cost_per_token": {"type": "number"}, + "litellm_provider": {"type": "string"}, + "max_audio_length_hours": {"type": "number"}, + "max_audio_per_prompt": {"type": "number"}, + "max_document_chunks_per_query": {"type": "number"}, + "max_images_per_prompt": {"type": "number"}, + "max_input_tokens": {"type": "number"}, + "max_output_tokens": {"type": "number"}, + "max_pdf_size_mb": {"type": "number"}, + "max_query_tokens": {"type": "number"}, + "max_tokens": {"type": "number"}, + "max_tokens_per_document_chunk": {"type": "number"}, + "max_video_length": {"type": "number"}, + "max_videos_per_prompt": {"type": "number"}, + "metadata": {"type": "object"}, + "mode": { + "type": "string", + "enum": [ + "audio_speech", + "audio_transcription", + "chat", + "completion", + "embedding", + "image_generation", + "moderation", + "moderations", + "rerank" + ], + }, + "output_cost_per_audio_token": {"type": "number"}, + "output_cost_per_character": {"type": "number"}, + "output_cost_per_character_above_128k_tokens": {"type": "number"}, + "output_cost_per_image": {"type": "number"}, + "output_cost_per_pixel": {"type": "number"}, + "output_cost_per_second": {"type": "number"}, + "output_cost_per_token": {"type": "number"}, + "output_cost_per_token_above_128k_tokens": {"type": "number"}, + "output_cost_per_token_batches": {"type": "number"}, + "output_db_cost_per_token": {"type": "number"}, + "output_dbu_cost_per_token": {"type": "number"}, + "output_vector_size": {"type": "number"}, + "rpd": {"type": "number"}, + "rpm": {"type": "number"}, + "source": {"type": "string"}, + "supports_assistant_prefill": {"type": "boolean"}, + "supports_audio_input": {"type": "boolean"}, + "supports_audio_output": {"type": "boolean"}, + "supports_embedding_image_input": {"type": "boolean"}, + "supports_function_calling": {"type": "boolean"}, + "supports_image_input": {"type": "boolean"}, + "supports_parallel_function_calling": {"type": "boolean"}, + "supports_pdf_input": {"type": "boolean"}, + "supports_prompt_caching": {"type": "boolean"}, + "supports_response_schema": {"type": "boolean"}, + "supports_system_messages": {"type": "boolean"}, + "supports_tool_choice": {"type": "boolean"}, + "supports_video_input": {"type": "boolean"}, + "supports_vision": {"type": "boolean"}, + "tool_use_system_prompt_tokens": {"type": "number"}, + "tpm": {"type": "number"}, + }, + "additionalProperties": False, + }, + } + + actual_json = litellm.get_locally_cached_model_cost_map() + assert isinstance(actual_json, dict) + temporarily_removed = actual_json.pop('sample_spec', None) # remove the sample, whose schema is inconsistent with the real data + + validate(actual_json, INTENDED_SCHEMA) + + if temporarily_removed is not None: + # put back the sample spec that we removed + actual_json.update({'sample_spec': temporarily_removed}) From 805679becccbee2107b1d0f07a6c811fa8418067 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 6 Mar 2025 23:05:54 -0800 Subject: [PATCH 011/144] feat(handle_jwt.py): support multiple jwt url's --- litellm/proxy/auth/handle_jwt.py | 50 +++++++++++++++------------ tests/proxy_unit_tests/test_jwt.py | 55 +++++++++++++++++++++++------- 2 files changed, 70 insertions(+), 35 deletions(-) diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 29f4b31f9c..61da9825e6 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -344,32 +344,38 @@ class JWTHandler: if keys_url is None: raise Exception("Missing JWT Public Key URL from environment.") - cached_keys = await self.user_api_key_cache.async_get_cache( - "litellm_jwt_auth_keys" - ) - if cached_keys is None: - response = await self.http_handler.get(keys_url) + keys_url_list = [url.strip() for url in keys_url.split(",")] - response_json = response.json() - if "keys" in response_json: - keys: JWKKeyValue = response.json()["keys"] + for key_url in keys_url_list: + + cache_key = f"litellm_jwt_auth_keys_{key_url}" + + cached_keys = await self.user_api_key_cache.async_get_cache(cache_key) + + if cached_keys is None: + response = await self.http_handler.get(key_url) + + response_json = response.json() + if "keys" in response_json: + keys: JWKKeyValue = response.json()["keys"] + else: + keys = response_json + + await self.user_api_key_cache.async_set_cache( + key=cache_key, + value=keys, + ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins + ) else: - keys = response_json + keys = cached_keys - await self.user_api_key_cache.async_set_cache( - key="litellm_jwt_auth_keys", - value=keys, - ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins - ) - else: - keys = cached_keys + public_key = self.parse_keys(keys=keys, kid=kid) + if public_key is not None: + return cast(dict, public_key) - public_key = self.parse_keys(keys=keys, kid=kid) - if public_key is None: - raise Exception( - f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}, len(keys)={len(keys)}" - ) - return cast(dict, public_key) + raise Exception( + f"No matching public key found. keys={keys_url_list}, kid={kid}" + ) def parse_keys(self, keys: JWKKeyValue, kid: Optional[str]) -> Optional[JWTKeyItem]: public_key: Optional[JWTKeyItem] = None diff --git a/tests/proxy_unit_tests/test_jwt.py b/tests/proxy_unit_tests/test_jwt.py index 7a9d2f0019..d96fb691f7 100644 --- a/tests/proxy_unit_tests/test_jwt.py +++ b/tests/proxy_unit_tests/test_jwt.py @@ -64,7 +64,7 @@ def test_load_config_with_custom_role_names(): @pytest.mark.asyncio -async def test_token_single_public_key(): +async def test_token_single_public_key(monkeypatch): import jwt jwt_handler = JWTHandler() @@ -80,10 +80,15 @@ async def test_token_single_public_key(): ] } + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=backend_keys["keys"]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", + value=backend_keys["keys"], + ) jwt_handler.user_api_key_cache = cache @@ -99,7 +104,7 @@ async def test_token_single_public_key(): @pytest.mark.parametrize("audience", [None, "litellm-proxy"]) @pytest.mark.asyncio -async def test_valid_invalid_token(audience): +async def test_valid_invalid_token(audience, monkeypatch): """ Tests - valid token @@ -116,6 +121,8 @@ async def test_valid_invalid_token(audience): if audience: os.environ["JWT_AUDIENCE"] = audience + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + # Generate a private / public key pair using RSA algorithm key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() @@ -145,7 +152,9 @@ async def test_valid_invalid_token(audience): # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] + ) jwt_handler = JWTHandler() @@ -294,7 +303,7 @@ def team_token_tuple(): @pytest.mark.parametrize("audience", [None, "litellm-proxy"]) @pytest.mark.asyncio -async def test_team_token_output(prisma_client, audience): +async def test_team_token_output(prisma_client, audience, monkeypatch): import json import uuid @@ -316,6 +325,8 @@ async def test_team_token_output(prisma_client, audience): if audience: os.environ["JWT_AUDIENCE"] = audience + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + # Generate a private / public key pair using RSA algorithm key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() @@ -345,7 +356,9 @@ async def test_team_token_output(prisma_client, audience): # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] + ) jwt_handler = JWTHandler() @@ -463,7 +476,7 @@ async def test_team_token_output(prisma_client, audience): @pytest.mark.parametrize("user_id_upsert", [True, False]) @pytest.mark.asyncio async def aaaatest_user_token_output( - prisma_client, audience, team_id_set, default_team_id, user_id_upsert + prisma_client, audience, team_id_set, default_team_id, user_id_upsert, monkeypatch ): import uuid @@ -528,10 +541,14 @@ async def aaaatest_user_token_output( assert isinstance(public_jwk, dict) + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] + ) jwt_handler = JWTHandler() @@ -699,7 +716,9 @@ async def aaaatest_user_token_output( @pytest.mark.parametrize("admin_allowed_routes", [None, ["ui_routes"]]) @pytest.mark.parametrize("audience", [None, "litellm-proxy"]) @pytest.mark.asyncio -async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_routes): +async def test_allowed_routes_admin( + prisma_client, audience, admin_allowed_routes, monkeypatch +): """ Add a check to make sure jwt proxy admin scope can access all allowed admin routes @@ -723,6 +742,8 @@ async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_route setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) await litellm.proxy.proxy_server.prisma_client.connect() + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + os.environ.pop("JWT_AUDIENCE", None) if audience: os.environ["JWT_AUDIENCE"] = audience @@ -756,7 +777,9 @@ async def test_allowed_routes_admin(prisma_client, audience, admin_allowed_route # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] + ) jwt_handler = JWTHandler() @@ -910,7 +933,9 @@ def mock_user_object(*args, **kwargs): "user_email, should_work", [("ishaan@berri.ai", True), ("krrish@tassle.xyz", False)] ) @pytest.mark.asyncio -async def test_allow_access_by_email(public_jwt_key, user_email, should_work): +async def test_allow_access_by_email( + public_jwt_key, user_email, should_work, monkeypatch +): """ Allow anyone with an `@xyz.com` email make a request to the proxy. @@ -925,10 +950,14 @@ async def test_allow_access_by_email(public_jwt_key, user_email, should_work): public_jwk = public_jwt_key["public_jwk"] private_key = public_jwt_key["private_key"] + monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") + # set cache cache = DualCache() - await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + await cache.async_set_cache( + key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] + ) jwt_handler = JWTHandler() @@ -1074,7 +1103,7 @@ async def test_end_user_jwt_auth(monkeypatch): ] cache.set_cache( - key="litellm_jwt_auth_keys", + key="litellm_jwt_auth_keys_https://example.com/public-key", value=keys, ) From fb4ebf0fd435d1ff6dadf46f113d75d5f48fa0a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20Can=CC=83ete?= Date: Sat, 8 Mar 2025 01:14:45 +0100 Subject: [PATCH 012/144] ci: add helm unittest --- .github/workflows/helm_unit_test.yml | 27 ++++++++++ .../litellm-helm/tests/deployment_tests.yaml | 54 +++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 .github/workflows/helm_unit_test.yml create mode 100644 deploy/charts/litellm-helm/tests/deployment_tests.yaml diff --git a/.github/workflows/helm_unit_test.yml b/.github/workflows/helm_unit_test.yml new file mode 100644 index 0000000000..c4b83af70a --- /dev/null +++ b/.github/workflows/helm_unit_test.yml @@ -0,0 +1,27 @@ +name: Helm unit test + +on: + pull_request: + push: + branches: + - main + +jobs: + unit-test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set up Helm 3.11.1 + uses: azure/setup-helm@v1 + with: + version: '3.11.1' + + - name: Install Helm Unit Test Plugin + run: | + helm plugin install https://github.com/helm-unittest/helm-unittest --version v0.4.4 + + - name: Run unit tests + run: + helm unittest -f 'tests/*.yaml' deploy/charts/litellm-helm \ No newline at end of file diff --git a/deploy/charts/litellm-helm/tests/deployment_tests.yaml b/deploy/charts/litellm-helm/tests/deployment_tests.yaml new file mode 100644 index 0000000000..e7ce44b052 --- /dev/null +++ b/deploy/charts/litellm-helm/tests/deployment_tests.yaml @@ -0,0 +1,54 @@ +suite: test deployment +templates: + - deployment.yaml + - configmap-litellm.yaml +tests: + - it: should work + template: deployment.yaml + set: + image.tag: test + asserts: + - isKind: + of: Deployment + - matchRegex: + path: metadata.name + pattern: -litellm$ + - equal: + path: spec.template.spec.containers[0].image + value: ghcr.io/berriai/litellm-database:test + - it: should work with tolerations + template: deployment.yaml + set: + tolerations: + - key: node-role.kubernetes.io/master + operator: Exists + effect: NoSchedule + asserts: + - equal: + path: spec.template.spec.tolerations[0].key + value: node-role.kubernetes.io/master + - equal: + path: spec.template.spec.tolerations[0].operator + value: Exists + - it: should work with affinity + template: deployment.yaml + set: + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: topology.kubernetes.io/zone + operator: In + values: + - antarctica-east1 + asserts: + - equal: + path: spec.template.spec.affinity.nodeAffinity.requiredDuringSchedulingIgnoredDuringExecution.nodeSelectorTerms[0].matchExpressions[0].key + value: topology.kubernetes.io/zone + - equal: + path: spec.template.spec.affinity.nodeAffinity.requiredDuringSchedulingIgnoredDuringExecution.nodeSelectorTerms[0].matchExpressions[0].operator + value: In + - equal: + path: spec.template.spec.affinity.nodeAffinity.requiredDuringSchedulingIgnoredDuringExecution.nodeSelectorTerms[0].matchExpressions[0].values[0] + value: antarctica-east1 From 2c5b2da9558cbda394ff669785dc41ebf89f76d5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sun, 9 Mar 2025 18:35:10 -0700 Subject: [PATCH 013/144] fix: make type object subscriptable --- litellm/router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/router.py b/litellm/router.py index d1c410e786..7fe5c2fb94 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -618,7 +618,7 @@ class Router: @staticmethod def _create_redis_cache( - cache_config: dict[str, Any] + cache_config: Dict[str, Any] ) -> RedisCache | RedisClusterCache: if cache_config.get("startup_nodes"): return RedisClusterCache(**cache_config) From c08705517bfcca1ad48cb6029a4899f0820ef20c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sun, 9 Mar 2025 19:40:03 -0700 Subject: [PATCH 014/144] test: fix test --- tests/proxy_unit_tests/test_user_api_key_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/proxy_unit_tests/test_user_api_key_auth.py b/tests/proxy_unit_tests/test_user_api_key_auth.py index dbe49a560d..e956a22282 100644 --- a/tests/proxy_unit_tests/test_user_api_key_auth.py +++ b/tests/proxy_unit_tests/test_user_api_key_auth.py @@ -826,7 +826,7 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa ] local_cache.set_cache( - key="litellm_jwt_auth_keys", + key="litellm_jwt_auth_keys_my-fake-url", value=keys, ) From 06744913862365526344157f3fcab551b985b751 Mon Sep 17 00:00:00 2001 From: omrishiv <327609+omrishiv@users.noreply.github.com> Date: Mon, 10 Mar 2025 08:02:00 -0700 Subject: [PATCH 015/144] add support for Amazon Nova Canvas model (#7838) * add initial support for Amazon Nova Canvas model Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com> * adjust name to AmazonNovaCanvas and map function variables to config Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com> * tighten model name check Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com> * fix quality mapping Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com> * add premium quality in config Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com> * support all Amazon Nova Canvas tasks * remove unused import Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com> * add tests for image generation tasks and fix payload Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com> * add missing util file Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com> * update model prices backup file Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com> * remove image tasks other than text->image Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com> --------- Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com> Co-authored-by: Krish Dholakia --- litellm/__init__.py | 1 + .../amazon_nova_canvas_transformation.py | 106 ++++++++++++++++++ litellm/llms/bedrock/image/image_handler.py | 3 + ...odel_prices_and_context_window_backup.json | 28 +++-- litellm/types/llms/bedrock.py | 57 ++++++++++ litellm/utils.py | 1 + model_prices_and_context_window.json | 28 +++-- .../image_gen_tests/test_image_generation.py | 10 ++ 8 files changed, 212 insertions(+), 22 deletions(-) create mode 100644 litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py diff --git a/litellm/__init__.py b/litellm/__init__.py index d66707f8b3..fd026ffb9d 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -899,6 +899,7 @@ from .llms.bedrock.chat.invoke_transformations.base_invoke_transformation import from .llms.bedrock.image.amazon_stability1_transformation import AmazonStabilityConfig from .llms.bedrock.image.amazon_stability3_transformation import AmazonStability3Config +from .llms.bedrock.image.amazon_nova_canvas_transformation import AmazonNovaCanvasConfig from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config from .llms.bedrock.embed.amazon_titan_multimodal_transformation import ( AmazonTitanMultimodalEmbeddingG1Config, diff --git a/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py b/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py new file mode 100644 index 0000000000..de46edb923 --- /dev/null +++ b/litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py @@ -0,0 +1,106 @@ +import types +from typing import List, Optional + +from openai.types.image import Image + +from litellm.types.llms.bedrock import ( + AmazonNovaCanvasTextToImageRequest, AmazonNovaCanvasTextToImageResponse, + AmazonNovaCanvasTextToImageParams, AmazonNovaCanvasRequestBase, +) +from litellm.types.utils import ImageResponse + + +class AmazonNovaCanvasConfig: + """ + Reference: https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/model-catalog/serverless/amazon.nova-canvas-v1:0 + + """ + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + @classmethod + def get_supported_openai_params(cls, model: Optional[str] = None) -> List: + """ + """ + return ["n", "size", "quality"] + + @classmethod + def _is_nova_model(cls, model: Optional[str] = None) -> bool: + """ + Returns True if the model is a Nova Canvas model + + Nova models follow this pattern: + + """ + if model: + if "amazon.nova-canvas" in model: + return True + return False + + @classmethod + def transform_request_body( + cls, text: str, optional_params: dict + ) -> AmazonNovaCanvasRequestBase: + """ + Transform the request body for Amazon Nova Canvas model + """ + task_type = optional_params.pop("taskType", "TEXT_IMAGE") + image_generation_config = optional_params.pop("imageGenerationConfig", {}) + image_generation_config = {**image_generation_config, **optional_params} + if task_type == "TEXT_IMAGE": + text_to_image_params = image_generation_config.pop("textToImageParams", {}) + text_to_image_params = {"text" :text, **text_to_image_params} + text_to_image_params = AmazonNovaCanvasTextToImageParams(**text_to_image_params) + return AmazonNovaCanvasTextToImageRequest(textToImageParams=text_to_image_params, taskType=task_type, + imageGenerationConfig=image_generation_config) + raise NotImplementedError(f"Task type {task_type} is not supported") + + @classmethod + def map_openai_params(cls, non_default_params: dict, optional_params: dict) -> dict: + """ + Map the OpenAI params to the Bedrock params + """ + _size = non_default_params.get("size") + if _size is not None: + width, height = _size.split("x") + optional_params["width"], optional_params["height"] = int(width), int(height) + if non_default_params.get("n") is not None: + optional_params["numberOfImages"] = non_default_params.get("n") + if non_default_params.get("quality") is not None: + if non_default_params.get("quality") in ("hd", "premium"): + optional_params["quality"] = "premium" + if non_default_params.get("quality") == "standard": + optional_params["quality"] = "standard" + return optional_params + + @classmethod + def transform_response_dict_to_openai_response( + cls, model_response: ImageResponse, response_dict: dict + ) -> ImageResponse: + """ + Transform the response dict to the OpenAI response + """ + + nova_response = AmazonNovaCanvasTextToImageResponse(**response_dict) + openai_images: List[Image] = [] + for _img in nova_response.get("images", []): + openai_images.append(Image(b64_json=_img)) + + model_response.data = openai_images + return model_response diff --git a/litellm/llms/bedrock/image/image_handler.py b/litellm/llms/bedrock/image/image_handler.py index 59a80b2222..8f7762e547 100644 --- a/litellm/llms/bedrock/image/image_handler.py +++ b/litellm/llms/bedrock/image/image_handler.py @@ -266,6 +266,8 @@ class BedrockImageGeneration(BaseAWSLLM): "text_prompts": [{"text": prompt, "weight": 1}], **inference_params, } + elif provider == "amazon": + return dict(litellm.AmazonNovaCanvasConfig.transform_request_body(text=prompt, optional_params=optional_params)) else: raise BedrockError( status_code=422, message=f"Unsupported model={model}, passed in" @@ -301,6 +303,7 @@ class BedrockImageGeneration(BaseAWSLLM): config_class = ( litellm.AmazonStability3Config if litellm.AmazonStability3Config._is_stability_3_model(model=model) + else litellm.AmazonNovaCanvasConfig if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model) else litellm.AmazonStabilityConfig ) config_class.transform_response_dict_to_openai_response( diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index cb2322752b..a34cfb7db9 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -6434,7 +6434,7 @@ "supports_response_schema": true }, "us.amazon.nova-micro-v1:0": { - "max_tokens": 4096, + "max_tokens": 4096, "max_input_tokens": 300000, "max_output_tokens": 4096, "input_cost_per_token": 0.000000035, @@ -6472,7 +6472,7 @@ "supports_response_schema": true }, "us.amazon.nova-lite-v1:0": { - "max_tokens": 4096, + "max_tokens": 4096, "max_input_tokens": 128000, "max_output_tokens": 4096, "input_cost_per_token": 0.00000006, @@ -6514,7 +6514,7 @@ "supports_response_schema": true }, "us.amazon.nova-pro-v1:0": { - "max_tokens": 4096, + "max_tokens": 4096, "max_input_tokens": 300000, "max_output_tokens": 4096, "input_cost_per_token": 0.0000008, @@ -6527,6 +6527,12 @@ "supports_prompt_caching": true, "supports_response_schema": true }, + "1024-x-1024/50-steps/bedrock/amazon.nova-canvas-v1:0": { + "max_input_tokens": 2600, + "output_cost_per_image": 0.06, + "litellm_provider": "bedrock", + "mode": "image_generation" + }, "eu.amazon.nova-pro-v1:0": { "max_tokens": 4096, "max_input_tokens": 300000, @@ -7871,22 +7877,22 @@ "mode": "image_generation" }, "stability.sd3-5-large-v1:0": { - "max_tokens": 77, - "max_input_tokens": 77, + "max_tokens": 77, + "max_input_tokens": 77, "output_cost_per_image": 0.08, "litellm_provider": "bedrock", "mode": "image_generation" }, "stability.stable-image-core-v1:0": { - "max_tokens": 77, - "max_input_tokens": 77, + "max_tokens": 77, + "max_input_tokens": 77, "output_cost_per_image": 0.04, "litellm_provider": "bedrock", "mode": "image_generation" }, "stability.stable-image-core-v1:1": { - "max_tokens": 77, - "max_input_tokens": 77, + "max_tokens": 77, + "max_input_tokens": 77, "output_cost_per_image": 0.04, "litellm_provider": "bedrock", "mode": "image_generation" @@ -7899,8 +7905,8 @@ "mode": "image_generation" }, "stability.stable-image-ultra-v1:1": { - "max_tokens": 77, - "max_input_tokens": 77, + "max_tokens": 77, + "max_input_tokens": 77, "output_cost_per_image": 0.14, "litellm_provider": "bedrock", "mode": "image_generation" diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index 7013c8a800..9d276d7d60 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -365,6 +365,63 @@ class AmazonStability3TextToImageResponse(TypedDict, total=False): finish_reasons: List[str] +class AmazonNovaCanvasRequestBase(TypedDict, total=False): + """ + Base class for Amazon Nova Canvas API requests + """ + + pass + + +class AmazonNovaCanvasImageGenerationConfig(TypedDict, total=False): + """ + Config for Amazon Nova Canvas Text to Image API + + Ref: https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html + """ + + cfgScale: int + seed: int + quality: Literal["standard", "premium"] + width: int + height: int + numberOfImages: int + + +class AmazonNovaCanvasTextToImageParams(TypedDict, total=False): + """ + Params for Amazon Nova Canvas Text to Image API + """ + + text: str + negativeText: str + controlStrength: float + controlMode: Literal["CANNY_EDIT", "SEGMENTATION"] + conditionImage: str + + +class AmazonNovaCanvasTextToImageRequest(AmazonNovaCanvasRequestBase, TypedDict, total=False): + """ + Request for Amazon Nova Canvas Text to Image API + + Ref: https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html + """ + + textToImageParams: AmazonNovaCanvasTextToImageParams + taskType: Literal["TEXT_IMAGE"] + imageGenerationConfig: AmazonNovaCanvasImageGenerationConfig + + +class AmazonNovaCanvasTextToImageResponse(TypedDict, total=False): + """ + Response for Amazon Nova Canvas Text to Image API + + Ref: https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html + """ + + images: List[str] + + if TYPE_CHECKING: from botocore.awsrequest import AWSPreparedRequest else: diff --git a/litellm/utils.py b/litellm/utils.py index ce5acbc694..2f1cac743c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2427,6 +2427,7 @@ def get_optional_params_image_gen( config_class = ( litellm.AmazonStability3Config if litellm.AmazonStability3Config._is_stability_3_model(model=model) + else litellm.AmazonNovaCanvasConfig if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model) else litellm.AmazonStabilityConfig ) supported_params = config_class.get_supported_openai_params(model=model) diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index cb2322752b..a34cfb7db9 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -6434,7 +6434,7 @@ "supports_response_schema": true }, "us.amazon.nova-micro-v1:0": { - "max_tokens": 4096, + "max_tokens": 4096, "max_input_tokens": 300000, "max_output_tokens": 4096, "input_cost_per_token": 0.000000035, @@ -6472,7 +6472,7 @@ "supports_response_schema": true }, "us.amazon.nova-lite-v1:0": { - "max_tokens": 4096, + "max_tokens": 4096, "max_input_tokens": 128000, "max_output_tokens": 4096, "input_cost_per_token": 0.00000006, @@ -6514,7 +6514,7 @@ "supports_response_schema": true }, "us.amazon.nova-pro-v1:0": { - "max_tokens": 4096, + "max_tokens": 4096, "max_input_tokens": 300000, "max_output_tokens": 4096, "input_cost_per_token": 0.0000008, @@ -6527,6 +6527,12 @@ "supports_prompt_caching": true, "supports_response_schema": true }, + "1024-x-1024/50-steps/bedrock/amazon.nova-canvas-v1:0": { + "max_input_tokens": 2600, + "output_cost_per_image": 0.06, + "litellm_provider": "bedrock", + "mode": "image_generation" + }, "eu.amazon.nova-pro-v1:0": { "max_tokens": 4096, "max_input_tokens": 300000, @@ -7871,22 +7877,22 @@ "mode": "image_generation" }, "stability.sd3-5-large-v1:0": { - "max_tokens": 77, - "max_input_tokens": 77, + "max_tokens": 77, + "max_input_tokens": 77, "output_cost_per_image": 0.08, "litellm_provider": "bedrock", "mode": "image_generation" }, "stability.stable-image-core-v1:0": { - "max_tokens": 77, - "max_input_tokens": 77, + "max_tokens": 77, + "max_input_tokens": 77, "output_cost_per_image": 0.04, "litellm_provider": "bedrock", "mode": "image_generation" }, "stability.stable-image-core-v1:1": { - "max_tokens": 77, - "max_input_tokens": 77, + "max_tokens": 77, + "max_input_tokens": 77, "output_cost_per_image": 0.04, "litellm_provider": "bedrock", "mode": "image_generation" @@ -7899,8 +7905,8 @@ "mode": "image_generation" }, "stability.stable-image-ultra-v1:1": { - "max_tokens": 77, - "max_input_tokens": 77, + "max_tokens": 77, + "max_input_tokens": 77, "output_cost_per_image": 0.14, "litellm_provider": "bedrock", "mode": "image_generation" diff --git a/tests/image_gen_tests/test_image_generation.py b/tests/image_gen_tests/test_image_generation.py index 544f25bc67..c2115abeb8 100644 --- a/tests/image_gen_tests/test_image_generation.py +++ b/tests/image_gen_tests/test_image_generation.py @@ -130,6 +130,16 @@ class TestBedrockSd1(BaseImageGenTest): return {"model": "bedrock/stability.sd3-large-v1:0"} +class TestBedrockNovaCanvasTextToImage(BaseImageGenTest): + def get_base_image_generation_call_args(self) -> dict: + litellm.in_memory_llm_clients_cache = InMemoryCache() + return {"model": "bedrock/amazon.nova-canvas-v1:0", + "n": 1, + "size": "320x320", + "imageGenerationConfig": {"cfgScale":6.5,"seed":12}, + "taskType": "TEXT_IMAGE"} + + class TestOpenAIDalle3(BaseImageGenTest): def get_base_image_generation_call_args(self) -> dict: return {"model": "dall-e-3"} From 666690c31cc415317cb981cae00b5664bff40f38 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 10 Mar 2025 10:18:03 -0700 Subject: [PATCH 016/144] fix atext_completion --- litellm/main.py | 46 +++++++++++----------------------------------- 1 file changed, 11 insertions(+), 35 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 846a908a8e..903e0e938b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3900,42 +3900,18 @@ async def atext_completion( ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) - _, custom_llm_provider, _, _ = get_llm_provider( - model=model, api_base=kwargs.get("api_base", None) - ) - - if ( - custom_llm_provider == "openai" - or custom_llm_provider == "azure" - or custom_llm_provider == "azure_text" - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "anyscale" - or custom_llm_provider == "mistral" - or custom_llm_provider == "openrouter" - or custom_llm_provider == "deepinfra" - or custom_llm_provider == "perplexity" - or custom_llm_provider == "groq" - or custom_llm_provider == "nvidia_nim" - or custom_llm_provider == "cerebras" - or custom_llm_provider == "sambanova" - or custom_llm_provider == "ai21_chat" - or custom_llm_provider == "ai21" - or custom_llm_provider == "volcengine" - or custom_llm_provider == "text-completion-codestral" - or custom_llm_provider == "deepseek" - or custom_llm_provider == "text-completion-openai" - or custom_llm_provider == "huggingface" - or custom_llm_provider == "ollama" - or custom_llm_provider == "vertex_ai" - or custom_llm_provider in litellm.openai_compatible_providers - ): # currently implemented aiohttp calls for just azure and openai, soon all. - # Await normally - response = await loop.run_in_executor(None, func_with_context) - if asyncio.iscoroutine(response): - response = await response + init_response = await loop.run_in_executor(None, func_with_context) + if isinstance(init_response, dict) or isinstance( + init_response, TextCompletionResponse + ): ## CACHING SCENARIO + if isinstance(init_response, dict): + response = TextCompletionResponse(**init_response) + response = init_response + elif asyncio.iscoroutine(init_response): + response = await init_response else: - # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) + response = init_response # type: ignore + if ( kwargs.get("stream", False) is True or isinstance(response, TextCompletionStreamWrapper) From 6d537aec48274e8482fbd46389714bef92e22c41 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 10 Mar 2025 10:36:50 -0700 Subject: [PATCH 017/144] OpenAI_Text --- .../components/add_model/provider_specific_fields.tsx | 3 ++- .../src/components/provider_info_helpers.tsx | 11 +++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/ui/litellm-dashboard/src/components/add_model/provider_specific_fields.tsx b/ui/litellm-dashboard/src/components/add_model/provider_specific_fields.tsx index b3da80c715..365d75dbad 100644 --- a/ui/litellm-dashboard/src/components/add_model/provider_specific_fields.tsx +++ b/ui/litellm-dashboard/src/components/add_model/provider_specific_fields.tsx @@ -99,7 +99,8 @@ const ProviderSpecificFields: React.FC = ({ {(selectedProviderEnum === Providers.Azure || selectedProviderEnum === Providers.Azure_AI_Studio || - selectedProviderEnum === Providers.OpenAI_Compatible + selectedProviderEnum === Providers.OpenAI_Compatible || + selectedProviderEnum === Providers.OpenAI_Text_Compatible ) && ( = { OpenAI: "openai", + OpenAI_Text: "text-completion-openai", Azure: "azure", Azure_AI_Studio: "azure_ai", Anthropic: "anthropic", @@ -37,6 +41,7 @@ export const provider_map: Record = { MistralAI: "mistral", Cohere: "cohere_chat", OpenAI_Compatible: "openai", + OpenAI_Text_Compatible: "text-completion-openai", Vertex_AI: "vertex_ai", Databricks: "databricks", xAI: "xai", @@ -53,6 +58,9 @@ export const provider_map: Record = { export const providerLogoMap: Record = { [Providers.OpenAI]: "https://artificialanalysis.ai/img/logos/openai_small.svg", + [Providers.OpenAI_Text]: "https://artificialanalysis.ai/img/logos/openai_small.svg", + [Providers.OpenAI_Text_Compatible]: "https://artificialanalysis.ai/img/logos/openai_small.svg", + [Providers.OpenAI_Compatible]: "https://artificialanalysis.ai/img/logos/openai_small.svg", [Providers.Azure]: "https://upload.wikimedia.org/wikipedia/commons/a/a8/Microsoft_Azure_Logo.svg", [Providers.Azure_AI_Studio]: "https://upload.wikimedia.org/wikipedia/commons/a/a8/Microsoft_Azure_Logo.svg", [Providers.Anthropic]: "https://artificialanalysis.ai/img/logos/anthropic_small.svg", @@ -61,7 +69,6 @@ export const providerLogoMap: Record = { [Providers.Groq]: "https://artificialanalysis.ai/img/logos/groq_small.png", [Providers.MistralAI]: "https://artificialanalysis.ai/img/logos/mistral_small.png", [Providers.Cohere]: "https://artificialanalysis.ai/img/logos/cohere_small.png", - [Providers.OpenAI_Compatible]: "https://upload.wikimedia.org/wikipedia/commons/4/4e/OpenAI_Logo.svg", [Providers.Vertex_AI]: "https://artificialanalysis.ai/img/logos/google_small.svg", [Providers.Databricks]: "https://artificialanalysis.ai/img/logos/databricks_small.png", [Providers.Ollama]: "https://artificialanalysis.ai/img/logos/ollama_small.svg", From 51f074682f420e11df5a468c998add1b230a0b4b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 10 Mar 2025 10:40:48 -0700 Subject: [PATCH 018/144] show eu api base on openai + text --- .../src/components/add_model/provider_specific_fields.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/litellm-dashboard/src/components/add_model/provider_specific_fields.tsx b/ui/litellm-dashboard/src/components/add_model/provider_specific_fields.tsx index 365d75dbad..b7565b0494 100644 --- a/ui/litellm-dashboard/src/components/add_model/provider_specific_fields.tsx +++ b/ui/litellm-dashboard/src/components/add_model/provider_specific_fields.tsx @@ -23,7 +23,7 @@ const ProviderSpecificFields: React.FC = ({ console.log(`type of selectedProviderEnum: ${typeof selectedProviderEnum}`); return ( <> - {selectedProviderEnum === Providers.OpenAI && ( + {selectedProviderEnum === Providers.OpenAI || selectedProviderEnum === Providers.OpenAI_Text && ( <> Date: Mon, 10 Mar 2025 12:20:37 -0700 Subject: [PATCH 019/144] fix linting error --- ui/litellm-dashboard/src/components/transform_request.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/litellm-dashboard/src/components/transform_request.tsx b/ui/litellm-dashboard/src/components/transform_request.tsx index 879132ef50..5d405df78f 100644 --- a/ui/litellm-dashboard/src/components/transform_request.tsx +++ b/ui/litellm-dashboard/src/components/transform_request.tsx @@ -156,7 +156,7 @@ ${formattedBody} }}>

Original Request

-

The request you would send to LiteLLM's `/chat/completions` endpoint.

+

The request you would send to LiteLLM's /chat/completions endpoint.