From 10a672634d26cbb8ed66fa5ffe98ec8019b421d9 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 16 May 2024 21:05:10 -0700 Subject: [PATCH] fix(proxy_server.py): fix invalid header string --- litellm/proxy/proxy_server.py | 129 +++++++++++++++++++--------------- 1 file changed, 72 insertions(+), 57 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 710290ff0..763a53daf 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -352,23 +352,26 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict: return pydantic_obj.dict() -def get_custom_headers(*, - model_id: Optional[str] = None, - cache_key: Optional[str] = None, - api_base: Optional[str] = None, - version: Optional[str] = None, - model_region: Optional[str] = None - ) -> dict: - exclude_values = {'', None} +def get_custom_headers( + *, + model_id: Optional[str] = None, + cache_key: Optional[str] = None, + api_base: Optional[str] = None, + version: Optional[str] = None, + model_region: Optional[str] = None, +) -> dict: + exclude_values = {"", None} headers = { - 'x-litellm-model-id': model_id, - 'x-litellm-cache-key': cache_key, - 'x-litellm-model-api-base': api_base, - 'x-litellm-version': version, - '"x-litellm-model-region': model_region + "x-litellm-model-id": model_id, + "x-litellm-cache-key": cache_key, + "x-litellm-model-api-base": api_base, + "x-litellm-version": version, + "x-litellm-model-region": model_region, } try: - return {key: value for key, value in headers.items() if value not in exclude_values} + return { + key: value for key, value in headers.items() if value not in exclude_values + } except Exception as e: verbose_proxy_logger.error(f"Error setting custom headers: {e}") return {} @@ -3797,7 +3800,7 @@ async def chat_completion( cache_key=cache_key, api_base=api_base, version=version, - model_region=getattr(user_api_key_dict, "allowed_model_region", "") + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), ) selected_data_generator = select_data_generator( response=response, @@ -3810,13 +3813,15 @@ async def chat_completion( headers=custom_headers, ) - fastapi_response.headers.update(get_custom_headers( - model_id=model_id, - cache_key=cache_key, - api_base=api_base, - version=version, - model_region=getattr(user_api_key_dict, "allowed_model_region", "") - )) + fastapi_response.headers.update( + get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) ### CALL HOOKS ### - modify outgoing data response = await proxy_logging_obj.post_call_success_hook( @@ -4006,7 +4011,7 @@ async def completion( model_id=model_id, cache_key=cache_key, api_base=api_base, - version=version + version=version, ) selected_data_generator = select_data_generator( response=response, @@ -4019,12 +4024,14 @@ async def completion( media_type="text/event-stream", headers=custom_headers, ) - fastapi_response.headers.update(get_custom_headers( - model_id=model_id, - cache_key=cache_key, - api_base=api_base, - version=version - )) + fastapi_response.headers.update( + get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + ) + ) return response except Exception as e: @@ -4229,13 +4236,15 @@ async def embeddings( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers.update(get_custom_headers( - model_id=model_id, - cache_key=cache_key, - api_base=api_base, - version=version, - model_region=getattr(user_api_key_dict, "allowed_model_region", "") - )) + fastapi_response.headers.update( + get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) return response except Exception as e: @@ -4410,13 +4419,15 @@ async def image_generation( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers.update(get_custom_headers( - model_id=model_id, - cache_key=cache_key, - api_base=api_base, - version=version, - model_region=getattr(user_api_key_dict, "allowed_model_region", "") - )) + fastapi_response.headers.update( + get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) return response except Exception as e: @@ -4609,13 +4620,15 @@ async def audio_transcriptions( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers.update(get_custom_headers( - model_id=model_id, - cache_key=cache_key, - api_base=api_base, - version=version, - model_region=getattr(user_api_key_dict, "allowed_model_region", "") - )) + fastapi_response.headers.update( + get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) return response except Exception as e: @@ -4790,13 +4803,15 @@ async def moderations( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers.update(get_custom_headers( - model_id=model_id, - cache_key=cache_key, - api_base=api_base, - version=version, - model_region=getattr(user_api_key_dict, "allowed_model_region", "") - )) + fastapi_response.headers.update( + get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) return response except Exception as e: