diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index e310092bd..7d099afc6 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -755,27 +755,40 @@ async def test_completion_gemini_stream(sync_mode): try: litellm.set_verbose = True print("Streaming gemini response") - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, + function1 = [ { - "role": "user", - "content": "Who was Alexander?", - }, + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } ] + messages = [{"role": "user", "content": "What is the weather like in Boston?"}] print("testing gemini streaming") complete_response = "" # Add any assertions here to check the response non_empty_chunks = 0 - + chunks = [] if sync_mode: response = completion( model="gemini/gemini-1.5-flash", messages=messages, stream=True, + functions=function1, ) for idx, chunk in enumerate(response): print(chunk) + chunks.append(chunk) # print(chunk.choices[0].delta) chunk, finished = streaming_format_tests(idx, chunk) if finished: @@ -787,11 +800,13 @@ async def test_completion_gemini_stream(sync_mode): model="gemini/gemini-1.5-flash", messages=messages, stream=True, + functions=function1, ) idx = 0 async for chunk in response: print(chunk) + chunks.append(chunk) # print(chunk.choices[0].delta) chunk, finished = streaming_format_tests(idx, chunk) if finished: @@ -800,10 +815,17 @@ async def test_completion_gemini_stream(sync_mode): complete_response += chunk idx += 1 - if complete_response.strip() == "": - raise Exception("Empty response received") + # if complete_response.strip() == "": + # raise Exception("Empty response received") print(f"completion_response: {complete_response}") - assert non_empty_chunks > 1 + + complete_response = litellm.stream_chunk_builder( + chunks=chunks, messages=messages + ) + + assert complete_response.choices[0].message.function_call is not None + + # assert non_empty_chunks > 1 except litellm.InternalServerError as e: pass except litellm.RateLimitError as e: diff --git a/litellm/utils.py b/litellm/utils.py index 61beabe95..b386a54ec 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8771,6 +8771,7 @@ class CustomStreamWrapper: self.chunks: List = ( [] ) # keep track of the returned chunks - used for calculating the input/output tokens for stream options + self.is_function_call = self.check_is_function_call(logging_obj=logging_obj) def __iter__(self): return self @@ -8778,6 +8779,19 @@ class CustomStreamWrapper: def __aiter__(self): return self + def check_is_function_call(self, logging_obj) -> bool: + if hasattr(logging_obj, "optional_params") and isinstance( + logging_obj.optional_params, dict + ): + if ( + "litellm_param_is_function_call" in logging_obj.optional_params + and logging_obj.optional_params["litellm_param_is_function_call"] + is not None + ): + return True + + return False + def process_chunk(self, chunk: str): """ NLP Cloud streaming returns the entire response, for each chunk. Process this, to only return the delta. @@ -10275,6 +10289,12 @@ class CustomStreamWrapper: ## CHECK FOR TOOL USE if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0: + if self.is_function_call is True: # user passed in 'functions' param + completion_obj["function_call"] = completion_obj["tool_calls"][0][ + "function" + ] + completion_obj["tool_calls"] = None + self.tool_call = True ## RETURN ARG @@ -10286,8 +10306,13 @@ class CustomStreamWrapper: ) or ( "tool_calls" in completion_obj + and completion_obj["tool_calls"] is not None and len(completion_obj["tool_calls"]) > 0 ) + or ( + "function_call" in completion_obj + and completion_obj["function_call"] is not None + ) ): # cannot set content of an OpenAI Object to be an empty string self.safety_checker() hold, model_response_str = self.check_special_tokens( @@ -10347,6 +10372,7 @@ class CustomStreamWrapper: if self.sent_first_chunk is False: completion_obj["role"] = "assistant" self.sent_first_chunk = True + model_response.choices[0].delta = Delta(**completion_obj) if completion_obj.get("index") is not None: model_response.choices[0].index = completion_obj.get(