From 97b9d570a643bdf027e428e26a4b743429823e9d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 23 Feb 2024 20:55:32 -0800 Subject: [PATCH 1/2] fix(utils.py): stricter azure function calling tests --- litellm/tests/test_streaming.py | 286 ++++++++++++++++++++++++++++---- litellm/utils.py | 8 +- 2 files changed, 261 insertions(+), 33 deletions(-) diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 7baa927ad..66e8be4cb 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1655,6 +1655,202 @@ def test_openai_streaming_and_function_calling(): raise e +# test_azure_streaming_and_function_calling() + + +def test_success_callback_streaming(): + def success_callback(kwargs, completion_response, start_time, end_time): + print( + { + "success": True, + "input": kwargs, + "output": completion_response, + "start_time": start_time, + "end_time": end_time, + } + ) + + litellm.success_callback = [success_callback] + + messages = [{"role": "user", "content": "hello"}] + print("TESTING LITELLM COMPLETION CALL") + response = litellm.completion( + model="j2-light", + messages=messages, + stream=True, + max_tokens=5, + ) + print(response) + + for chunk in response: + print(chunk["choices"][0]) + + +# test_success_callback_streaming() + +#### STREAMING + FUNCTION CALLING ### +from pydantic import BaseModel +from typing import List, Optional + + +class Function(BaseModel): + name: str + arguments: str + + +class ToolCalls(BaseModel): + index: int + id: str + type: str + function: Function + + +class Delta(BaseModel): + role: str + content: Optional[str] + tool_calls: List[ToolCalls] + + +class Choices(BaseModel): + index: int + delta: Delta + logprobs: Optional[str] + finish_reason: Optional[str] + + +class Chunk(BaseModel): + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[Choices] + + +def validate_first_streaming_function_calling_chunk(chunk: ModelResponse): + chunk_instance = Chunk(**chunk.model_dump()) + + +### Chunk 1 + + +# { +# "id": "chatcmpl-8vdVjtzxc0JqGjq93NxC79dMp6Qcs", +# "object": "chat.completion.chunk", +# "created": 1708747267, +# "model": "gpt-3.5-turbo-0125", +# "system_fingerprint": "fp_86156a94a0", +# "choices": [ +# { +# "index": 0, +# "delta": { +# "role": "assistant", +# "content": null, +# "tool_calls": [ +# { +# "index": 0, +# "id": "call_oN10vaaC9iA8GLFRIFwjCsN7", +# "type": "function", +# "function": { +# "name": "get_current_weather", +# "arguments": "" +# } +# } +# ] +# }, +# "logprobs": null, +# "finish_reason": null +# } +# ] +# } +class Function2(BaseModel): + arguments: str + + +class ToolCalls2(BaseModel): + index: int + function: Optional[Function2] + + +class Delta2(BaseModel): + tool_calls: List[ToolCalls2] + + +class Choices2(BaseModel): + index: int + delta: Delta2 + logprobs: Optional[str] + finish_reason: Optional[str] + + +class Chunk2(BaseModel): + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[Choices2] + + +## Chunk 2 + +# { +# "id": "chatcmpl-8vdVjtzxc0JqGjq93NxC79dMp6Qcs", +# "object": "chat.completion.chunk", +# "created": 1708747267, +# "model": "gpt-3.5-turbo-0125", +# "system_fingerprint": "fp_86156a94a0", +# "choices": [ +# { +# "index": 0, +# "delta": { +# "tool_calls": [ +# { +# "index": 0, +# "function": { +# "arguments": "{\"" +# } +# } +# ] +# }, +# "logprobs": null, +# "finish_reason": null +# } +# ] +# } + + +def validate_second_streaming_function_calling_chunk(chunk: ModelResponse): + chunk_instance = Chunk2(**chunk.model_dump()) + + +class Delta3(BaseModel): + content: Optional[str] = None + role: Optional[str] = None + function_call: Optional[dict] = None + tool_calls: Optional[List] = None + + +class Choices3(BaseModel): + index: int + delta: Delta3 + logprobs: Optional[str] + finish_reason: str + + +class Chunk3(BaseModel): + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[Choices3] + + +def validate_final_streaming_function_calling_chunk(chunk: ModelResponse): + chunk_instance = Chunk3(**chunk.model_dump()) + + def test_azure_streaming_and_function_calling(): tools = [ { @@ -1690,6 +1886,7 @@ def test_azure_streaming_and_function_calling(): ) # Add any assertions here to check the response for idx, chunk in enumerate(response): + print(f"chunk: {chunk}") if idx == 0: assert ( chunk.choices[0].delta.tool_calls[0].function.arguments is not None @@ -1697,40 +1894,69 @@ def test_azure_streaming_and_function_calling(): assert isinstance( chunk.choices[0].delta.tool_calls[0].function.arguments, str ) + validate_first_streaming_function_calling_chunk(chunk=chunk) + elif idx == 1: + validate_second_streaming_function_calling_chunk(chunk=chunk) + elif chunk.choices[0].finish_reason is not None: # last chunk + validate_final_streaming_function_calling_chunk(chunk=chunk) + except Exception as e: pytest.fail(f"Error occurred: {e}") raise e -# test_azure_streaming_and_function_calling() - - -def test_success_callback_streaming(): - def success_callback(kwargs, completion_response, start_time, end_time): - print( - { - "success": True, - "input": kwargs, - "output": completion_response, - "start_time": start_time, - "end_time": end_time, - } +@pytest.mark.asyncio +async def test_azure_astreaming_and_function_calling(): + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + messages = [{"role": "user", "content": "What is the weather like in Boston?"}] + try: + response = await litellm.acompletion( + model="azure/gpt-4-nov-release", + tools=tools, + tool_choice="auto", + messages=messages, + stream=True, + api_base=os.getenv("AZURE_FRANCE_API_BASE"), + api_key=os.getenv("AZURE_FRANCE_API_KEY"), + api_version="2024-02-15-preview", ) + # Add any assertions here to check the response + idx = 0 + async for chunk in response: + print(f"chunk: {chunk}") + if idx == 0: + assert ( + chunk.choices[0].delta.tool_calls[0].function.arguments is not None + ) + assert isinstance( + chunk.choices[0].delta.tool_calls[0].function.arguments, str + ) + validate_first_streaming_function_calling_chunk(chunk=chunk) + elif idx == 1: + validate_second_streaming_function_calling_chunk(chunk=chunk) + elif chunk.choices[0].finish_reason is not None: # last chunk + validate_final_streaming_function_calling_chunk(chunk=chunk) + idx += 1 - litellm.success_callback = [success_callback] - - messages = [{"role": "user", "content": "hello"}] - print("TESTING LITELLM COMPLETION CALL") - response = litellm.completion( - model="j2-light", - messages=messages, - stream=True, - max_tokens=5, - ) - print(response) - - for chunk in response: - print(chunk["choices"][0]) - - -# test_success_callback_streaming() + except Exception as e: + pytest.fail(f"Error occurred: {e}") + raise e diff --git a/litellm/utils.py b/litellm/utils.py index 4468d7c50..b8f4fc1e7 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -376,11 +376,9 @@ class StreamingChoices(OpenAIObject): self.delta = delta else: self.delta = Delta() - - if logprobs is not None: - self.logprobs = logprobs if enhancements is not None: self.enhancements = enhancements + self.logprobs = logprobs def __contains__(self, key): # Define custom behavior for the 'in' operator @@ -8623,6 +8621,10 @@ class CustomStreamWrapper: model_response.choices[0].finish_reason = response_obj[ "finish_reason" ] + if response_obj.get("original_chunk", None) is not None: + model_response.system_fingerprint = getattr( + response_obj["original_chunk"], "system_fingerprint", None + ) if response_obj["logprobs"] is not None: model_response.choices[0].logprobs = response_obj["logprobs"] From cd43630ab8a56cf1e8f12feb8d8f19bc7940cae8 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 23 Feb 2024 21:56:37 -0800 Subject: [PATCH 2/2] test(test_custom_logger.py): skip flaky test --- litellm/tests/test_custom_logger.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index 2747d33e9..e1c87f1a3 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -206,6 +206,7 @@ def test_async_custom_handler_stream(): # test_async_custom_handler_stream() +@pytest.mark.skip(reason="Flaky test") def test_azure_completion_stream(): # [PROD Test] - Do not DELETE # test if completion() + sync custom logger get the same complete stream response