diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index e1c87f1a32..a7b0c937f0 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -483,9 +483,12 @@ def test_redis_cache_completion_stream(): max_tokens=40, temperature=0.2, stream=True, + caching=True, ) response_1_content = "" + response_1_id = None for chunk in response1: + response_1_id = chunk.id print(chunk) response_1_content += chunk.choices[0].delta.content or "" print(response_1_content) @@ -497,16 +500,22 @@ def test_redis_cache_completion_stream(): max_tokens=40, temperature=0.2, stream=True, + caching=True, ) response_2_content = "" + response_2_id = None for chunk in response2: + response_2_id = chunk.id print(chunk) response_2_content += chunk.choices[0].delta.content or "" print("\nresponse 1", response_1_content) print("\nresponse 2", response_2_content) assert ( - response_1_content == response_2_content + response_1_id == response_2_id ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" + # assert ( + # response_1_content == response_2_content + # ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" litellm.success_callback = [] litellm._async_success_callback = [] litellm.cache = None diff --git a/litellm/utils.py b/litellm/utils.py index 0e718d31a3..09c23a7253 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1169,7 +1169,7 @@ class Logging: verbose_logger.debug(f"success callbacks: {litellm.success_callback}") ## BUILD COMPLETE STREAMED RESPONSE complete_streaming_response = None - if self.stream: + if self.stream and isinstance(result, ModelResponse): if ( result.choices[0].finish_reason is not None ): # if it's the last chunk @@ -8654,6 +8654,8 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") + if hasattr(chunk, "id"): + model_response.id = chunk.id if response_obj["is_finished"]: model_response.choices[0].finish_reason = response_obj[ "finish_reason" @@ -8676,6 +8678,8 @@ class CustomStreamWrapper: model_response.system_fingerprint = getattr( response_obj["original_chunk"], "system_fingerprint", None ) + if hasattr(response_obj["original_chunk"], "id"): + model_response.id = response_obj["original_chunk"].id if response_obj["logprobs"] is not None: model_response.choices[0].logprobs = response_obj["logprobs"]