fix(utils.py): support raw response headers for streaming requests

This commit is contained in:
Krrish Dholakia 2024-07-23 11:58:58 -07:00
parent d1ffb4de5f
commit f64a3309d1
5 changed files with 60 additions and 30 deletions

View file

@ -2909,6 +2909,7 @@ async def chat_completion(
fastest_response_batch_completion = hidden_params.get(
"fastest_response_batch_completion", None
)
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
# Post Call Processing
if llm_router is not None:
@ -2931,6 +2932,7 @@ async def chat_completion(
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
**additional_headers,
)
selected_data_generator = select_data_generator(
response=response,
@ -2948,8 +2950,10 @@ async def chat_completion(
user_api_key_dict=user_api_key_dict, response=response
)
hidden_params = getattr(response, "_hidden_params", {}) or {}
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
hidden_params = (
getattr(response, "_hidden_params", {}) or {}
) # get any updated response headers
additional_headers = hidden_params.get("additional_headers", {}) or {}
fastapi_response.headers.update(
get_custom_headers(

View file

@ -1364,6 +1364,12 @@ def test_completion_openai_response_headers():
print("response_headers=", response._response_headers)
assert response._response_headers is not None
assert "x-ratelimit-remaining-tokens" in response._response_headers
assert isinstance(
response._hidden_params["additional_headers"][
"llm_provider-x-ratelimit-remaining-requests"
],
str,
)
# /chat/completion - with streaming
@ -1376,6 +1382,12 @@ def test_completion_openai_response_headers():
print("streaming response_headers=", response_headers)
assert response_headers is not None
assert "x-ratelimit-remaining-tokens" in response_headers
assert isinstance(
response._hidden_params["additional_headers"][
"llm_provider-x-ratelimit-remaining-requests"
],
str,
)
for chunk in streaming_response:
print("chunk=", chunk)
@ -1390,6 +1402,12 @@ def test_completion_openai_response_headers():
print("embedding_response_headers=", embedding_response_headers)
assert embedding_response_headers is not None
assert "x-ratelimit-remaining-tokens" in embedding_response_headers
assert isinstance(
response._hidden_params["additional_headers"][
"llm_provider-x-ratelimit-remaining-requests"
],
str,
)
litellm.return_response_headers = False

View file

@ -881,6 +881,7 @@ def test_completion_azure_ai():
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_cost_hidden_params(sync_mode):
litellm.return_response_headers = True
if sync_mode:
response = litellm.completion(
model="gpt-3.5-turbo",
@ -896,9 +897,6 @@ async def test_completion_cost_hidden_params(sync_mode):
assert "response_cost" in response._hidden_params
assert isinstance(response._hidden_params["response_cost"], float)
assert isinstance(
response._hidden_params["llm_provider-x-ratelimit-remaining-requests"], float
)
def test_vertex_ai_gemini_predict_cost():

View file

@ -1988,25 +1988,30 @@ async def test_hf_completion_tgi_stream():
# test on openai completion call
def test_openai_chat_completion_call():
try:
litellm.set_verbose = False
print(f"making openai chat completion call")
response = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
complete_response = ""
start_time = time.time()
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
print(f"outside chunk: {chunk}")
if finished:
break
complete_response += chunk
# print(f'complete_chunk: {complete_response}')
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"complete response: {complete_response}")
except:
print(f"error occurred: {traceback.format_exc()}")
pass
litellm.set_verbose = False
litellm.return_response_headers = True
print(f"making openai chat completion call")
response = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
assert isinstance(
response._hidden_params["additional_headers"][
"llm_provider-x-ratelimit-remaining-requests"
],
str,
)
print(f"response._hidden_params: {response._hidden_params}")
complete_response = ""
start_time = time.time()
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
print(f"outside chunk: {chunk}")
if finished:
break
complete_response += chunk
# print(f'complete_chunk: {complete_response}')
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"complete response: {complete_response}")
# test_openai_chat_completion_call()

View file

@ -5679,13 +5679,13 @@ def convert_to_model_response_object(
):
received_args = locals()
if _response_headers is not None:
llm_response_headers = {
"{}-{}".format("llm_provider", k): v for k, v in _response_headers.items()
}
if hidden_params is not None:
hidden_params["additional_headers"] = {
"{}-{}".format("llm_provider", k): v
for k, v in _response_headers.items()
}
hidden_params["additional_headers"] = llm_response_headers
else:
hidden_params = {"additional_headers": _response_headers}
hidden_params = {"additional_headers": llm_response_headers}
### CHECK IF ERROR IN RESPONSE ### - openrouter returns these in the dictionary
if (
response_object is not None
@ -8320,8 +8320,13 @@ class CustomStreamWrapper:
or {}
)
self._hidden_params = {
"model_id": (_model_info.get("id", None))
"model_id": (_model_info.get("id", None)),
} # returned as x-litellm-model-id response header in proxy
if _response_headers is not None:
self._hidden_params["additional_headers"] = {
"{}-{}".format("llm_provider", k): v
for k, v in _response_headers.items()
}
self._response_headers = _response_headers
self.response_id = None
self.logging_loop = None