Merge pull request #3701 from paneru-rajan/Issue-3675-remove-empty-valued-header

Exclude custom headers from response if the value is None or empty string
This commit is contained in:
Krish Dholakia 2024-05-16 17:42:07 -07:00 committed by GitHub
commit 7502e15295
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -352,6 +352,28 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
return pydantic_obj.dict() return pydantic_obj.dict()
def get_custom_headers(*,
model_id: Optional[str] = None,
cache_key: Optional[str] = None,
api_base: Optional[str] = None,
version: Optional[str] = None,
model_region: Optional[str] = None
) -> dict:
exclude_values = {'', None}
headers = {
'x-litellm-model-id': model_id,
'x-litellm-cache-key': cache_key,
'x-litellm-model-api-base': api_base,
'x-litellm-version': version,
'"x-litellm-model-region': model_region
}
try:
return {key: value for key, value in headers.items() if value not in exclude_values}
except Exception as e:
verbose_proxy_logger.error(f"Error setting custom headers: {e}")
return {}
async def check_request_disconnection(request: Request, llm_api_call_task): async def check_request_disconnection(request: Request, llm_api_call_task):
""" """
Asynchronously checks if the request is disconnected at regular intervals. Asynchronously checks if the request is disconnected at regular intervals.
@ -3770,13 +3792,13 @@ async def chat_completion(
if ( if (
"stream" in data and data["stream"] == True "stream" in data and data["stream"] == True
): # use generate_responses to stream responses ): # use generate_responses to stream responses
custom_headers = { custom_headers = get_custom_headers(
"x-litellm-model-id": model_id, model_id=model_id,
"x-litellm-cache-key": cache_key, cache_key=cache_key,
"x-litellm-model-api-base": api_base, api_base=api_base,
"x-litellm-version": version, version=version,
"x-litellm-model-region": user_api_key_dict.allowed_model_region or "", model_region=getattr(user_api_key_dict, "allowed_model_region", "")
} )
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, response=response,
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
@ -3788,13 +3810,13 @@ async def chat_completion(
headers=custom_headers, headers=custom_headers,
) )
fastapi_response.headers["x-litellm-model-id"] = model_id fastapi_response.headers.update(get_custom_headers(
fastapi_response.headers["x-litellm-cache-key"] = cache_key model_id=model_id,
fastapi_response.headers["x-litellm-model-api-base"] = api_base cache_key=cache_key,
fastapi_response.headers["x-litellm-version"] = version api_base=api_base,
fastapi_response.headers["x-litellm-model-region"] = ( version=version,
user_api_key_dict.allowed_model_region or "" model_region=getattr(user_api_key_dict, "allowed_model_region", "")
) ))
### CALL HOOKS ### - modify outgoing data ### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook( response = await proxy_logging_obj.post_call_success_hook(
@ -3980,12 +4002,12 @@ async def completion(
if ( if (
"stream" in data and data["stream"] == True "stream" in data and data["stream"] == True
): # use generate_responses to stream responses ): # use generate_responses to stream responses
custom_headers = { custom_headers = get_custom_headers(
"x-litellm-model-id": model_id, model_id=model_id,
"x-litellm-cache-key": cache_key, cache_key=cache_key,
"x-litellm-model-api-base": api_base, api_base=api_base,
"x-litellm-version": version, version=version
} )
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, response=response,
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
@ -3997,11 +4019,12 @@ async def completion(
media_type="text/event-stream", media_type="text/event-stream",
headers=custom_headers, headers=custom_headers,
) )
fastapi_response.headers.update(get_custom_headers(
fastapi_response.headers["x-litellm-model-id"] = model_id model_id=model_id,
fastapi_response.headers["x-litellm-cache-key"] = cache_key cache_key=cache_key,
fastapi_response.headers["x-litellm-model-api-base"] = api_base api_base=api_base,
fastapi_response.headers["x-litellm-version"] = version version=version
))
return response return response
except Exception as e: except Exception as e:
@ -4206,13 +4229,13 @@ async def embeddings(
cache_key = hidden_params.get("cache_key", None) or "" cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or "" api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers["x-litellm-model-id"] = model_id fastapi_response.headers.update(get_custom_headers(
fastapi_response.headers["x-litellm-cache-key"] = cache_key model_id=model_id,
fastapi_response.headers["x-litellm-model-api-base"] = api_base cache_key=cache_key,
fastapi_response.headers["x-litellm-version"] = version api_base=api_base,
fastapi_response.headers["x-litellm-model-region"] = ( version=version,
user_api_key_dict.allowed_model_region or "" model_region=getattr(user_api_key_dict, "allowed_model_region", "")
) ))
return response return response
except Exception as e: except Exception as e:
@ -4387,13 +4410,13 @@ async def image_generation(
cache_key = hidden_params.get("cache_key", None) or "" cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or "" api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers["x-litellm-model-id"] = model_id fastapi_response.headers.update(get_custom_headers(
fastapi_response.headers["x-litellm-cache-key"] = cache_key model_id=model_id,
fastapi_response.headers["x-litellm-model-api-base"] = api_base cache_key=cache_key,
fastapi_response.headers["x-litellm-version"] = version api_base=api_base,
fastapi_response.headers["x-litellm-model-region"] = ( version=version,
user_api_key_dict.allowed_model_region or "" model_region=getattr(user_api_key_dict, "allowed_model_region", "")
) ))
return response return response
except Exception as e: except Exception as e:
@ -4586,13 +4609,13 @@ async def audio_transcriptions(
cache_key = hidden_params.get("cache_key", None) or "" cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or "" api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers["x-litellm-model-id"] = model_id fastapi_response.headers.update(get_custom_headers(
fastapi_response.headers["x-litellm-cache-key"] = cache_key model_id=model_id,
fastapi_response.headers["x-litellm-model-api-base"] = api_base cache_key=cache_key,
fastapi_response.headers["x-litellm-version"] = version api_base=api_base,
fastapi_response.headers["x-litellm-model-region"] = ( version=version,
user_api_key_dict.allowed_model_region or "" model_region=getattr(user_api_key_dict, "allowed_model_region", "")
) ))
return response return response
except Exception as e: except Exception as e:
@ -4767,13 +4790,13 @@ async def moderations(
cache_key = hidden_params.get("cache_key", None) or "" cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or "" api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers["x-litellm-model-id"] = model_id fastapi_response.headers.update(get_custom_headers(
fastapi_response.headers["x-litellm-cache-key"] = cache_key model_id=model_id,
fastapi_response.headers["x-litellm-model-api-base"] = api_base cache_key=cache_key,
fastapi_response.headers["x-litellm-version"] = version api_base=api_base,
fastapi_response.headers["x-litellm-model-region"] = ( version=version,
user_api_key_dict.allowed_model_region or "" model_region=getattr(user_api_key_dict, "allowed_model_region", "")
) ))
return response return response
except Exception as e: except Exception as e: