diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c60799ac91..710290ff00 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -352,6 +352,28 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict: return pydantic_obj.dict() +def get_custom_headers(*, + model_id: Optional[str] = None, + cache_key: Optional[str] = None, + api_base: Optional[str] = None, + version: Optional[str] = None, + model_region: Optional[str] = None + ) -> dict: + exclude_values = {'', None} + headers = { + 'x-litellm-model-id': model_id, + 'x-litellm-cache-key': cache_key, + 'x-litellm-model-api-base': api_base, + 'x-litellm-version': version, + '"x-litellm-model-region': model_region + } + try: + return {key: value for key, value in headers.items() if value not in exclude_values} + except Exception as e: + verbose_proxy_logger.error(f"Error setting custom headers: {e}") + return {} + + async def check_request_disconnection(request: Request, llm_api_call_task): """ Asynchronously checks if the request is disconnected at regular intervals. @@ -3770,13 +3792,13 @@ async def chat_completion( if ( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses - custom_headers = { - "x-litellm-model-id": model_id, - "x-litellm-cache-key": cache_key, - "x-litellm-model-api-base": api_base, - "x-litellm-version": version, - "x-litellm-model-region": user_api_key_dict.allowed_model_region or "", - } + custom_headers = get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", "") + ) selected_data_generator = select_data_generator( response=response, user_api_key_dict=user_api_key_dict, @@ -3788,13 +3810,13 @@ async def chat_completion( headers=custom_headers, ) - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version - fastapi_response.headers["x-litellm-model-region"] = ( - user_api_key_dict.allowed_model_region or "" - ) + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", "") + )) ### CALL HOOKS ### - modify outgoing data response = await proxy_logging_obj.post_call_success_hook( @@ -3980,12 +4002,12 @@ async def completion( if ( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses - custom_headers = { - "x-litellm-model-id": model_id, - "x-litellm-cache-key": cache_key, - "x-litellm-model-api-base": api_base, - "x-litellm-version": version, - } + custom_headers = get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version + ) selected_data_generator = select_data_generator( response=response, user_api_key_dict=user_api_key_dict, @@ -3997,11 +4019,12 @@ async def completion( media_type="text/event-stream", headers=custom_headers, ) - - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version + )) return response except Exception as e: @@ -4206,13 +4229,13 @@ async def embeddings( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version - fastapi_response.headers["x-litellm-model-region"] = ( - user_api_key_dict.allowed_model_region or "" - ) + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", "") + )) return response except Exception as e: @@ -4387,13 +4410,13 @@ async def image_generation( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version - fastapi_response.headers["x-litellm-model-region"] = ( - user_api_key_dict.allowed_model_region or "" - ) + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", "") + )) return response except Exception as e: @@ -4586,13 +4609,13 @@ async def audio_transcriptions( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version - fastapi_response.headers["x-litellm-model-region"] = ( - user_api_key_dict.allowed_model_region or "" - ) + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", "") + )) return response except Exception as e: @@ -4767,13 +4790,13 @@ async def moderations( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version - fastapi_response.headers["x-litellm-model-region"] = ( - user_api_key_dict.allowed_model_region or "" - ) + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", "") + )) return response except Exception as e: