fix(utils.py): support raw response headers for streaming requests

This commit is contained in:
Krrish Dholakia 2024-07-23 11:58:58 -07:00
parent d1ffb4de5f
commit f64a3309d1
5 changed files with 60 additions and 30 deletions

View file

@ -2909,6 +2909,7 @@ async def chat_completion(
fastest_response_batch_completion = hidden_params.get( fastest_response_batch_completion = hidden_params.get(
"fastest_response_batch_completion", None "fastest_response_batch_completion", None
) )
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
# Post Call Processing # Post Call Processing
if llm_router is not None: if llm_router is not None:
@ -2931,6 +2932,7 @@ async def chat_completion(
response_cost=response_cost, response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion, fastest_response_batch_completion=fastest_response_batch_completion,
**additional_headers,
) )
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, response=response,
@ -2948,8 +2950,10 @@ async def chat_completion(
user_api_key_dict=user_api_key_dict, response=response user_api_key_dict=user_api_key_dict, response=response
) )
hidden_params = getattr(response, "_hidden_params", {}) or {} hidden_params = (
additional_headers: dict = hidden_params.get("additional_headers", {}) or {} getattr(response, "_hidden_params", {}) or {}
) # get any updated response headers
additional_headers = hidden_params.get("additional_headers", {}) or {}
fastapi_response.headers.update( fastapi_response.headers.update(
get_custom_headers( get_custom_headers(

View file

@ -1364,6 +1364,12 @@ def test_completion_openai_response_headers():
print("response_headers=", response._response_headers) print("response_headers=", response._response_headers)
assert response._response_headers is not None assert response._response_headers is not None
assert "x-ratelimit-remaining-tokens" in response._response_headers assert "x-ratelimit-remaining-tokens" in response._response_headers
assert isinstance(
response._hidden_params["additional_headers"][
"llm_provider-x-ratelimit-remaining-requests"
],
str,
)
# /chat/completion - with streaming # /chat/completion - with streaming
@ -1376,6 +1382,12 @@ def test_completion_openai_response_headers():
print("streaming response_headers=", response_headers) print("streaming response_headers=", response_headers)
assert response_headers is not None assert response_headers is not None
assert "x-ratelimit-remaining-tokens" in response_headers assert "x-ratelimit-remaining-tokens" in response_headers
assert isinstance(
response._hidden_params["additional_headers"][
"llm_provider-x-ratelimit-remaining-requests"
],
str,
)
for chunk in streaming_response: for chunk in streaming_response:
print("chunk=", chunk) print("chunk=", chunk)
@ -1390,6 +1402,12 @@ def test_completion_openai_response_headers():
print("embedding_response_headers=", embedding_response_headers) print("embedding_response_headers=", embedding_response_headers)
assert embedding_response_headers is not None assert embedding_response_headers is not None
assert "x-ratelimit-remaining-tokens" in embedding_response_headers assert "x-ratelimit-remaining-tokens" in embedding_response_headers
assert isinstance(
response._hidden_params["additional_headers"][
"llm_provider-x-ratelimit-remaining-requests"
],
str,
)
litellm.return_response_headers = False litellm.return_response_headers = False

View file

@ -881,6 +881,7 @@ def test_completion_azure_ai():
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_cost_hidden_params(sync_mode): async def test_completion_cost_hidden_params(sync_mode):
litellm.return_response_headers = True
if sync_mode: if sync_mode:
response = litellm.completion( response = litellm.completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
@ -896,9 +897,6 @@ async def test_completion_cost_hidden_params(sync_mode):
assert "response_cost" in response._hidden_params assert "response_cost" in response._hidden_params
assert isinstance(response._hidden_params["response_cost"], float) assert isinstance(response._hidden_params["response_cost"], float)
assert isinstance(
response._hidden_params["llm_provider-x-ratelimit-remaining-requests"], float
)
def test_vertex_ai_gemini_predict_cost(): def test_vertex_ai_gemini_predict_cost():

View file

@ -1988,10 +1988,18 @@ async def test_hf_completion_tgi_stream():
# test on openai completion call # test on openai completion call
def test_openai_chat_completion_call(): def test_openai_chat_completion_call():
try:
litellm.set_verbose = False litellm.set_verbose = False
litellm.return_response_headers = True
print(f"making openai chat completion call") print(f"making openai chat completion call")
response = completion(model="gpt-3.5-turbo", messages=messages, stream=True) response = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
assert isinstance(
response._hidden_params["additional_headers"][
"llm_provider-x-ratelimit-remaining-requests"
],
str,
)
print(f"response._hidden_params: {response._hidden_params}")
complete_response = "" complete_response = ""
start_time = time.time() start_time = time.time()
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
@ -2004,9 +2012,6 @@ def test_openai_chat_completion_call():
if complete_response.strip() == "": if complete_response.strip() == "":
raise Exception("Empty response received") raise Exception("Empty response received")
print(f"complete response: {complete_response}") print(f"complete response: {complete_response}")
except:
print(f"error occurred: {traceback.format_exc()}")
pass
# test_openai_chat_completion_call() # test_openai_chat_completion_call()

View file

@ -5679,13 +5679,13 @@ def convert_to_model_response_object(
): ):
received_args = locals() received_args = locals()
if _response_headers is not None: if _response_headers is not None:
if hidden_params is not None: llm_response_headers = {
hidden_params["additional_headers"] = { "{}-{}".format("llm_provider", k): v for k, v in _response_headers.items()
"{}-{}".format("llm_provider", k): v
for k, v in _response_headers.items()
} }
if hidden_params is not None:
hidden_params["additional_headers"] = llm_response_headers
else: else:
hidden_params = {"additional_headers": _response_headers} hidden_params = {"additional_headers": llm_response_headers}
### CHECK IF ERROR IN RESPONSE ### - openrouter returns these in the dictionary ### CHECK IF ERROR IN RESPONSE ### - openrouter returns these in the dictionary
if ( if (
response_object is not None response_object is not None
@ -8320,8 +8320,13 @@ class CustomStreamWrapper:
or {} or {}
) )
self._hidden_params = { self._hidden_params = {
"model_id": (_model_info.get("id", None)) "model_id": (_model_info.get("id", None)),
} # returned as x-litellm-model-id response header in proxy } # returned as x-litellm-model-id response header in proxy
if _response_headers is not None:
self._hidden_params["additional_headers"] = {
"{}-{}".format("llm_provider", k): v
for k, v in _response_headers.items()
}
self._response_headers = _response_headers self._response_headers = _response_headers
self.response_id = None self.response_id = None
self.logging_loop = None self.logging_loop = None