fix(utils.py): fix recreating model response object when stream usage is true

This commit is contained in:
Krrish Dholakia 2024-07-11 21:00:46 -07:00
parent e112379d2f
commit b2e46086dd
3 changed files with 88 additions and 17 deletions

View file

@ -3040,8 +3040,11 @@ def test_completion_claude_3_function_call_with_streaming():
@pytest.mark.parametrize(
"model", ["gemini/gemini-1.5-flash"]
) # "claude-3-opus-20240229",
"model",
[
"gemini/gemini-1.5-flash",
], # "claude-3-opus-20240229"
) #
@pytest.mark.asyncio
async def test_acompletion_claude_3_function_call_with_streaming(model):
litellm.set_verbose = True
@ -3049,41 +3052,45 @@ async def test_acompletion_claude_3_function_call_with_streaming(model):
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"name": "generate_series_of_questions",
"description": "Generate a series of questions, given a topic.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
"questions": {
"type": "array",
"description": "The questions to be generated.",
"items": {"type": "string"},
},
},
"required": ["questions"],
},
},
},
}
]
SYSTEM_PROMPT = "You are an AI assistant"
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": "What's the weather like in Boston today in fahrenheit?",
}
"content": "Generate 3 questions about civil engineering.",
},
]
try:
# test without max tokens
response = await acompletion(
model=model,
# model="claude-3-5-sonnet-20240620",
messages=messages,
tools=tools,
tool_choice="required",
stream=True,
temperature=0.75,
tools=tools,
stream_options={"include_usage": True},
)
idx = 0
print(f"response: {response}")
async for chunk in response:
# print(f"chunk: {chunk}")
print(f"chunk in test: {chunk}")
if idx == 0:
assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
@ -3513,3 +3520,56 @@ def test_unit_test_custom_stream_wrapper_function_call():
if chunk.choices[0].finish_reason is not None:
finish_reason = chunk.choices[0].finish_reason
assert finish_reason == "tool_calls"
## UNIT TEST RECREATING MODEL RESPONSE
from litellm.types.utils import (
ChatCompletionDeltaToolCall,
Delta,
Function,
StreamingChoices,
Usage,
)
initial_model_response = litellm.ModelResponse(
id="chatcmpl-842826b6-75a1-4ed4-8a68-7655e60654b3",
choices=[
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(
content="",
role="assistant",
function_call=None,
tool_calls=[
ChatCompletionDeltaToolCall(
id="7ee88721-bfee-4584-8662-944a23d4c7a5",
function=Function(
arguments='{"questions": ["What are the main challenges facing civil engineers today?", "How has technology impacted the field of civil engineering?", "What are some of the most innovative projects in civil engineering in recent years?"]}',
name="generate_series_of_questions",
),
type="function",
index=0,
)
],
),
logprobs=None,
)
],
created=1720755257,
model="gemini-1.5-flash",
object="chat.completion.chunk",
system_fingerprint=None,
usage=Usage(prompt_tokens=67, completion_tokens=55, total_tokens=122),
stream=True,
)
obj_dict = initial_model_response.dict()
if "usage" in obj_dict:
del obj_dict["usage"]
new_model = response.model_response_creator(chunk=obj_dict)
print("\n\n{}\n\n".format(new_model))
assert len(new_model.choices[0].delta.tool_calls) > 0

View file

@ -573,6 +573,8 @@ class ModelResponse(OpenAIObject):
_new_choice = choice # type: ignore
elif isinstance(choice, dict):
_new_choice = Choices(**choice) # type: ignore
else:
_new_choice = choice
new_choices.append(_new_choice)
choices = new_choices
else:

View file

@ -8951,6 +8951,15 @@ class CustomStreamWrapper:
model_response.system_fingerprint = self.system_fingerprint
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
model_response._hidden_params["created_at"] = time.time()
if (
len(model_response.choices) > 0
and hasattr(model_response.choices[0], "delta")
and model_response.choices[0].delta is not None
):
# do nothing, if object instantiated
pass
else:
model_response.choices = [StreamingChoices(finish_reason=None)]
return model_response
@ -9892,7 +9901,6 @@ class CustomStreamWrapper:
self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model
)
print_verbose(f"final returned processed chunk: {processed_chunk}")
self.chunks.append(processed_chunk)
if hasattr(
processed_chunk, "usage"
@ -9906,6 +9914,7 @@ class CustomStreamWrapper:
# Create a new object without the removed attribute
processed_chunk = self.model_response_creator(chunk=obj_dict)
print_verbose(f"final returned processed chunk: {processed_chunk}")
return processed_chunk
raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls