From 85679470c276d2785cc87e42ceab084bb4f7294a Mon Sep 17 00:00:00 2001 From: Rajan Paneru Date: Fri, 17 May 2024 09:03:18 +0930 Subject: [PATCH 1/3] Exclude custom headers from response if the value is None or empty string This will return clean header, sending a header with empty value is not standard which is being avoided from this fix. --- litellm/proxy/proxy_server.py | 125 ++++++++++++++++++++-------------- 1 file changed, 72 insertions(+), 53 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 11086315df..48a59cab8d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -352,6 +352,24 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict: return pydantic_obj.dict() +def get_custom_headers(*, + model_id: Optional[str] = None, + cache_key: Optional[str] = None, + api_base: Optional[str] = None, + version: Optional[str] = None, + model_region: Optional[str] = None + ) -> dict: + exclude_values = {'', None} + headers = { + 'x-litellm-model-id': model_id, + 'x-litellm-cache-key': cache_key, + 'x-litellm-model-api-base': api_base, + 'x-litellm-version': version, + '"x-litellm-model-region': model_region + } + return {key: value for key, value in headers.items() if value not in exclude_values} + + async def check_request_disconnection(request: Request, llm_api_call_task): """ Asynchronously checks if the request is disconnected at regular intervals. @@ -3747,13 +3765,13 @@ async def chat_completion( if ( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses - custom_headers = { - "x-litellm-model-id": model_id, - "x-litellm-cache-key": cache_key, - "x-litellm-model-api-base": api_base, - "x-litellm-version": version, - "x-litellm-model-region": user_api_key_dict.allowed_model_region or "", - } + custom_headers = get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=user_api_key_dict.allowed_model_region + ) selected_data_generator = select_data_generator( response=response, user_api_key_dict=user_api_key_dict, @@ -3765,13 +3783,13 @@ async def chat_completion( headers=custom_headers, ) - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version - fastapi_response.headers["x-litellm-model-region"] = ( - user_api_key_dict.allowed_model_region or "" - ) + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=user_api_key_dict.allowed_model_region + )) ### CALL HOOKS ### - modify outgoing data response = await proxy_logging_obj.post_call_success_hook( @@ -3957,12 +3975,12 @@ async def completion( if ( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses - custom_headers = { - "x-litellm-model-id": model_id, - "x-litellm-cache-key": cache_key, - "x-litellm-model-api-base": api_base, - "x-litellm-version": version, - } + custom_headers = get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version + ) selected_data_generator = select_data_generator( response=response, user_api_key_dict=user_api_key_dict, @@ -3974,11 +3992,12 @@ async def completion( media_type="text/event-stream", headers=custom_headers, ) - - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version + )) return response except Exception as e: @@ -4183,13 +4202,13 @@ async def embeddings( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version - fastapi_response.headers["x-litellm-model-region"] = ( - user_api_key_dict.allowed_model_region or "" - ) + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=user_api_key_dict.allowed_model_region + )) return response except Exception as e: @@ -4364,13 +4383,13 @@ async def image_generation( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version - fastapi_response.headers["x-litellm-model-region"] = ( - user_api_key_dict.allowed_model_region or "" - ) + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=user_api_key_dict.allowed_model_region + )) return response except Exception as e: @@ -4563,13 +4582,13 @@ async def audio_transcriptions( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version - fastapi_response.headers["x-litellm-model-region"] = ( - user_api_key_dict.allowed_model_region or "" - ) + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=user_api_key_dict.allowed_model_region + )) return response except Exception as e: @@ -4744,13 +4763,13 @@ async def moderations( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version - fastapi_response.headers["x-litellm-model-region"] = ( - user_api_key_dict.allowed_model_region or "" - ) + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=user_api_key_dict.allowed_model_region + )) return response except Exception as e: From 54f8d060574dbefead8477451502e9599a429bee Mon Sep 17 00:00:00 2001 From: Rajan Paneru Date: Fri, 17 May 2024 09:55:13 +0930 Subject: [PATCH 2/3] handle exception and logged it --- litellm/proxy/proxy_server.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 48a59cab8d..734ef5676a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -367,7 +367,11 @@ def get_custom_headers(*, 'x-litellm-version': version, '"x-litellm-model-region': model_region } - return {key: value for key, value in headers.items() if value not in exclude_values} + try: + return {key: value for key, value in headers.items() if value not in exclude_values} + except Exception as e: + verbose_proxy_logger.error(f"Error setting custom headers: {e}") + return {} async def check_request_disconnection(request: Request, llm_api_call_task): From e4ce10038a6cbbc42bfa8a3ff86998e3cde5eaeb Mon Sep 17 00:00:00 2001 From: Rajan Paneru Date: Fri, 17 May 2024 10:05:18 +0930 Subject: [PATCH 3/3] use default empty str if the allowed_model_region attribute is not present --- litellm/proxy/proxy_server.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 734ef5676a..0694ef6c20 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3774,7 +3774,7 @@ async def chat_completion( cache_key=cache_key, api_base=api_base, version=version, - model_region=user_api_key_dict.allowed_model_region + model_region=getattr(user_api_key_dict, "allowed_model_region", "") ) selected_data_generator = select_data_generator( response=response, @@ -3792,7 +3792,7 @@ async def chat_completion( cache_key=cache_key, api_base=api_base, version=version, - model_region=user_api_key_dict.allowed_model_region + model_region=getattr(user_api_key_dict, "allowed_model_region", "") )) ### CALL HOOKS ### - modify outgoing data @@ -4211,7 +4211,7 @@ async def embeddings( cache_key=cache_key, api_base=api_base, version=version, - model_region=user_api_key_dict.allowed_model_region + model_region=getattr(user_api_key_dict, "allowed_model_region", "") )) return response @@ -4392,7 +4392,7 @@ async def image_generation( cache_key=cache_key, api_base=api_base, version=version, - model_region=user_api_key_dict.allowed_model_region + model_region=getattr(user_api_key_dict, "allowed_model_region", "") )) return response @@ -4591,7 +4591,7 @@ async def audio_transcriptions( cache_key=cache_key, api_base=api_base, version=version, - model_region=user_api_key_dict.allowed_model_region + model_region=getattr(user_api_key_dict, "allowed_model_region", "") )) return response @@ -4772,7 +4772,7 @@ async def moderations( cache_key=cache_key, api_base=api_base, version=version, - model_region=user_api_key_dict.allowed_model_region + model_region=getattr(user_api_key_dict, "allowed_model_region", "") )) return response