fix(utils.py): ensure last chunk is always empty delta w/ finish reason

makes sure we're openai-compatible with our streaming. Adds stricter tests for this as well
This commit is contained in:
Krrish Dholakia 2024-03-25 16:33:41 -07:00
parent f153889738
commit 9e1e97528d
3 changed files with 221 additions and 285 deletions

View file

@ -1,119 +1,6 @@
============================= test session starts ============================== <litellm.utils.CustomStreamWrapper object at 0x118bd82d0>
platform darwin -- Python 3.11.6, pytest-7.3.1, pluggy-1.3.0 chunk: ModelResponse(id='chatcmpl-95b7d389-ff5a-4e09-a084-02584ba2cf1e', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content='In the United States of America, the Supreme Court has ultimate judicial authority, and it is the one that rules on legal disputes between the states, or on the interpretation of the', role='assistant', function_call=None, tool_calls=None), logprobs=None)], created=1711406570, model='ai21.j2-mid-v1', object='chat.completion.chunk', system_fingerprint=None, usage=Usage())
rootdir: /Users/krrishdholakia/Documents/litellm/litellm/tests extracted chunk: In the United States of America, the Supreme Court has ultimate judicial authority, and it is the one that rules on legal disputes between the states, or on the interpretation of the
plugins: timeout-2.2.0, asyncio-0.23.2, anyio-3.7.1, xdist-3.3.1 chunk: ModelResponse(id='chatcmpl-95b7d389-ff5a-4e09-a084-02584ba2cf1e', choices=[StreamingChoices(finish_reason='stop', index=0, delta=Delta(content=None, role=None, function_call=None, tool_calls=None), logprobs=None)], created=1711406570, model='ai21.j2-mid-v1', object='chat.completion.chunk', system_fingerprint=None, usage=Usage())
asyncio: mode=Mode.STRICT extracted chunk:
collected 1 item completion_response: In the United States of America, the Supreme Court has ultimate judicial authority, and it is the one that rules on legal disputes between the states, or on the interpretation of the
test_completion.py . [100%]
=============================== warnings summary ===============================
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
/opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
../proxy/_types.py:102
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:102: PydanticDeprecatedSince20: `pydantic.config.Extra` is deprecated, use literal values instead (e.g. `extra='allow'`). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
extra = Extra.allow # Allow extra fields
../proxy/_types.py:105
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:105: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:134
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:134: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:180
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:180: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:241
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:241: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:253
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:253: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:292
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:292: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:319
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:319: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:570
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:570: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:591
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:591: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../utils.py:35
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:35: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
import pkg_resources
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: 10 warnings
/opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
/opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.cloud')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2350
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2350
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2350
/opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2350: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(parent)
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
/opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.logging')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
/opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.iam')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
/opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('mpl_toolkits')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
/opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('sphinxcontrib')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
../llms/prompt_templates/factory.py:6
/Users/krrishdholakia/Documents/litellm/litellm/llms/prompt_templates/factory.py:6: DeprecationWarning: 'imghdr' is deprecated and slated for removal in Python 3.13
import imghdr, base64
test_completion.py::test_completion_claude_3_stream
../utils.py:3249
../utils.py:3249
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:3249: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
with resources.open_text(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================== 1 passed, 46 warnings in 3.14s ========================

View file

@ -108,8 +108,19 @@ last_openai_chunk_example = {
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
} }
"""
Final chunk (sdk):
chunk: ChatCompletionChunk(id='chatcmpl-96mM3oNBlxh2FDWVLKsgaFBBcULmI',
choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, role=None,
tool_calls=None), finish_reason='stop', index=0, logprobs=None)],
created=1711402871, model='gpt-3.5-turbo-0125', object='chat.completion.chunk', system_fingerprint='fp_3bc1b5746c')
"""
def validate_last_format(chunk): def validate_last_format(chunk):
"""
Ensure last chunk has no remaining content or tools
"""
assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary." assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
assert isinstance(chunk["id"], str), "'id' should be a string." assert isinstance(chunk["id"], str), "'id' should be a string."
assert isinstance(chunk["object"], str), "'object' should be a string." assert isinstance(chunk["object"], str), "'object' should be a string."
@ -119,6 +130,10 @@ def validate_last_format(chunk):
for choice in chunk["choices"]: for choice in chunk["choices"]:
assert isinstance(choice["index"], int), "'index' should be an integer." assert isinstance(choice["index"], int), "'index' should be an integer."
assert choice["delta"]["content"] is None
assert choice["delta"]["function_call"] is None
assert choice["delta"]["role"] is None
assert choice["delta"]["tool_calls"] is None
assert isinstance( assert isinstance(
choice["finish_reason"], str choice["finish_reason"], str
), "'finish_reason' should be a string." ), "'finish_reason' should be a string."
@ -493,13 +508,15 @@ def test_completion_mistral_api_stream():
stream=True, stream=True,
) )
complete_response = "" complete_response = ""
has_finish_reason = False
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
print(chunk)
# print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
if finished: if finished:
has_finish_reason = True
break break
complete_response += chunk complete_response += chunk
if has_finish_reason == False:
raise Exception("finish reason not set")
if complete_response.strip() == "": if complete_response.strip() == "":
raise Exception("Empty response received") raise Exception("Empty response received")
print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
@ -534,11 +551,15 @@ def test_completion_deep_infra_stream():
complete_response = "" complete_response = ""
# Add any assertions here to check the response # Add any assertions here to check the response
has_finish_reason = False
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
if finished: if finished:
has_finish_reason = True
break break
complete_response += chunk complete_response += chunk
if has_finish_reason == False:
raise Exception("finish reason not set")
if complete_response.strip() == "": if complete_response.strip() == "":
raise Exception("Empty response received") raise Exception("Empty response received")
print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
@ -608,11 +629,15 @@ def test_completion_claude_stream_bad_key():
) )
complete_response = "" complete_response = ""
# Add any assertions here to check the response # Add any assertions here to check the response
has_finish_reason = False
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
if finished: if finished:
has_finish_reason = True
break break
complete_response += chunk complete_response += chunk
if has_finish_reason == False:
raise Exception("finish reason not set")
if complete_response.strip() == "": if complete_response.strip() == "":
raise Exception("Empty response received") raise Exception("Empty response received")
print(f"1234completion_response: {complete_response}") print(f"1234completion_response: {complete_response}")
@ -626,6 +651,45 @@ def test_completion_claude_stream_bad_key():
# test_completion_claude_stream_bad_key() # test_completion_claude_stream_bad_key()
# test_completion_replicate_stream() # test_completion_replicate_stream()
def test_vertex_ai_stream():
from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials
load_vertex_ai_credentials()
litellm.set_verbose = True
litellm.vertex_project = "reliablekeys"
import random
test_models = ["gemini-1.0-pro"]
for model in test_models:
try:
print("making request", model)
response = completion(
model=model,
messages=[
{"role": "user", "content": "write 10 line code code for saying hi"}
],
stream=True,
)
complete_response = ""
is_finished = False
for idx, chunk in enumerate(response):
print(f"chunk in response: {chunk}")
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
is_finished = True
break
complete_response += chunk
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
assert is_finished == True
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# def test_completion_vertexai_stream(): # def test_completion_vertexai_stream():
# try: # try:
# import os # import os
@ -742,11 +806,15 @@ def test_bedrock_claude_3_streaming():
) )
complete_response = "" complete_response = ""
# Add any assertions here to check the response # Add any assertions here to check the response
has_finish_reason = False
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
if finished: if finished:
has_finish_reason = True
break break
complete_response += chunk complete_response += chunk
if has_finish_reason == False:
raise Exception("finish reason not set")
if complete_response.strip() == "": if complete_response.strip() == "":
raise Exception("Empty response received") raise Exception("Empty response received")
print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
@ -1705,7 +1773,7 @@ def test_success_callback_streaming():
messages = [{"role": "user", "content": "hello"}] messages = [{"role": "user", "content": "hello"}]
print("TESTING LITELLM COMPLETION CALL") print("TESTING LITELLM COMPLETION CALL")
response = litellm.completion( response = litellm.completion(
model="j2-light", model="gpt-3.5-turbo",
messages=messages, messages=messages,
stream=True, stream=True,
max_tokens=5, max_tokens=5,
@ -2072,7 +2140,7 @@ def test_completion_claude_3_function_call_with_streaming():
) )
idx = 0 idx = 0
for chunk in response: for chunk in response:
# print(f"chunk: {chunk}") print(f"chunk in response: {chunk}")
if idx == 0: if idx == 0:
assert ( assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None chunk.choices[0].delta.tool_calls[0].function.arguments is not None
@ -2081,7 +2149,7 @@ def test_completion_claude_3_function_call_with_streaming():
chunk.choices[0].delta.tool_calls[0].function.arguments, str chunk.choices[0].delta.tool_calls[0].function.arguments, str
) )
validate_first_streaming_function_calling_chunk(chunk=chunk) validate_first_streaming_function_calling_chunk(chunk=chunk)
elif idx == 1: elif idx == 1 and chunk.choices[0].finish_reason is None:
validate_second_streaming_function_calling_chunk(chunk=chunk) validate_second_streaming_function_calling_chunk(chunk=chunk)
elif chunk.choices[0].finish_reason is not None: # last chunk elif chunk.choices[0].finish_reason is not None: # last chunk
validate_final_streaming_function_calling_chunk(chunk=chunk) validate_final_streaming_function_calling_chunk(chunk=chunk)
@ -2136,7 +2204,7 @@ async def test_acompletion_claude_3_function_call_with_streaming():
chunk.choices[0].delta.tool_calls[0].function.arguments, str chunk.choices[0].delta.tool_calls[0].function.arguments, str
) )
validate_first_streaming_function_calling_chunk(chunk=chunk) validate_first_streaming_function_calling_chunk(chunk=chunk)
elif idx == 1: elif idx == 1 and chunk.choices[0].finish_reason is None:
validate_second_streaming_function_calling_chunk(chunk=chunk) validate_second_streaming_function_calling_chunk(chunk=chunk)
elif chunk.choices[0].finish_reason is not None: # last chunk elif chunk.choices[0].finish_reason is not None: # last chunk
validate_final_streaming_function_calling_chunk(chunk=chunk) validate_final_streaming_function_calling_chunk(chunk=chunk)

