fix(utils.py): persist response id across chunks

This commit is contained in:
Krrish Dholakia 2024-03-25 18:20:43 -07:00
parent dc2c4af631
commit 1ac641165b
2 changed files with 11 additions and 3 deletions

View file

@ -490,7 +490,7 @@ def test_redis_cache_completion_stream():
response_1_content += chunk.choices[0].delta.content or "" response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content) print(response_1_content)
time.sleep(0.1) # sleep for 0.1 seconds allow set cache to occur time.sleep(1) # sleep for 0.1 seconds allow set cache to occur
response2 = completion( response2 = completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=messages, messages=messages,
@ -505,8 +505,10 @@ def test_redis_cache_completion_stream():
response_2_id = chunk.id response_2_id = chunk.id
print(chunk) print(chunk)
response_2_content += chunk.choices[0].delta.content or "" response_2_content += chunk.choices[0].delta.content or ""
print("\nresponse 1", response_1_content) print(
print("\nresponse 2", response_2_content) f"\nresponse 1: {response_1_content}",
)
print(f"\nresponse 2: {response_2_content}")
assert ( assert (
response_1_id == response_2_id response_1_id == response_2_id
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
@ -516,6 +518,7 @@ def test_redis_cache_completion_stream():
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
litellm.cache = None litellm.cache = None
raise Exception("it worked!")
except Exception as e: except Exception as e:
print(e) print(e)
litellm.success_callback = [] litellm.success_callback = []

View file

@ -8458,6 +8458,7 @@ class CustomStreamWrapper:
self.completion_stream = completion_stream self.completion_stream = completion_stream
self.sent_first_chunk = False self.sent_first_chunk = False
self.sent_last_chunk = False self.sent_last_chunk = False
self.system_fingerprint: Optional[str] = None
self.received_finish_reason: Optional[str] = None self.received_finish_reason: Optional[str] = None
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"] self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
self.holding_chunk = "" self.holding_chunk = ""
@ -9373,6 +9374,7 @@ class CustomStreamWrapper:
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if hasattr(chunk, "id"): if hasattr(chunk, "id"):
model_response.id = chunk.id model_response.id = chunk.id
self.response_id = chunk.id
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
else: # openai / azure chat model else: # openai / azure chat model
@ -9397,6 +9399,7 @@ class CustomStreamWrapper:
) )
if hasattr(response_obj["original_chunk"], "id"): if hasattr(response_obj["original_chunk"], "id"):
model_response.id = response_obj["original_chunk"].id model_response.id = response_obj["original_chunk"].id
self.response_id = model_response.id
if response_obj["logprobs"] is not None: if response_obj["logprobs"] is not None:
model_response.choices[0].logprobs = response_obj["logprobs"] model_response.choices[0].logprobs = response_obj["logprobs"]
@ -9412,6 +9415,7 @@ class CustomStreamWrapper:
# enter this branch when no content has been passed in response # enter this branch when no content has been passed in response
original_chunk = response_obj.get("original_chunk", None) original_chunk = response_obj.get("original_chunk", None)
model_response.id = original_chunk.id model_response.id = original_chunk.id
self.response_id = original_chunk.id
if len(original_chunk.choices) > 0: if len(original_chunk.choices) > 0:
if ( if (
original_chunk.choices[0].delta.function_call is not None original_chunk.choices[0].delta.function_call is not None
@ -9493,6 +9497,7 @@ class CustomStreamWrapper:
original_chunk = response_obj.get("original_chunk", None) original_chunk = response_obj.get("original_chunk", None)
if original_chunk: if original_chunk:
model_response.id = original_chunk.id model_response.id = original_chunk.id
self.response_id = original_chunk.id
if len(original_chunk.choices) > 0: if len(original_chunk.choices) > 0:
try: try:
delta = dict(original_chunk.choices[0].delta) delta = dict(original_chunk.choices[0].delta)