diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index f1640d97d..8854e24d8 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -4,6 +4,7 @@ import sys, os, asyncio import traceback import time, pytest +from pydantic import BaseModel sys.path.insert( 0, os.path.abspath("../..") @@ -92,7 +93,7 @@ def validate_second_format(chunk): for choice in chunk["choices"]: assert isinstance(choice["index"], int), "'index' should be an integer." - assert "role" not in choice["delta"], "'role' should be a string." + assert hasattr(choice["delta"], "role"), "'role' should be a string." # openai v1.0.0 returns content as None assert (choice["finish_reason"] is None) or isinstance( choice["finish_reason"], str @@ -1455,8 +1456,8 @@ first_openai_function_call_example = { def validate_first_function_call_chunk_structure(item): - if not isinstance(item, dict): - raise Exception("Incorrect format") + if not (isinstance(item, dict) or isinstance(item, litellm.ModelResponse)): + raise Exception(f"Incorrect format, type of item: {type(item)}") required_keys = {"id", "object", "created", "model", "choices"} for key in required_keys: @@ -1468,27 +1469,42 @@ def validate_first_function_call_chunk_structure(item): required_keys_in_choices_array = {"index", "delta", "finish_reason"} for choice in item["choices"]: - if not isinstance(choice, dict): - raise Exception("Incorrect format") + if not ( + isinstance(choice, dict) + or isinstance(choice, litellm.utils.StreamingChoices) + ): + raise Exception(f"Incorrect format, type of choice: {type(choice)}") for key in required_keys_in_choices_array: if key not in choice: raise Exception("Incorrect format") - if not isinstance(choice["delta"], dict): - raise Exception("Incorrect format") + if not ( + isinstance(choice["delta"], dict) + or isinstance(choice["delta"], litellm.utils.Delta) + ): + raise Exception( + f"Incorrect format, type of choice: {type(choice['delta'])}" + ) required_keys_in_delta = {"role", "content", "function_call"} for key in required_keys_in_delta: if key not in choice["delta"]: raise Exception("Incorrect format") - if not isinstance(choice["delta"]["function_call"], dict): - raise Exception("Incorrect format") + if not ( + isinstance(choice["delta"]["function_call"], dict) + or isinstance(choice["delta"]["function_call"], BaseModel) + ): + raise Exception( + f"Incorrect format, type of function call: {type(choice['delta']['function_call'])}" + ) required_keys_in_function_call = {"name", "arguments"} for key in required_keys_in_function_call: - if key not in choice["delta"]["function_call"]: - raise Exception("Incorrect format") + if not hasattr(choice["delta"]["function_call"], key): + raise Exception( + f"Incorrect format, expected key={key}; actual keys: {choice['delta']['function_call']}, eval: {hasattr(choice['delta']['function_call'], key)}" + ) return True @@ -1547,7 +1563,7 @@ final_function_call_chunk_example = { def validate_final_function_call_chunk_structure(data): - if not isinstance(data, dict): + if not (isinstance(data, dict) or isinstance(data, litellm.ModelResponse)): raise Exception("Incorrect format") required_keys = {"id", "object", "created", "model", "choices"} @@ -1560,7 +1576,9 @@ def validate_final_function_call_chunk_structure(data): required_keys_in_choices_array = {"index", "delta", "finish_reason"} for choice in data["choices"]: - if not isinstance(choice, dict): + if not ( + isinstance(choice, dict) or isinstance(choice["delta"], litellm.utils.Delta) + ): raise Exception("Incorrect format") for key in required_keys_in_choices_array: if key not in choice: @@ -1592,37 +1610,88 @@ def streaming_and_function_calling_format_tests(idx, chunk): return extracted_chunk, finished -# def test_openai_streaming_and_function_calling(): -# function1 = [ -# { -# "name": "get_current_weather", -# "description": "Get the current weather in a given location", -# "parameters": { -# "type": "object", -# "properties": { -# "location": { -# "type": "string", -# "description": "The city and state, e.g. San Francisco, CA", -# }, -# "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, -# }, -# "required": ["location"], -# }, -# } -# ] -# messages=[{"role": "user", "content": "What is the weather like in Boston?"}] -# try: -# response = completion( -# model="gpt-3.5-turbo", functions=function1, messages=messages, stream=True, -# ) -# # Add any assertions here to check the response -# for idx, chunk in enumerate(response): -# streaming_and_function_calling_format_tests(idx=idx, chunk=chunk) -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") -# raise e +def test_openai_streaming_and_function_calling(): + tools = [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + messages = [{"role": "user", "content": "What is the weather like in Boston?"}] + try: + response = completion( + model="gpt-3.5-turbo", + tools=tools, + messages=messages, + stream=True, + ) + # Add any assertions here to check the response + for idx, chunk in enumerate(response): + streaming_and_function_calling_format_tests(idx=idx, chunk=chunk) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + raise e -# test_openai_streaming_and_function_calling() + +def test_azure_streaming_and_function_calling(): + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + messages = [{"role": "user", "content": "What is the weather like in Boston?"}] + try: + response = completion( + model="azure/gpt-4-nov-release", + tools=tools, + tool_choice="auto", + messages=messages, + stream=True, + api_base=os.getenv("AZURE_FRANCE_API_BASE"), + api_key=os.getenv("AZURE_FRANCE_API_KEY"), + api_version="2024-02-15-preview", + ) + # Add any assertions here to check the response + for idx, chunk in enumerate(response): + if idx == 0: + assert ( + chunk.choices[0].delta.tool_calls[0].function.arguments is not None + ) + assert isinstance( + chunk.choices[0].delta.tool_calls[0].function.arguments, str + ) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + raise e + + +# test_azure_streaming_and_function_calling() def test_success_callback_streaming(): diff --git a/litellm/utils.py b/litellm/utils.py index 4260ee6e1..8c6529544 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -258,11 +258,14 @@ class Message(OpenAIObject): class Delta(OpenAIObject): - def __init__(self, content=None, role=None, **params): + def __init__( + self, content=None, role=None, function_call=None, tool_calls=None, **params + ): super(Delta, self).__init__(**params) self.content = content - if role is not None: - self.role = role + self.role = role + self.function_call = function_call + self.tool_calls = tool_calls def __contains__(self, key): # Define custom behavior for the 'in' operator @@ -8675,8 +8678,37 @@ class CustomStreamWrapper: ): try: delta = dict(original_chunk.choices[0].delta) + ## AZURE - check if arguments is not None + if ( + original_chunk.choices[0].delta.function_call + is not None + ): + if ( + getattr( + original_chunk.choices[0].delta.function_call, + "arguments", + ) + is None + ): + original_chunk.choices[ + 0 + ].delta.function_call.arguments = "" + elif original_chunk.choices[0].delta.tool_calls is not None: + if isinstance( + original_chunk.choices[0].delta.tool_calls, list + ): + for t in original_chunk.choices[0].delta.tool_calls: + if ( + getattr( + t.function, + "arguments", + ) + is None + ): + t.function.arguments = "" model_response.choices[0].delta = Delta(**delta) except Exception as e: + traceback.print_exc() model_response.choices[0].delta = Delta() else: return