From dec558ba4cde95e7b2fe30513b0cd343778f3668 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Tue, 21 Jan 2025 20:36:11 -0800 Subject: [PATCH] Litellm dev 01 21 2025 p1 (#7898) * fix(utils.py): don't pass 'anthropic-beta' header to vertex - will cause request to fail * fix(utils.py): add flag to allow user to disable filtering invalid headers ensure user can control behaviour * style(utils.py): cleanup message * test(test_utils.py): add unit test to cover invalid header filtering * fix(proxy_server.py): fix custom openapi schema generation * fix(utils.py): pass extra headers if set * fix(main.py): fix image variation to use 'client' param --- litellm/__init__.py | 1 + litellm/main.py | 11 ++-- litellm/proxy/litellm_pre_call_utils.py | 7 +-- litellm/proxy/proxy_server.py | 55 ++++++++++++++++++- litellm/types/llms/anthropic.py | 10 ++++ litellm/utils.py | 32 +++++++++++ tests/image_gen_tests/test_image_variation.py | 27 +++++---- tests/local_testing/test_utils.py | 23 ++++++++ tests/proxy_unit_tests/test_proxy_utils.py | 16 ++++++ 9 files changed, 161 insertions(+), 21 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 9784adbd87..165113a595 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -105,6 +105,7 @@ turn_off_message_logging: Optional[bool] = False log_raw_request_response: bool = False redact_messages_in_exceptions: Optional[bool] = False redact_user_api_key_info: Optional[bool] = False +filter_invalid_headers: Optional[bool] = False add_user_information_to_llm_headers: Optional[bool] = ( None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers ) diff --git a/litellm/main.py b/litellm/main.py index b6b35969ba..8042fb1cc8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -844,9 +844,6 @@ def completion( # type: ignore # noqa: PLR0915 ) if headers is None: headers = {} - - if extra_headers is not None: - headers.update(extra_headers) num_retries = kwargs.get( "num_retries", None ) ## alt. param for 'max_retries'. Use this to pass retries w/ instructor. @@ -1042,9 +1039,14 @@ def completion( # type: ignore # noqa: PLR0915 api_version=api_version, parallel_tool_calls=parallel_tool_calls, messages=messages, + extra_headers=extra_headers, **non_default_params, ) + extra_headers = optional_params.pop("extra_headers", None) + if extra_headers is not None: + headers.update(extra_headers) + if litellm.add_function_to_prompt and optional_params.get( "functions_unsupported_model", None ): # if user opts to add it to prompt, when API doesn't support function calling @@ -4670,7 +4672,7 @@ def image_variation( **kwargs, ) -> ImageResponse: # get non-default params - + client = kwargs.get("client", None) # get logging object litellm_logging_obj = cast(LiteLLMLoggingObj, kwargs.get("litellm_logging_obj")) @@ -4744,6 +4746,7 @@ def image_variation( logging_obj=litellm_logging_obj, optional_params={}, litellm_params=litellm_params, + client=client, ) # return the response diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 5944d0248c..9839a519a2 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -17,6 +17,7 @@ from litellm.proxy._types import ( TeamCallbackMetadata, UserAPIKeyAuth, ) +from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS from litellm.types.services import ServiceTypes from litellm.types.utils import ( StandardLoggingUserAPIKeyMetadata, @@ -396,6 +397,7 @@ async def add_litellm_data_to_request( # noqa: PLR0915 dict: The modified data dictionary. """ + from litellm.proxy.proxy_server import llm_router, premium_user safe_add_api_version_from_query_params(data, request) @@ -626,6 +628,7 @@ async def add_litellm_data_to_request( # noqa: PLR0915 parent_otel_span=user_api_key_dict.parent_otel_span, ) ) + return data @@ -726,10 +729,6 @@ def add_provider_specific_headers_to_request( data: dict, headers: dict, ): - ANTHROPIC_API_HEADERS = [ - "anthropic-version", - "anthropic-beta", - ] extra_headers = data.get("extra_headers", {}) or {} diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 2beb1bd435..2c126e54c0 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -562,20 +562,71 @@ app = FastAPI( ### CUSTOM API DOCS [ENTERPRISE FEATURE] ### # Custom OpenAPI schema generator to include only selected routes -def custom_openapi(): +from fastapi.routing import APIWebSocketRoute + + +def get_openapi_schema(): if app.openapi_schema: return app.openapi_schema + openapi_schema = get_openapi( title=app.title, version=app.version, description=app.description, routes=app.routes, ) + + # Find all WebSocket routes + websocket_routes = [ + route for route in app.routes if isinstance(route, APIWebSocketRoute) + ] + + # Add each WebSocket route to the schema + for route in websocket_routes: + # Get the base path without query parameters + base_path = route.path.split("{")[0].rstrip("?") + + # Extract parameters from the route + parameters = [] + if hasattr(route, "dependant"): + for param in route.dependant.query_params: + parameters.append( + { + "name": param.name, + "in": "query", + "required": param.required, + "schema": { + "type": "string" + }, # You can make this more specific if needed + } + ) + + openapi_schema["paths"][base_path] = { + "get": { + "summary": f"WebSocket: {route.name or base_path}", + "description": "WebSocket connection endpoint", + "operationId": f"websocket_{route.name or base_path.replace('/', '_')}", + "parameters": parameters, + "responses": {"101": {"description": "WebSocket Protocol Switched"}}, + "tags": ["WebSocket"], + } + } + + app.openapi_schema = openapi_schema + return app.openapi_schema + + +def custom_openapi(): + if app.openapi_schema: + return app.openapi_schema + openapi_schema = get_openapi_schema() + # Filter routes to include only specific ones openai_routes = LiteLLMRoutes.openai_routes.value paths_to_include: dict = {} for route in openai_routes: - paths_to_include[route] = openapi_schema["paths"][route] + if route in openapi_schema["paths"]: + paths_to_include[route] = openapi_schema["paths"][route] openapi_schema["paths"] = paths_to_include app.openapi_schema = openapi_schema return app.openapi_schema diff --git a/litellm/types/llms/anthropic.py b/litellm/types/llms/anthropic.py index 55e37ad971..71b9161eb4 100644 --- a/litellm/types/llms/anthropic.py +++ b/litellm/types/llms/anthropic.py @@ -334,3 +334,13 @@ from .openai import ChatCompletionUsageBlock class AnthropicChatCompletionUsageBlock(ChatCompletionUsageBlock, total=False): cache_creation_input_tokens: int cache_read_input_tokens: int + + +ANTHROPIC_API_HEADERS = { + "anthropic-version", + "anthropic-beta", +} + +ANTHROPIC_API_ONLY_HEADERS = { # fails if calling anthropic on vertex ai / bedrock + "anthropic-beta", +} diff --git a/litellm/utils.py b/litellm/utils.py index 298c2652cb..37d721fc29 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -112,6 +112,7 @@ from litellm.router_utils.get_retry_from_policy import ( reset_retry_policy, ) from litellm.secret_managers.main import get_secret +from litellm.types.llms.anthropic import ANTHROPIC_API_ONLY_HEADERS from litellm.types.llms.openai import ( AllMessageValues, AllPromptValues, @@ -2601,6 +2602,25 @@ def _remove_unsupported_params( return non_default_params +def get_clean_extra_headers(extra_headers: dict, custom_llm_provider: str) -> dict: + """ + For `anthropic-beta` headers, ensure provider is anthropic. + + Vertex AI raises an exception if `anthropic-beta` is passed in. + """ + if litellm.filter_invalid_headers is not True: # allow user to opt out of filtering + return extra_headers + clean_extra_headers = {} + for k, v in extra_headers.items(): + if k in ANTHROPIC_API_ONLY_HEADERS and custom_llm_provider != "anthropic": + verbose_logger.debug( + f"Provider {custom_llm_provider} does not support {k} header. Dropping from request, to prevent errors." + ) # Switching between anthropic api and vertex ai anthropic fails when anthropic-beta is passed in. Welcome feedback on this. + else: + clean_extra_headers[k] = v + return clean_extra_headers + + def get_optional_params( # noqa: PLR0915 # use the openai defaults # https://platform.openai.com/docs/api-reference/chat/create @@ -2739,6 +2759,12 @@ def get_optional_params( # noqa: PLR0915 ) } + ## Supports anthropic headers + if extra_headers is not None: + extra_headers = get_clean_extra_headers( + extra_headers=extra_headers, custom_llm_provider=custom_llm_provider + ) + ## raise exception if function calling passed in for a provider that doesn't support it if ( "functions" in non_default_params @@ -3508,6 +3534,12 @@ def get_optional_params( # noqa: PLR0915 for k in passed_params.keys(): if k not in default_params.keys(): optional_params[k] = passed_params[k] + if extra_headers is not None: + optional_params.setdefault("extra_headers", {}) + optional_params["extra_headers"] = { + **optional_params["extra_headers"], + **extra_headers, + } print_verbose(f"Final returned optional params: {optional_params}") return optional_params diff --git a/tests/image_gen_tests/test_image_variation.py b/tests/image_gen_tests/test_image_variation.py index 9fea648631..d4f6660335 100644 --- a/tests/image_gen_tests/test_image_variation.py +++ b/tests/image_gen_tests/test_image_variation.py @@ -68,16 +68,21 @@ async def test_openai_image_variation_litellm_sdk(image_url, sync_mode): await aimage_variation(image=image_url, n=2, size="1024x1024") -@pytest.mark.parametrize("sync_mode", [True, False]) # , -@pytest.mark.asyncio -async def test_topaz_image_variation(image_url, sync_mode): +def test_topaz_image_variation(image_url): from litellm import image_variation, aimage_variation + from litellm.llms.custom_httpx.http_handler import HTTPHandler + from unittest.mock import patch - if sync_mode: - image_variation( - model="topaz/Standard V2", image=image_url, n=2, size="1024x1024" - ) - else: - response = await aimage_variation( - model="topaz/Standard V2", image=image_url, n=2, size="1024x1024" - ) + client = HTTPHandler() + with patch.object(client, "post") as mock_post: + try: + image_variation( + model="topaz/Standard V2", + image=image_url, + n=2, + size="1024x1024", + client=client, + ) + except Exception as e: + print(e) + mock_post.assert_called_once() diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index 35d676c088..fbc1f9d7d4 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -1494,3 +1494,26 @@ def test_get_num_retries(num_retries): "num_retries": num_retries, }, ) + + +@pytest.mark.parametrize("filter_invalid_headers", [True, False]) +@pytest.mark.parametrize( + "custom_llm_provider, expected_result", + [("anthropic", {"anthropic-beta": "123"}), ("bedrock", {}), ("vertex_ai", {})], +) +def test_get_clean_extra_headers( + filter_invalid_headers, custom_llm_provider, expected_result, monkeypatch +): + from litellm.utils import get_clean_extra_headers + + monkeypatch.setattr(litellm, "filter_invalid_headers", filter_invalid_headers) + + if filter_invalid_headers: + assert ( + get_clean_extra_headers({"anthropic-beta": "123"}, custom_llm_provider) + == expected_result + ) + else: + assert get_clean_extra_headers( + {"anthropic-beta": "123"}, custom_llm_provider + ) == {"anthropic-beta": "123"} diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 6934521718..a3de35a2ab 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -1479,3 +1479,19 @@ async def test_health_check_not_called_when_disabled(monkeypatch): # Verify health check wasn't called mock_prisma.health_check.assert_not_called() + + +@patch( + "litellm.proxy.proxy_server.get_openapi_schema", + return_value={ + "paths": { + "/new/route": {"get": {"summary": "New"}}, + } + }, +) +def test_custom_openapi(mock_get_openapi_schema): + from litellm.proxy.proxy_server import custom_openapi + from litellm.proxy.proxy_server import app + + openapi_schema = custom_openapi() + assert openapi_schema is not None