fix(utils.py): guarantee openai-compatible headers always exist in response

Fixes https://github.com/BerriAI/litellm/issues/5957
This commit is contained in:
Krrish Dholakia 2024-09-28 16:57:04 -07:00
parent 498e14ba59
commit 55d7bc7f32
3 changed files with 42 additions and 19 deletions

View file

@ -1002,6 +1002,10 @@ def client(original_function):
result._hidden_params["response_cost"] = (
logging_obj._response_cost_calculator(result=result)
)
result._hidden_params["additional_headers"] = process_response_headers(
result._hidden_params.get("additional_headers") or {}
) # GUARANTEE OPENAI HEADERS IN RESPONSE
result._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
@ -1410,6 +1414,9 @@ def client(original_function):
result._hidden_params["response_cost"] = (
logging_obj._response_cost_calculator(result=result)
)
result._hidden_params["additional_headers"] = process_response_headers(
result._hidden_params.get("additional_headers") or {}
) # GUARANTEE OPENAI HEADERS IN RESPONSE
if (
isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse)
@ -6394,24 +6401,9 @@ class CustomStreamWrapper:
"model_id": (_model_info.get("id", None)),
} # returned as x-litellm-model-id response header in proxy
if _response_headers is not None:
openai_headers = {}
processed_headers = {}
additional_headers = {}
for k, v in _response_headers.items():
if k in OPENAI_RESPONSE_HEADERS: # return openai-compatible headers
openai_headers[k] = v
if k.startswith(
"llm_provider-"
): # return raw provider headers (incl. openai-compatible ones)
processed_headers[k] = v
else:
additional_headers["{}-{}".format("llm_provider", k)] = v
self._hidden_params["additional_headers"] = {
**openai_headers,
**processed_headers,
**additional_headers,
}
self._hidden_params["additional_headers"] = process_response_headers(
_response_headers or {}
) # GUARANTEE OPENAI HEADERS IN RESPONSE
self._response_headers = _response_headers
self.response_id = None
@ -9263,3 +9255,29 @@ def has_tool_call_blocks(messages: List[AllMessageValues]) -> bool:
if message.get("tool_calls") is not None:
return True
return False
def process_response_headers(response_headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {}
processed_headers = {}
additional_headers = {}
for k, v in response_headers.items():
if k in OPENAI_RESPONSE_HEADERS: # return openai-compatible headers
openai_headers[k] = v
if k.startswith(
"llm_provider-"
): # return raw provider headers (incl. openai-compatible ones)
processed_headers[k] = v
else:
additional_headers["{}-{}".format("llm_provider", k)] = v
## GUARANTEE OPENAI HEADERS IN RESPONSE
for item in OPENAI_RESPONSE_HEADERS:
if item not in openai_headers:
openai_headers[item] = None
additional_headers = {
**openai_headers,
**processed_headers,
**additional_headers,
}
return additional_headers

View file

@ -4547,7 +4547,12 @@ async def test_completion_ai21_chat():
@pytest.mark.parametrize(
"model",
["gpt-4o", "azure/chatgpt-v-2", "claude-3-sonnet-20240229"], #
[
"gpt-4o",
"azure/chatgpt-v-2",
"claude-3-sonnet-20240229",
"fireworks_ai/mixtral-8x7b-instruct",
],
)
@pytest.mark.parametrize(
"stream",