forked from phoenix/litellm-mirror
fix(utils.py): fix recreating model response object when stream usage is true
This commit is contained in:
parent
e112379d2f
commit
b2e46086dd
3 changed files with 88 additions and 17 deletions
|
@ -3040,8 +3040,11 @@ def test_completion_claude_3_function_call_with_streaming():
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model", ["gemini/gemini-1.5-flash"]
|
||||
) # "claude-3-opus-20240229",
|
||||
"model",
|
||||
[
|
||||
"gemini/gemini-1.5-flash",
|
||||
], # "claude-3-opus-20240229"
|
||||
) #
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_claude_3_function_call_with_streaming(model):
|
||||
litellm.set_verbose = True
|
||||
|
@ -3049,41 +3052,45 @@ async def test_acompletion_claude_3_function_call_with_streaming(model):
|
|||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"name": "generate_series_of_questions",
|
||||
"description": "Generate a series of questions, given a topic.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
"questions": {
|
||||
"type": "array",
|
||||
"description": "The questions to be generated.",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
"required": ["questions"],
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
SYSTEM_PROMPT = "You are an AI assistant"
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in fahrenheit?",
|
||||
}
|
||||
"content": "Generate 3 questions about civil engineering.",
|
||||
},
|
||||
]
|
||||
try:
|
||||
# test without max tokens
|
||||
response = await acompletion(
|
||||
model=model,
|
||||
# model="claude-3-5-sonnet-20240620",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
stream=True,
|
||||
temperature=0.75,
|
||||
tools=tools,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
idx = 0
|
||||
print(f"response: {response}")
|
||||
async for chunk in response:
|
||||
# print(f"chunk: {chunk}")
|
||||
print(f"chunk in test: {chunk}")
|
||||
if idx == 0:
|
||||
assert (
|
||||
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
|
||||
|
@ -3513,3 +3520,56 @@ def test_unit_test_custom_stream_wrapper_function_call():
|
|||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
assert finish_reason == "tool_calls"
|
||||
|
||||
## UNIT TEST RECREATING MODEL RESPONSE
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionDeltaToolCall,
|
||||
Delta,
|
||||
Function,
|
||||
StreamingChoices,
|
||||
Usage,
|
||||
)
|
||||
|
||||
initial_model_response = litellm.ModelResponse(
|
||||
id="chatcmpl-842826b6-75a1-4ed4-8a68-7655e60654b3",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(
|
||||
content="",
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=[
|
||||
ChatCompletionDeltaToolCall(
|
||||
id="7ee88721-bfee-4584-8662-944a23d4c7a5",
|
||||
function=Function(
|
||||
arguments='{"questions": ["What are the main challenges facing civil engineers today?", "How has technology impacted the field of civil engineering?", "What are some of the most innovative projects in civil engineering in recent years?"]}',
|
||||
name="generate_series_of_questions",
|
||||
),
|
||||
type="function",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
created=1720755257,
|
||||
model="gemini-1.5-flash",
|
||||
object="chat.completion.chunk",
|
||||
system_fingerprint=None,
|
||||
usage=Usage(prompt_tokens=67, completion_tokens=55, total_tokens=122),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
obj_dict = initial_model_response.dict()
|
||||
|
||||
if "usage" in obj_dict:
|
||||
del obj_dict["usage"]
|
||||
|
||||
new_model = response.model_response_creator(chunk=obj_dict)
|
||||
|
||||
print("\n\n{}\n\n".format(new_model))
|
||||
|
||||
assert len(new_model.choices[0].delta.tool_calls) > 0
|
||||
|
|
|
@ -573,6 +573,8 @@ class ModelResponse(OpenAIObject):
|
|||
_new_choice = choice # type: ignore
|
||||
elif isinstance(choice, dict):
|
||||
_new_choice = Choices(**choice) # type: ignore
|
||||
else:
|
||||
_new_choice = choice
|
||||
new_choices.append(_new_choice)
|
||||
choices = new_choices
|
||||
else:
|
||||
|
|
|
@ -8951,7 +8951,16 @@ class CustomStreamWrapper:
|
|||
model_response.system_fingerprint = self.system_fingerprint
|
||||
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
|
||||
model_response._hidden_params["created_at"] = time.time()
|
||||
model_response.choices = [StreamingChoices(finish_reason=None)]
|
||||
|
||||
if (
|
||||
len(model_response.choices) > 0
|
||||
and hasattr(model_response.choices[0], "delta")
|
||||
and model_response.choices[0].delta is not None
|
||||
):
|
||||
# do nothing, if object instantiated
|
||||
pass
|
||||
else:
|
||||
model_response.choices = [StreamingChoices(finish_reason=None)]
|
||||
return model_response
|
||||
|
||||
def is_delta_empty(self, delta: Delta) -> bool:
|
||||
|
@ -9892,7 +9901,6 @@ class CustomStreamWrapper:
|
|||
self.rules.post_call_rules(
|
||||
input=self.response_uptil_now, model=self.model
|
||||
)
|
||||
print_verbose(f"final returned processed chunk: {processed_chunk}")
|
||||
self.chunks.append(processed_chunk)
|
||||
if hasattr(
|
||||
processed_chunk, "usage"
|
||||
|
@ -9906,6 +9914,7 @@ class CustomStreamWrapper:
|
|||
|
||||
# Create a new object without the removed attribute
|
||||
processed_chunk = self.model_response_creator(chunk=obj_dict)
|
||||
print_verbose(f"final returned processed chunk: {processed_chunk}")
|
||||
return processed_chunk
|
||||
raise StopAsyncIteration
|
||||
else: # temporary patch for non-aiohttp async calls
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue