From f64a3309d14c57ee9ee5c5f4e13d2addb21ed6c4 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 23 Jul 2024 11:58:58 -0700 Subject: [PATCH] fix(utils.py): support raw response headers for streaming requests --- litellm/proxy/proxy_server.py | 8 +++-- litellm/tests/test_completion.py | 18 +++++++++++ litellm/tests/test_completion_cost.py | 4 +-- litellm/tests/test_streaming.py | 43 +++++++++++++++------------ litellm/utils.py | 17 +++++++---- 5 files changed, 60 insertions(+), 30 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 040348275..0ac1d82e0 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2909,6 +2909,7 @@ async def chat_completion( fastest_response_batch_completion = hidden_params.get( "fastest_response_batch_completion", None ) + additional_headers: dict = hidden_params.get("additional_headers", {}) or {} # Post Call Processing if llm_router is not None: @@ -2931,6 +2932,7 @@ async def chat_completion( response_cost=response_cost, model_region=getattr(user_api_key_dict, "allowed_model_region", ""), fastest_response_batch_completion=fastest_response_batch_completion, + **additional_headers, ) selected_data_generator = select_data_generator( response=response, @@ -2948,8 +2950,10 @@ async def chat_completion( user_api_key_dict=user_api_key_dict, response=response ) - hidden_params = getattr(response, "_hidden_params", {}) or {} - additional_headers: dict = hidden_params.get("additional_headers", {}) or {} + hidden_params = ( + getattr(response, "_hidden_params", {}) or {} + ) # get any updated response headers + additional_headers = hidden_params.get("additional_headers", {}) or {} fastapi_response.headers.update( get_custom_headers( diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 770498962..c2ce836ef 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1364,6 +1364,12 @@ def test_completion_openai_response_headers(): print("response_headers=", response._response_headers) assert response._response_headers is not None assert "x-ratelimit-remaining-tokens" in response._response_headers + assert isinstance( + response._hidden_params["additional_headers"][ + "llm_provider-x-ratelimit-remaining-requests" + ], + str, + ) # /chat/completion - with streaming @@ -1376,6 +1382,12 @@ def test_completion_openai_response_headers(): print("streaming response_headers=", response_headers) assert response_headers is not None assert "x-ratelimit-remaining-tokens" in response_headers + assert isinstance( + response._hidden_params["additional_headers"][ + "llm_provider-x-ratelimit-remaining-requests" + ], + str, + ) for chunk in streaming_response: print("chunk=", chunk) @@ -1390,6 +1402,12 @@ def test_completion_openai_response_headers(): print("embedding_response_headers=", embedding_response_headers) assert embedding_response_headers is not None assert "x-ratelimit-remaining-tokens" in embedding_response_headers + assert isinstance( + response._hidden_params["additional_headers"][ + "llm_provider-x-ratelimit-remaining-requests" + ], + str, + ) litellm.return_response_headers = False diff --git a/litellm/tests/test_completion_cost.py b/litellm/tests/test_completion_cost.py index 6e4425fb6..289e200d9 100644 --- a/litellm/tests/test_completion_cost.py +++ b/litellm/tests/test_completion_cost.py @@ -881,6 +881,7 @@ def test_completion_azure_ai(): @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio async def test_completion_cost_hidden_params(sync_mode): + litellm.return_response_headers = True if sync_mode: response = litellm.completion( model="gpt-3.5-turbo", @@ -896,9 +897,6 @@ async def test_completion_cost_hidden_params(sync_mode): assert "response_cost" in response._hidden_params assert isinstance(response._hidden_params["response_cost"], float) - assert isinstance( - response._hidden_params["llm_provider-x-ratelimit-remaining-requests"], float - ) def test_vertex_ai_gemini_predict_cost(): diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 64c2eb4ab..768c8752c 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1988,25 +1988,30 @@ async def test_hf_completion_tgi_stream(): # test on openai completion call def test_openai_chat_completion_call(): - try: - litellm.set_verbose = False - print(f"making openai chat completion call") - response = completion(model="gpt-3.5-turbo", messages=messages, stream=True) - complete_response = "" - start_time = time.time() - for idx, chunk in enumerate(response): - chunk, finished = streaming_format_tests(idx, chunk) - print(f"outside chunk: {chunk}") - if finished: - break - complete_response += chunk - # print(f'complete_chunk: {complete_response}') - if complete_response.strip() == "": - raise Exception("Empty response received") - print(f"complete response: {complete_response}") - except: - print(f"error occurred: {traceback.format_exc()}") - pass + litellm.set_verbose = False + litellm.return_response_headers = True + print(f"making openai chat completion call") + response = completion(model="gpt-3.5-turbo", messages=messages, stream=True) + assert isinstance( + response._hidden_params["additional_headers"][ + "llm_provider-x-ratelimit-remaining-requests" + ], + str, + ) + + print(f"response._hidden_params: {response._hidden_params}") + complete_response = "" + start_time = time.time() + for idx, chunk in enumerate(response): + chunk, finished = streaming_format_tests(idx, chunk) + print(f"outside chunk: {chunk}") + if finished: + break + complete_response += chunk + # print(f'complete_chunk: {complete_response}') + if complete_response.strip() == "": + raise Exception("Empty response received") + print(f"complete response: {complete_response}") # test_openai_chat_completion_call() diff --git a/litellm/utils.py b/litellm/utils.py index 0beb041e9..7f615ab61 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5679,13 +5679,13 @@ def convert_to_model_response_object( ): received_args = locals() if _response_headers is not None: + llm_response_headers = { + "{}-{}".format("llm_provider", k): v for k, v in _response_headers.items() + } if hidden_params is not None: - hidden_params["additional_headers"] = { - "{}-{}".format("llm_provider", k): v - for k, v in _response_headers.items() - } + hidden_params["additional_headers"] = llm_response_headers else: - hidden_params = {"additional_headers": _response_headers} + hidden_params = {"additional_headers": llm_response_headers} ### CHECK IF ERROR IN RESPONSE ### - openrouter returns these in the dictionary if ( response_object is not None @@ -8320,8 +8320,13 @@ class CustomStreamWrapper: or {} ) self._hidden_params = { - "model_id": (_model_info.get("id", None)) + "model_id": (_model_info.get("id", None)), } # returned as x-litellm-model-id response header in proxy + if _response_headers is not None: + self._hidden_params["additional_headers"] = { + "{}-{}".format("llm_provider", k): v + for k, v in _response_headers.items() + } self._response_headers = _response_headers self.response_id = None self.logging_loop = None