fix(proxy_server.py): fix invalid header string

This commit is contained in:
Krrish Dholakia 2024-05-16 21:05:10 -07:00
parent ee98074273
commit 10a672634d

View file

@ -352,23 +352,26 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
return pydantic_obj.dict() return pydantic_obj.dict()
def get_custom_headers(*, def get_custom_headers(
*,
model_id: Optional[str] = None, model_id: Optional[str] = None,
cache_key: Optional[str] = None, cache_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
version: Optional[str] = None, version: Optional[str] = None,
model_region: Optional[str] = None model_region: Optional[str] = None,
) -> dict: ) -> dict:
exclude_values = {'', None} exclude_values = {"", None}
headers = { headers = {
'x-litellm-model-id': model_id, "x-litellm-model-id": model_id,
'x-litellm-cache-key': cache_key, "x-litellm-cache-key": cache_key,
'x-litellm-model-api-base': api_base, "x-litellm-model-api-base": api_base,
'x-litellm-version': version, "x-litellm-version": version,
'"x-litellm-model-region': model_region "x-litellm-model-region": model_region,
} }
try: try:
return {key: value for key, value in headers.items() if value not in exclude_values} return {
key: value for key, value in headers.items() if value not in exclude_values
}
except Exception as e: except Exception as e:
verbose_proxy_logger.error(f"Error setting custom headers: {e}") verbose_proxy_logger.error(f"Error setting custom headers: {e}")
return {} return {}
@ -3797,7 +3800,7 @@ async def chat_completion(
cache_key=cache_key, cache_key=cache_key,
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", "") model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
) )
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, response=response,
@ -3810,13 +3813,15 @@ async def chat_completion(
headers=custom_headers, headers=custom_headers,
) )
fastapi_response.headers.update(get_custom_headers( fastapi_response.headers.update(
get_custom_headers(
model_id=model_id, model_id=model_id,
cache_key=cache_key, cache_key=cache_key,
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", "") model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)) )
)
### CALL HOOKS ### - modify outgoing data ### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook( response = await proxy_logging_obj.post_call_success_hook(
@ -4006,7 +4011,7 @@ async def completion(
model_id=model_id, model_id=model_id,
cache_key=cache_key, cache_key=cache_key,
api_base=api_base, api_base=api_base,
version=version version=version,
) )
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, response=response,
@ -4019,12 +4024,14 @@ async def completion(
media_type="text/event-stream", media_type="text/event-stream",
headers=custom_headers, headers=custom_headers,
) )
fastapi_response.headers.update(get_custom_headers( fastapi_response.headers.update(
get_custom_headers(
model_id=model_id, model_id=model_id,
cache_key=cache_key, cache_key=cache_key,
api_base=api_base, api_base=api_base,
version=version version=version,
)) )
)
return response return response
except Exception as e: except Exception as e:
@ -4229,13 +4236,15 @@ async def embeddings(
cache_key = hidden_params.get("cache_key", None) or "" cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or "" api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(get_custom_headers( fastapi_response.headers.update(
get_custom_headers(
model_id=model_id, model_id=model_id,
cache_key=cache_key, cache_key=cache_key,
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", "") model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)) )
)
return response return response
except Exception as e: except Exception as e:
@ -4410,13 +4419,15 @@ async def image_generation(
cache_key = hidden_params.get("cache_key", None) or "" cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or "" api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(get_custom_headers( fastapi_response.headers.update(
get_custom_headers(
model_id=model_id, model_id=model_id,
cache_key=cache_key, cache_key=cache_key,
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", "") model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)) )
)
return response return response
except Exception as e: except Exception as e:
@ -4609,13 +4620,15 @@ async def audio_transcriptions(
cache_key = hidden_params.get("cache_key", None) or "" cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or "" api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(get_custom_headers( fastapi_response.headers.update(
get_custom_headers(
model_id=model_id, model_id=model_id,
cache_key=cache_key, cache_key=cache_key,
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", "") model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)) )
)
return response return response
except Exception as e: except Exception as e:
@ -4790,13 +4803,15 @@ async def moderations(
cache_key = hidden_params.get("cache_key", None) or "" cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or "" api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(get_custom_headers( fastapi_response.headers.update(
get_custom_headers(
model_id=model_id, model_id=model_id,
cache_key=cache_key, cache_key=cache_key,
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", "") model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)) )
)
return response return response
except Exception as e: except Exception as e: