diff --git a/litellm/main.py b/litellm/main.py index 33fad52cce..7ffb4db7ba 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3677,6 +3677,7 @@ def stream_chunk_builder( response["usage"]["total_tokens"] = ( response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] ) + return convert_to_model_response_object( response_object=response, model_response_object=model_response, diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 34370b057b..d58d68507a 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -138,7 +138,7 @@ def test_vertex_ai(): def test_vertex_ai_stream(): load_vertex_ai_credentials() - litellm.set_verbose = False + litellm.set_verbose = True litellm.vertex_project = "reliablekeys" import random diff --git a/litellm/tests/test_function_calling.py b/litellm/tests/test_function_calling.py index 2fcbdc9460..ffef8f6594 100644 --- a/litellm/tests/test_function_calling.py +++ b/litellm/tests/test_function_calling.py @@ -124,11 +124,12 @@ def test_parallel_function_call(): pytest.fail(f"Error occurred: {e}") -test_parallel_function_call() +# test_parallel_function_call() def test_parallel_function_call_stream(): try: + litellm.set_verbose = True # Step 1: send the conversation and available functions to the model messages = [ { @@ -217,4 +218,4 @@ def test_parallel_function_call_stream(): pytest.fail(f"Error occurred: {e}") -test_parallel_function_call_stream() +# test_parallel_function_call_stream() diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 66e8be4cbe..8c3187bd4a 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -801,7 +801,6 @@ def test_completion_bedrock_claude_stream(): raise Exception("finish reason not set for last chunk") if complete_response.strip() == "": raise Exception("Empty response received") - print(f"completion_response: {complete_response}") except RateLimitError: pass except Exception as e: @@ -1907,6 +1906,8 @@ def test_azure_streaming_and_function_calling(): @pytest.mark.asyncio async def test_azure_astreaming_and_function_calling(): + import uuid + tools = [ { "type": "function", @@ -1927,7 +1928,20 @@ async def test_azure_astreaming_and_function_calling(): }, } ] - messages = [{"role": "user", "content": "What is the weather like in Boston?"}] + messages = [ + { + "role": "user", + "content": f"What is the weather like in Boston? {uuid.uuid4()}", + } + ] + from litellm.caching import Cache + + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) try: response = await litellm.acompletion( model="azure/gpt-4-nov-release", @@ -1938,6 +1952,7 @@ async def test_azure_astreaming_and_function_calling(): api_base=os.getenv("AZURE_FRANCE_API_BASE"), api_key=os.getenv("AZURE_FRANCE_API_KEY"), api_version="2024-02-15-preview", + caching=True, ) # Add any assertions here to check the response idx = 0 @@ -1957,6 +1972,36 @@ async def test_azure_astreaming_and_function_calling(): validate_final_streaming_function_calling_chunk(chunk=chunk) idx += 1 + ## CACHING TEST + print("\n\nCACHING TESTS\n\n") + response = await litellm.acompletion( + model="azure/gpt-4-nov-release", + tools=tools, + tool_choice="auto", + messages=messages, + stream=True, + api_base=os.getenv("AZURE_FRANCE_API_BASE"), + api_key=os.getenv("AZURE_FRANCE_API_KEY"), + api_version="2024-02-15-preview", + caching=True, + ) + # Add any assertions here to check the response + idx = 0 + async for chunk in response: + print(f"chunk: {chunk}") + if idx == 0: + assert ( + chunk.choices[0].delta.tool_calls[0].function.arguments is not None + ) + assert isinstance( + chunk.choices[0].delta.tool_calls[0].function.arguments, str + ) + validate_first_streaming_function_calling_chunk(chunk=chunk) + elif idx == 1: + validate_second_streaming_function_calling_chunk(chunk=chunk) + elif chunk.choices[0].finish_reason is not None: # last chunk + validate_final_streaming_function_calling_chunk(chunk=chunk) + idx += 1 except Exception as e: pytest.fail(f"Error occurred: {e}") raise e diff --git a/litellm/utils.py b/litellm/utils.py index 0e718d31a3..c59a3c1e5c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -213,6 +213,13 @@ class Function(OpenAIObject): name: str +class ChatCompletionDeltaToolCall(OpenAIObject): + id: str + function: Function + type: str + index: int + + class ChatCompletionMessageToolCall(OpenAIObject): id: str function: Function @@ -269,7 +276,14 @@ class Delta(OpenAIObject): self.content = content self.role = role self.function_call = function_call - self.tool_calls = tool_calls + if tool_calls is not None and isinstance(tool_calls, dict): + self.tool_calls = [] + for tool_call in tool_calls: + if tool_call.get("index", None) is None: + tool_call["index"] = 0 + self.tool_calls.append(ChatCompletionDeltaToolCall(**tool_call)) + else: + self.tool_calls = tool_calls def __contains__(self, key): # Define custom behavior for the 'in' operator @@ -1182,7 +1196,8 @@ class Logging: start_time=start_time, end_time=end_time, ) - except: + except Exception as e: + complete_streaming_response = None else: self.sync_streaming_chunks.append(result) @@ -5847,6 +5862,18 @@ async def convert_to_streaming_response_async(response_object: Optional[dict] = choice_list = [] for idx, choice in enumerate(response_object["choices"]): + if ( + choice["message"].get("tool_calls", None) is not None + and isinstance(choice["message"]["tool_calls"], list) + and len(choice["message"]["tool_calls"]) > 0 + and isinstance(choice["message"]["tool_calls"][0], dict) + ): + pydantic_tool_calls = [] + for index, t in enumerate(choice["message"]["tool_calls"]): + if "index" not in t: + t["index"] = index + pydantic_tool_calls.append(ChatCompletionDeltaToolCall(**t)) + choice["message"]["tool_calls"] = pydantic_tool_calls delta = Delta( content=choice["message"].get("content", None), role=choice["message"]["role"], @@ -8650,6 +8677,7 @@ class CustomStreamWrapper: "text": chunk.choices[0].delta.content, "is_finished": True, "finish_reason": chunk.choices[0].finish_reason, + "original_chunk": chunk, } completion_obj["content"] = response_obj["text"] @@ -8681,13 +8709,82 @@ class CustomStreamWrapper: model_response.model = self.model print_verbose( - f"model_response: {model_response}; completion_obj: {completion_obj}" - ) - print_verbose( - f"model_response finish reason 3: {model_response.choices[0].finish_reason}" + f"model_response finish reason 3: {model_response.choices[0].finish_reason}; response_obj={response_obj}" ) + ## FUNCTION CALL PARSING if ( - len(completion_obj["content"]) > 0 + response_obj is not None + and response_obj.get("original_chunk", None) is not None + ): # function / tool calling branch - only set for openai/azure compatible endpoints + # enter this branch when no content has been passed in response + original_chunk = response_obj.get("original_chunk", None) + model_response.id = original_chunk.id + if len(original_chunk.choices) > 0: + if ( + original_chunk.choices[0].delta.function_call is not None + or original_chunk.choices[0].delta.tool_calls is not None + ): + try: + delta = dict(original_chunk.choices[0].delta) + model_response.system_fingerprint = ( + original_chunk.system_fingerprint + ) + ## AZURE - check if arguments is not None + if ( + original_chunk.choices[0].delta.function_call + is not None + ): + if ( + getattr( + original_chunk.choices[0].delta.function_call, + "arguments", + ) + is None + ): + original_chunk.choices[ + 0 + ].delta.function_call.arguments = "" + elif original_chunk.choices[0].delta.tool_calls is not None: + if isinstance( + original_chunk.choices[0].delta.tool_calls, list + ): + for t in original_chunk.choices[0].delta.tool_calls: + if hasattr(t, "functions") and hasattr( + t.functions, "arguments" + ): + if ( + getattr( + t.function, + "arguments", + ) + is None + ): + t.function.arguments = "" + model_response.choices[0].delta = Delta(**delta) + except Exception as e: + traceback.print_exc() + model_response.choices[0].delta = Delta() + else: + try: + delta = dict(original_chunk.choices[0].delta) + print_verbose(f"original delta: {delta}") + model_response.choices[0].delta = Delta(**delta) + print_verbose( + f"new delta: {model_response.choices[0].delta}" + ) + except Exception as e: + model_response.choices[0].delta = Delta() + else: + return + print_verbose( + f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}" + ) + print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}") + ## RETURN ARG + if ( + "content" in completion_obj + and isinstance(completion_obj["content"], str) + and len(completion_obj["content"]) > 0 ): # cannot set content of an OpenAI Object to be an empty string hold, model_response_str = self.check_special_tokens( chunk=completion_obj["content"], @@ -8739,7 +8836,7 @@ class CustomStreamWrapper: return model_response else: return - elif model_response.choices[0].finish_reason: + elif model_response.choices[0].finish_reason is not None: # flush any remaining holding chunk if len(self.holding_chunk) > 0: if model_response.choices[0].delta.content is None: @@ -8749,61 +8846,15 @@ class CustomStreamWrapper: self.holding_chunk + model_response.choices[0].delta.content ) self.holding_chunk = "" + # get any function call arguments model_response.choices[0].finish_reason = map_finish_reason( model_response.choices[0].finish_reason ) # ensure consistent output to openai return model_response elif ( - response_obj is not None - and response_obj.get("original_chunk", None) is not None - ): # function / tool calling branch - only set for openai/azure compatible endpoints - # enter this branch when no content has been passed in response - original_chunk = response_obj.get("original_chunk", None) - model_response.id = original_chunk.id - if len(original_chunk.choices) > 0: - if ( - original_chunk.choices[0].delta.function_call is not None - or original_chunk.choices[0].delta.tool_calls is not None - ): - try: - delta = dict(original_chunk.choices[0].delta) - ## AZURE - check if arguments is not None - if ( - original_chunk.choices[0].delta.function_call - is not None - ): - if ( - getattr( - original_chunk.choices[0].delta.function_call, - "arguments", - ) - is None - ): - original_chunk.choices[ - 0 - ].delta.function_call.arguments = "" - elif original_chunk.choices[0].delta.tool_calls is not None: - if isinstance( - original_chunk.choices[0].delta.tool_calls, list - ): - for t in original_chunk.choices[0].delta.tool_calls: - if ( - getattr( - t.function, - "arguments", - ) - is None - ): - t.function.arguments = "" - model_response.choices[0].delta = Delta(**delta) - except Exception as e: - traceback.print_exc() - model_response.choices[0].delta = Delta() - else: - return - else: - return - model_response.system_fingerprint = original_chunk.system_fingerprint + model_response.choices[0].delta.tool_calls is not None + or model_response.choices[0].delta.function_call is not None + ): if self.sent_first_chunk == False: model_response.choices[0].delta["role"] = "assistant" self.sent_first_chunk = True @@ -8856,6 +8907,7 @@ class CustomStreamWrapper: print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}") response: Optional[ModelResponse] = self.chunk_creator(chunk=chunk) print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}") + if response is None: continue ## LOGGING @@ -8900,7 +8952,11 @@ class CustomStreamWrapper: print_verbose(f"value of async chunk: {chunk}") if chunk == "None" or chunk is None: raise Exception - elif self.custom_llm_provider == "gemini" and len(chunk.parts) == 0: + elif ( + self.custom_llm_provider == "gemini" + and hasattr(chunk, "parts") + and len(chunk.parts) == 0 + ): continue # chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks. # __anext__ also calls async_success_handler, which does logging @@ -8929,6 +8985,7 @@ class CustomStreamWrapper: self.rules.post_call_rules( input=self.response_uptil_now, model=self.model ) + print_verbose(f"final returned processed chunk: {processed_chunk}") return processed_chunk raise StopAsyncIteration else: # temporary patch for non-aiohttp async calls