diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 1b0f91fe1..eab202406 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -3040,8 +3040,11 @@ def test_completion_claude_3_function_call_with_streaming(): @pytest.mark.parametrize( - "model", ["gemini/gemini-1.5-flash"] -) # "claude-3-opus-20240229", + "model", + [ + "gemini/gemini-1.5-flash", + ], # "claude-3-opus-20240229" +) # @pytest.mark.asyncio async def test_acompletion_claude_3_function_call_with_streaming(model): litellm.set_verbose = True @@ -3049,41 +3052,45 @@ async def test_acompletion_claude_3_function_call_with_streaming(model): { "type": "function", "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", + "name": "generate_series_of_questions", + "description": "Generate a series of questions, given a topic.", "parameters": { "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", + "questions": { + "type": "array", + "description": "The questions to be generated.", + "items": {"type": "string"}, }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, - "required": ["location"], + "required": ["questions"], }, }, - } + }, ] + SYSTEM_PROMPT = "You are an AI assistant" messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, { "role": "user", - "content": "What's the weather like in Boston today in fahrenheit?", - } + "content": "Generate 3 questions about civil engineering.", + }, ] try: # test without max tokens response = await acompletion( model=model, + # model="claude-3-5-sonnet-20240620", messages=messages, - tools=tools, - tool_choice="required", stream=True, + temperature=0.75, + tools=tools, + stream_options={"include_usage": True}, ) idx = 0 print(f"response: {response}") async for chunk in response: - # print(f"chunk: {chunk}") + print(f"chunk in test: {chunk}") if idx == 0: assert ( chunk.choices[0].delta.tool_calls[0].function.arguments is not None @@ -3513,3 +3520,56 @@ def test_unit_test_custom_stream_wrapper_function_call(): if chunk.choices[0].finish_reason is not None: finish_reason = chunk.choices[0].finish_reason assert finish_reason == "tool_calls" + + ## UNIT TEST RECREATING MODEL RESPONSE + from litellm.types.utils import ( + ChatCompletionDeltaToolCall, + Delta, + Function, + StreamingChoices, + Usage, + ) + + initial_model_response = litellm.ModelResponse( + id="chatcmpl-842826b6-75a1-4ed4-8a68-7655e60654b3", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content="", + role="assistant", + function_call=None, + tool_calls=[ + ChatCompletionDeltaToolCall( + id="7ee88721-bfee-4584-8662-944a23d4c7a5", + function=Function( + arguments='{"questions": ["What are the main challenges facing civil engineers today?", "How has technology impacted the field of civil engineering?", "What are some of the most innovative projects in civil engineering in recent years?"]}', + name="generate_series_of_questions", + ), + type="function", + index=0, + ) + ], + ), + logprobs=None, + ) + ], + created=1720755257, + model="gemini-1.5-flash", + object="chat.completion.chunk", + system_fingerprint=None, + usage=Usage(prompt_tokens=67, completion_tokens=55, total_tokens=122), + stream=True, + ) + + obj_dict = initial_model_response.dict() + + if "usage" in obj_dict: + del obj_dict["usage"] + + new_model = response.model_response_creator(chunk=obj_dict) + + print("\n\n{}\n\n".format(new_model)) + + assert len(new_model.choices[0].delta.tool_calls) > 0 diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 3ecf36ba2..4747a9a87 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -573,6 +573,8 @@ class ModelResponse(OpenAIObject): _new_choice = choice # type: ignore elif isinstance(choice, dict): _new_choice = Choices(**choice) # type: ignore + else: + _new_choice = choice new_choices.append(_new_choice) choices = new_choices else: diff --git a/litellm/utils.py b/litellm/utils.py index d32800764..27906591f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8951,7 +8951,16 @@ class CustomStreamWrapper: model_response.system_fingerprint = self.system_fingerprint model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider model_response._hidden_params["created_at"] = time.time() - model_response.choices = [StreamingChoices(finish_reason=None)] + + if ( + len(model_response.choices) > 0 + and hasattr(model_response.choices[0], "delta") + and model_response.choices[0].delta is not None + ): + # do nothing, if object instantiated + pass + else: + model_response.choices = [StreamingChoices(finish_reason=None)] return model_response def is_delta_empty(self, delta: Delta) -> bool: @@ -9892,7 +9901,6 @@ class CustomStreamWrapper: self.rules.post_call_rules( input=self.response_uptil_now, model=self.model ) - print_verbose(f"final returned processed chunk: {processed_chunk}") self.chunks.append(processed_chunk) if hasattr( processed_chunk, "usage" @@ -9906,6 +9914,7 @@ class CustomStreamWrapper: # Create a new object without the removed attribute processed_chunk = self.model_response_creator(chunk=obj_dict) + print_verbose(f"final returned processed chunk: {processed_chunk}") return processed_chunk raise StopAsyncIteration else: # temporary patch for non-aiohttp async calls