fix(utils.py): fix redis cache test

This commit is contained in:
Krrish Dholakia 2024-02-26 22:04:24 -08:00
parent a428501e68
commit 1447621128
2 changed files with 15 additions and 2 deletions

View file

@ -483,9 +483,12 @@ def test_redis_cache_completion_stream():
max_tokens=40, max_tokens=40,
temperature=0.2, temperature=0.2,
stream=True, stream=True,
caching=True,
) )
response_1_content = "" response_1_content = ""
response_1_id = None
for chunk in response1: for chunk in response1:
response_1_id = chunk.id
print(chunk) print(chunk)
response_1_content += chunk.choices[0].delta.content or "" response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content) print(response_1_content)
@ -497,16 +500,22 @@ def test_redis_cache_completion_stream():
max_tokens=40, max_tokens=40,
temperature=0.2, temperature=0.2,
stream=True, stream=True,
caching=True,
) )
response_2_content = "" response_2_content = ""
response_2_id = None
for chunk in response2: for chunk in response2:
response_2_id = chunk.id
print(chunk) print(chunk)
response_2_content += chunk.choices[0].delta.content or "" response_2_content += chunk.choices[0].delta.content or ""
print("\nresponse 1", response_1_content) print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content) print("\nresponse 2", response_2_content)
assert ( assert (
response_1_content == response_2_content response_1_id == response_2_id
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
# assert (
# response_1_content == response_2_content
# ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
litellm.cache = None litellm.cache = None

View file

@ -1169,7 +1169,7 @@ class Logging:
verbose_logger.debug(f"success callbacks: {litellm.success_callback}") verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
## BUILD COMPLETE STREAMED RESPONSE ## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response = None complete_streaming_response = None
if self.stream: if self.stream and isinstance(result, ModelResponse):
if ( if (
result.choices[0].finish_reason is not None result.choices[0].finish_reason is not None
): # if it's the last chunk ): # if it's the last chunk
@ -8654,6 +8654,8 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if hasattr(chunk, "id"):
model_response.id = chunk.id
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ model_response.choices[0].finish_reason = response_obj[
"finish_reason" "finish_reason"
@ -8676,6 +8678,8 @@ class CustomStreamWrapper:
model_response.system_fingerprint = getattr( model_response.system_fingerprint = getattr(
response_obj["original_chunk"], "system_fingerprint", None response_obj["original_chunk"], "system_fingerprint", None
) )
if hasattr(response_obj["original_chunk"], "id"):
model_response.id = response_obj["original_chunk"].id
if response_obj["logprobs"] is not None: if response_obj["logprobs"] is not None:
model_response.choices[0].logprobs = response_obj["logprobs"] model_response.choices[0].logprobs = response_obj["logprobs"]