diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 11086315df..48a59cab8d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -352,6 +352,24 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict: return pydantic_obj.dict() +def get_custom_headers(*, + model_id: Optional[str] = None, + cache_key: Optional[str] = None, + api_base: Optional[str] = None, + version: Optional[str] = None, + model_region: Optional[str] = None + ) -> dict: + exclude_values = {'', None} + headers = { + 'x-litellm-model-id': model_id, + 'x-litellm-cache-key': cache_key, + 'x-litellm-model-api-base': api_base, + 'x-litellm-version': version, + '"x-litellm-model-region': model_region + } + return {key: value for key, value in headers.items() if value not in exclude_values} + + async def check_request_disconnection(request: Request, llm_api_call_task): """ Asynchronously checks if the request is disconnected at regular intervals. @@ -3747,13 +3765,13 @@ async def chat_completion( if ( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses - custom_headers = { - "x-litellm-model-id": model_id, - "x-litellm-cache-key": cache_key, - "x-litellm-model-api-base": api_base, - "x-litellm-version": version, - "x-litellm-model-region": user_api_key_dict.allowed_model_region or "", - } + custom_headers = get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=user_api_key_dict.allowed_model_region + ) selected_data_generator = select_data_generator( response=response, user_api_key_dict=user_api_key_dict, @@ -3765,13 +3783,13 @@ async def chat_completion( headers=custom_headers, ) - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version - fastapi_response.headers["x-litellm-model-region"] = ( - user_api_key_dict.allowed_model_region or "" - ) + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=user_api_key_dict.allowed_model_region + )) ### CALL HOOKS ### - modify outgoing data response = await proxy_logging_obj.post_call_success_hook( @@ -3957,12 +3975,12 @@ async def completion( if ( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses - custom_headers = { - "x-litellm-model-id": model_id, - "x-litellm-cache-key": cache_key, - "x-litellm-model-api-base": api_base, - "x-litellm-version": version, - } + custom_headers = get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version + ) selected_data_generator = select_data_generator( response=response, user_api_key_dict=user_api_key_dict, @@ -3974,11 +3992,12 @@ async def completion( media_type="text/event-stream", headers=custom_headers, ) - - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version + )) return response except Exception as e: @@ -4183,13 +4202,13 @@ async def embeddings( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version - fastapi_response.headers["x-litellm-model-region"] = ( - user_api_key_dict.allowed_model_region or "" - ) + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=user_api_key_dict.allowed_model_region + )) return response except Exception as e: @@ -4364,13 +4383,13 @@ async def image_generation( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version - fastapi_response.headers["x-litellm-model-region"] = ( - user_api_key_dict.allowed_model_region or "" - ) + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=user_api_key_dict.allowed_model_region + )) return response except Exception as e: @@ -4563,13 +4582,13 @@ async def audio_transcriptions( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version - fastapi_response.headers["x-litellm-model-region"] = ( - user_api_key_dict.allowed_model_region or "" - ) + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=user_api_key_dict.allowed_model_region + )) return response except Exception as e: @@ -4744,13 +4763,13 @@ async def moderations( cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" - fastapi_response.headers["x-litellm-model-id"] = model_id - fastapi_response.headers["x-litellm-cache-key"] = cache_key - fastapi_response.headers["x-litellm-model-api-base"] = api_base - fastapi_response.headers["x-litellm-version"] = version - fastapi_response.headers["x-litellm-model-region"] = ( - user_api_key_dict.allowed_model_region or "" - ) + fastapi_response.headers.update(get_custom_headers( + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=user_api_key_dict.allowed_model_region + )) return response except Exception as e: