fix(utils.py): fix recreating model response object when stream usage is true

This commit is contained in:
Krrish Dholakia 2024-07-11 21:00:46 -07:00
parent e112379d2f
commit b2e46086dd
3 changed files with 88 additions and 17 deletions

View file

@ -3040,8 +3040,11 @@ def test_completion_claude_3_function_call_with_streaming():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", ["gemini/gemini-1.5-flash"] "model",
) # "claude-3-opus-20240229", [
"gemini/gemini-1.5-flash",
], # "claude-3-opus-20240229"
) #
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acompletion_claude_3_function_call_with_streaming(model): async def test_acompletion_claude_3_function_call_with_streaming(model):
litellm.set_verbose = True litellm.set_verbose = True
@ -3049,41 +3052,45 @@ async def test_acompletion_claude_3_function_call_with_streaming(model):
{ {
"type": "function", "type": "function",
"function": { "function": {
"name": "get_current_weather", "name": "generate_series_of_questions",
"description": "Get the current weather in a given location", "description": "Generate a series of questions, given a topic.",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"location": { "questions": {
"type": "string", "type": "array",
"description": "The city and state, e.g. San Francisco, CA", "description": "The questions to be generated.",
}, "items": {"type": "string"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, },
}, },
"required": ["location"], "required": ["questions"],
},
}, },
}, },
}
] ]
SYSTEM_PROMPT = "You are an AI assistant"
messages = [ messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{ {
"role": "user", "role": "user",
"content": "What's the weather like in Boston today in fahrenheit?", "content": "Generate 3 questions about civil engineering.",
} },
] ]
try: try:
# test without max tokens # test without max tokens
response = await acompletion( response = await acompletion(
model=model, model=model,
# model="claude-3-5-sonnet-20240620",
messages=messages, messages=messages,
tools=tools,
tool_choice="required",
stream=True, stream=True,
temperature=0.75,
tools=tools,
stream_options={"include_usage": True},
) )
idx = 0 idx = 0
print(f"response: {response}") print(f"response: {response}")
async for chunk in response: async for chunk in response:
# print(f"chunk: {chunk}") print(f"chunk in test: {chunk}")
if idx == 0: if idx == 0:
assert ( assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None chunk.choices[0].delta.tool_calls[0].function.arguments is not None
@ -3513,3 +3520,56 @@ def test_unit_test_custom_stream_wrapper_function_call():
if chunk.choices[0].finish_reason is not None: if chunk.choices[0].finish_reason is not None:
finish_reason = chunk.choices[0].finish_reason finish_reason = chunk.choices[0].finish_reason
assert finish_reason == "tool_calls" assert finish_reason == "tool_calls"
## UNIT TEST RECREATING MODEL RESPONSE
from litellm.types.utils import (
ChatCompletionDeltaToolCall,
Delta,
Function,
StreamingChoices,
Usage,
)
initial_model_response = litellm.ModelResponse(
id="chatcmpl-842826b6-75a1-4ed4-8a68-7655e60654b3",
choices=[
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(
content="",
role="assistant",
function_call=None,
tool_calls=[
ChatCompletionDeltaToolCall(
id="7ee88721-bfee-4584-8662-944a23d4c7a5",
function=Function(
arguments='{"questions": ["What are the main challenges facing civil engineers today?", "How has technology impacted the field of civil engineering?", "What are some of the most innovative projects in civil engineering in recent years?"]}',
name="generate_series_of_questions",
),
type="function",
index=0,
)
],
),
logprobs=None,
)
],
created=1720755257,
model="gemini-1.5-flash",
object="chat.completion.chunk",
system_fingerprint=None,
usage=Usage(prompt_tokens=67, completion_tokens=55, total_tokens=122),
stream=True,
)
obj_dict = initial_model_response.dict()
if "usage" in obj_dict:
del obj_dict["usage"]
new_model = response.model_response_creator(chunk=obj_dict)
print("\n\n{}\n\n".format(new_model))
assert len(new_model.choices[0].delta.tool_calls) > 0

View file

@ -573,6 +573,8 @@ class ModelResponse(OpenAIObject):
_new_choice = choice # type: ignore _new_choice = choice # type: ignore
elif isinstance(choice, dict): elif isinstance(choice, dict):
_new_choice = Choices(**choice) # type: ignore _new_choice = Choices(**choice) # type: ignore
else:
_new_choice = choice
new_choices.append(_new_choice) new_choices.append(_new_choice)
choices = new_choices choices = new_choices
else: else:

View file

@ -8951,6 +8951,15 @@ class CustomStreamWrapper:
model_response.system_fingerprint = self.system_fingerprint model_response.system_fingerprint = self.system_fingerprint
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
model_response._hidden_params["created_at"] = time.time() model_response._hidden_params["created_at"] = time.time()
if (
len(model_response.choices) > 0
and hasattr(model_response.choices[0], "delta")
and model_response.choices[0].delta is not None
):
# do nothing, if object instantiated
pass
else:
model_response.choices = [StreamingChoices(finish_reason=None)] model_response.choices = [StreamingChoices(finish_reason=None)]
return model_response return model_response
@ -9892,7 +9901,6 @@ class CustomStreamWrapper:
self.rules.post_call_rules( self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model input=self.response_uptil_now, model=self.model
) )
print_verbose(f"final returned processed chunk: {processed_chunk}")
self.chunks.append(processed_chunk) self.chunks.append(processed_chunk)
if hasattr( if hasattr(
processed_chunk, "usage" processed_chunk, "usage"
@ -9906,6 +9914,7 @@ class CustomStreamWrapper:
# Create a new object without the removed attribute # Create a new object without the removed attribute
processed_chunk = self.model_response_creator(chunk=obj_dict) processed_chunk = self.model_response_creator(chunk=obj_dict)
print_verbose(f"final returned processed chunk: {processed_chunk}")
return processed_chunk return processed_chunk
raise StopAsyncIteration raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls else: # temporary patch for non-aiohttp async calls