View file

@ -8458,6 +8458,7 @@ class CustomStreamWrapper:
self.completion_stream = completion_stream self.completion_stream = completion_stream
self.sent_first_chunk = False self.sent_first_chunk = False
self.sent_last_chunk = False self.sent_last_chunk = False
self.received_finish_reason: Optional[str] = None
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"] self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
self.holding_chunk = "" self.holding_chunk = ""
self.complete_response = "" self.complete_response = ""
@ -9131,7 +9132,7 @@ class CustomStreamWrapper:
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
def chunk_creator(self, chunk): def model_response_creator(self):
model_response = ModelResponse(stream=True, model=self.model) model_response = ModelResponse(stream=True, model=self.model)
if self.response_id is not None: if self.response_id is not None:
model_response.id = self.response_id model_response.id = self.response_id
@ -9141,6 +9142,20 @@ class CustomStreamWrapper:
model_response._hidden_params["created_at"] = time.time() model_response._hidden_params["created_at"] = time.time()
model_response.choices = [StreamingChoices()] model_response.choices = [StreamingChoices()]
model_response.choices[0].finish_reason = None model_response.choices[0].finish_reason = None
return model_response
def is_delta_empty(self, delta: Delta) -> bool:
is_empty = True
if delta.content is not None:
is_empty = False
elif delta.tool_calls is not None:
is_empty = False
elif delta.function_call is not None:
is_empty = False
return is_empty
def chunk_creator(self, chunk):
model_response = self.model_response_creator()
response_obj = {} response_obj = {}
try: try:
# return this for all models # return this for all models
@ -9149,30 +9164,22 @@ class CustomStreamWrapper:
response_obj = self.handle_anthropic_chunk(chunk) response_obj = self.handle_anthropic_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
elif self.model == "replicate" or self.custom_llm_provider == "replicate": elif self.model == "replicate" or self.custom_llm_provider == "replicate":
response_obj = self.handle_replicate_chunk(chunk) response_obj = self.handle_replicate_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
elif self.custom_llm_provider and self.custom_llm_provider == "together_ai": elif self.custom_llm_provider and self.custom_llm_provider == "together_ai":
response_obj = self.handle_together_ai_chunk(chunk) response_obj = self.handle_together_ai_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
response_obj = self.handle_huggingface_chunk(chunk) response_obj = self.handle_huggingface_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
elif ( elif (
self.custom_llm_provider and self.custom_llm_provider == "baseten" self.custom_llm_provider and self.custom_llm_provider == "baseten"
): # baseten doesn't provide streaming ): # baseten doesn't provide streaming
@ -9183,16 +9190,12 @@ class CustomStreamWrapper:
response_obj = self.handle_ai21_chunk(chunk) response_obj = self.handle_ai21_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk": elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
response_obj = self.handle_maritalk_chunk(chunk) response_obj = self.handle_maritalk_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
elif self.custom_llm_provider and self.custom_llm_provider == "vllm": elif self.custom_llm_provider and self.custom_llm_provider == "vllm":
completion_obj["content"] = chunk[0].outputs[0].text completion_obj["content"] = chunk[0].outputs[0].text
elif ( elif (
@ -9201,152 +9204,116 @@ class CustomStreamWrapper:
response_obj = self.handle_aleph_alpha_chunk(chunk) response_obj = self.handle_aleph_alpha_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
elif self.custom_llm_provider == "nlp_cloud": elif self.custom_llm_provider == "nlp_cloud":
try: try:
response_obj = self.handle_nlp_cloud_chunk(chunk) response_obj = self.handle_nlp_cloud_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
except Exception as e: except Exception as e:
if self.sent_last_chunk: if self.received_finish_reason:
raise e raise e
else: else:
if self.sent_first_chunk is False: if self.sent_first_chunk is False:
raise Exception("An unknown error occurred with the stream") raise Exception("An unknown error occurred with the stream")
model_response.choices[0].finish_reason = "stop" self.received_finish_reason = "stop"
self.sent_last_chunk = True
elif self.custom_llm_provider == "gemini": elif self.custom_llm_provider == "gemini":
try: if hasattr(chunk, "parts") == True:
if hasattr(chunk, "parts") == True: try:
try: if len(chunk.parts) > 0:
if len(chunk.parts) > 0: completion_obj["content"] = chunk.parts[0].text
completion_obj["content"] = chunk.parts[0].text if hasattr(chunk.parts[0], "finish_reason"):
if hasattr(chunk.parts[0], "finish_reason"): self.received_finish_reason = chunk.parts[
model_response.choices[0].finish_reason = ( 0
map_finish_reason(chunk.parts[0].finish_reason.name) ].finish_reason.name
) except:
except: if chunk.parts[0].finish_reason.name == "SAFETY":
if chunk.parts[0].finish_reason.name == "SAFETY": raise Exception(
raise Exception( f"The response was blocked by VertexAI. {str(chunk)}"
f"The response was blocked by VertexAI. {str(chunk)}" )
) else:
else: completion_obj["content"] = str(chunk)
completion_obj["content"] = str(chunk)
except StopIteration as e:
if self.sent_last_chunk:
raise e
else:
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
try: if hasattr(chunk, "candidates") == True:
if hasattr(chunk, "candidates") == True: try:
try: try:
try: completion_obj["content"] = chunk.text
completion_obj["content"] = chunk.text
except Exception as e:
if "Part has no text." in str(e):
## check for function calling
function_call = (
chunk.candidates[0]
.content.parts[0]
.function_call
)
args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v
args_str = json.dumps(args_dict)
_delta_obj = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": args_str,
"name": function_call.name,
},
"type": "function",
}
],
)
_streaming_response = StreamingChoices(
delta=_delta_obj
)
_model_response = ModelResponse(stream=True)
_model_response.choices = [_streaming_response]
response_obj = {"original_chunk": _model_response}
else:
raise e
if (
hasattr(chunk.candidates[0], "finish_reason")
and chunk.candidates[0].finish_reason.name
!= "FINISH_REASON_UNSPECIFIED"
): # every non-final chunk in vertex ai has this
model_response.choices[0].finish_reason = (
map_finish_reason(
chunk.candidates[0].finish_reason.name
)
)
except Exception as e: except Exception as e:
if chunk.candidates[0].finish_reason.name == "SAFETY": if "Part has no text." in str(e):
raise Exception( ## check for function calling
f"The response was blocked by VertexAI. {str(chunk)}" function_call = (
chunk.candidates[0].content.parts[0].function_call
) )
else: args_dict = {}
completion_obj["content"] = str(chunk) for k, v in function_call.args.items():
except StopIteration as e: args_dict[k] = v
if self.sent_last_chunk: args_str = json.dumps(args_dict)
raise e _delta_obj = litellm.utils.Delta(
else: content=None,
model_response.choices[0].finish_reason = "stop" tool_calls=[
self.sent_last_chunk = True {
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": args_str,
"name": function_call.name,
},
"type": "function",
}
],
)
_streaming_response = StreamingChoices(delta=_delta_obj)
_model_response = ModelResponse(stream=True)
_model_response.choices = [_streaming_response]
response_obj = {"original_chunk": _model_response}
else:
raise e
if (
hasattr(chunk.candidates[0], "finish_reason")
and chunk.candidates[0].finish_reason.name
!= "FINISH_REASON_UNSPECIFIED"
): # every non-final chunk in vertex ai has this
self.received_finish_reason = chunk.candidates[
0
].finish_reason.name
except Exception as e:
if chunk.candidates[0].finish_reason.name == "SAFETY":
raise Exception(
f"The response was blocked by VertexAI. {str(chunk)}"
)
else:
completion_obj["content"] = str(chunk)
elif self.custom_llm_provider == "cohere": elif self.custom_llm_provider == "cohere":
response_obj = self.handle_cohere_chunk(chunk) response_obj = self.handle_cohere_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
elif self.custom_llm_provider == "cohere_chat": elif self.custom_llm_provider == "cohere_chat":
response_obj = self.handle_cohere_chat_chunk(chunk) response_obj = self.handle_cohere_chat_chunk(chunk)
if response_obj is None: if response_obj is None:
return return
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
elif self.custom_llm_provider == "bedrock": elif self.custom_llm_provider == "bedrock":
if self.sent_last_chunk: if self.received_finish_reason is not None:
raise StopIteration raise StopIteration
response_obj = self.handle_bedrock_stream(chunk) response_obj = self.handle_bedrock_stream(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
self.sent_last_chunk = True
elif self.custom_llm_provider == "sagemaker": elif self.custom_llm_provider == "sagemaker":
verbose_logger.debug(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk) response_obj = self.handle_sagemaker_stream(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
self.sent_last_chunk = True
elif self.custom_llm_provider == "petals": elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0: if len(self.completion_stream) == 0:
if self.sent_last_chunk: if self.received_finish_reason is not None:
raise StopIteration raise StopIteration
else: else:
model_response.choices[0].finish_reason = "stop" self.received_finish_reason = "stop"
self.sent_last_chunk = True
chunk_size = 30 chunk_size = 30
new_chunk = self.completion_stream[:chunk_size] new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk completion_obj["content"] = new_chunk
@ -9356,11 +9323,10 @@ class CustomStreamWrapper:
# fake streaming # fake streaming
response_obj = {} response_obj = {}
if len(self.completion_stream) == 0: if len(self.completion_stream) == 0:
if self.sent_last_chunk: if self.received_finish_reason is not None:
raise StopIteration raise StopIteration
else: else:
model_response.choices[0].finish_reason = "stop" self.received_finish_reason = "stop"
self.sent_last_chunk = True
chunk_size = 30 chunk_size = 30
new_chunk = self.completion_stream[:chunk_size] new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk completion_obj["content"] = new_chunk
@ -9371,41 +9337,31 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
elif self.custom_llm_provider == "ollama_chat": elif self.custom_llm_provider == "ollama_chat":
response_obj = self.handle_ollama_chat_stream(chunk) response_obj = self.handle_ollama_chat_stream(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
elif self.custom_llm_provider == "cloudflare": elif self.custom_llm_provider == "cloudflare":
response_obj = self.handle_cloudlfare_stream(chunk) response_obj = self.handle_cloudlfare_stream(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
elif self.custom_llm_provider == "text-completion-openai": elif self.custom_llm_provider == "text-completion-openai":
response_obj = self.handle_openai_text_completion_chunk(chunk) response_obj = self.handle_openai_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
elif self.custom_llm_provider == "azure_text": elif self.custom_llm_provider == "azure_text":
response_obj = self.handle_azure_text_completion_chunk(chunk) response_obj = self.handle_azure_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
elif self.custom_llm_provider == "cached_response": elif self.custom_llm_provider == "cached_response":
response_obj = { response_obj = {
"text": chunk.choices[0].delta.content, "text": chunk.choices[0].delta.content,
@ -9419,9 +9375,7 @@ class CustomStreamWrapper:
if hasattr(chunk, "id"): if hasattr(chunk, "id"):
model_response.id = chunk.id model_response.id = chunk.id
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
else: # openai / azure chat model else: # openai / azure chat model
if self.custom_llm_provider == "azure": if self.custom_llm_provider == "azure":
if hasattr(chunk, "model"): if hasattr(chunk, "model"):
@ -9437,9 +9391,7 @@ class CustomStreamWrapper:
raise Exception( raise Exception(
"Mistral API raised a streaming error - finish_reason: error, no content string given." "Mistral API raised a streaming error - finish_reason: error, no content string given."
) )
model_response.choices[0].finish_reason = response_obj[ self.received_finish_reason = response_obj["finish_reason"]
"finish_reason"
]
if response_obj.get("original_chunk", None) is not None: if response_obj.get("original_chunk", None) is not None:
model_response.system_fingerprint = getattr( model_response.system_fingerprint = getattr(
response_obj["original_chunk"], "system_fingerprint", None response_obj["original_chunk"], "system_fingerprint", None
@ -9451,7 +9403,7 @@ class CustomStreamWrapper:
model_response.model = self.model model_response.model = self.model
print_verbose( print_verbose(
f"model_response finish reason 3: {model_response.choices[0].finish_reason}; response_obj={response_obj}" f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}"
) )
## FUNCTION CALL PARSING ## FUNCTION CALL PARSING
if ( if (
@ -9580,7 +9532,7 @@ class CustomStreamWrapper:
return model_response return model_response
else: else:
return return
elif model_response.choices[0].finish_reason is not None: elif self.received_finish_reason is not None:
# flush any remaining holding chunk # flush any remaining holding chunk
if len(self.holding_chunk) > 0: if len(self.holding_chunk) > 0:
if model_response.choices[0].delta.content is None: if model_response.choices[0].delta.content is None:
@ -9590,10 +9542,17 @@ class CustomStreamWrapper:
self.holding_chunk + model_response.choices[0].delta.content self.holding_chunk + model_response.choices[0].delta.content
) )
self.holding_chunk = "" self.holding_chunk = ""
# get any function call arguments # if delta is None
model_response.choices[0].finish_reason = map_finish_reason( is_delta_empty = self.is_delta_empty(
model_response.choices[0].finish_reason delta=model_response.choices[0].delta
) # ensure consistent output to openai )
if is_delta_empty:
# get any function call arguments
model_response.choices[0].finish_reason = map_finish_reason(
finish_reason=self.received_finish_reason
) # ensure consistent output to openai
self.sent_last_chunk = True
return model_response return model_response
elif ( elif (
model_response.choices[0].delta.tool_calls is not None model_response.choices[0].delta.tool_calls is not None
@ -9653,6 +9612,16 @@ class CustomStreamWrapper:
## SYNC LOGGING ## SYNC LOGGING
self.logging_obj.success_handler(processed_chunk) self.logging_obj.success_handler(processed_chunk)
def finish_reason_handler(self):
model_response = self.model_response_creator()
if self.received_finish_reason is not None:
model_response.choices[0].finish_reason = map_finish_reason(
finish_reason=self.received_finish_reason
)
else:
model_response.choices[0].finish_reason = "stop"
return model_response
## needs to handle the empty string case (even starting chunk can be an empty string) ## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self): def __next__(self):
try: try:
@ -9687,7 +9656,11 @@ class CustomStreamWrapper:
# RETURN RESULT # RETURN RESULT
return response return response
except StopIteration: except StopIteration:
raise # Re-raise StopIteration if self.sent_last_chunk == True:
raise # Re-raise StopIteration
else:
self.sent_last_chunk = True
return self.finish_reason_handler()
except Exception as e: except Exception as e:
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
@ -9792,9 +9765,17 @@ class CustomStreamWrapper:
# RETURN RESULT # RETURN RESULT
return processed_chunk return processed_chunk
except StopAsyncIteration: except StopAsyncIteration:
raise if self.sent_last_chunk == True:
raise # Re-raise StopIteration
else:
self.sent_last_chunk = True
return self.finish_reason_handler()
except StopIteration: except StopIteration:
raise StopAsyncIteration # Re-raise StopIteration if self.sent_last_chunk == True:
raise StopAsyncIteration
else:
self.sent_last_chunk = True
return self.finish_reason_handler()
except Exception as e: except Exception as e:
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
# Handle any exceptions that might occur during streaming # Handle any exceptions that might occur during streaming