mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(utils.py): ensure argument is always a string
This commit is contained in:
parent
673e9f0703
commit
b6017115e3
2 changed files with 147 additions and 46 deletions
|
@ -4,6 +4,7 @@
|
|||
import sys, os, asyncio
|
||||
import traceback
|
||||
import time, pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -92,7 +93,7 @@ def validate_second_format(chunk):
|
|||
|
||||
for choice in chunk["choices"]:
|
||||
assert isinstance(choice["index"], int), "'index' should be an integer."
|
||||
assert "role" not in choice["delta"], "'role' should be a string."
|
||||
assert hasattr(choice["delta"], "role"), "'role' should be a string."
|
||||
# openai v1.0.0 returns content as None
|
||||
assert (choice["finish_reason"] is None) or isinstance(
|
||||
choice["finish_reason"], str
|
||||
|
@ -1455,8 +1456,8 @@ first_openai_function_call_example = {
|
|||
|
||||
|
||||
def validate_first_function_call_chunk_structure(item):
|
||||
if not isinstance(item, dict):
|
||||
raise Exception("Incorrect format")
|
||||
if not (isinstance(item, dict) or isinstance(item, litellm.ModelResponse)):
|
||||
raise Exception(f"Incorrect format, type of item: {type(item)}")
|
||||
|
||||
required_keys = {"id", "object", "created", "model", "choices"}
|
||||
for key in required_keys:
|
||||
|
@ -1468,27 +1469,42 @@ def validate_first_function_call_chunk_structure(item):
|
|||
|
||||
required_keys_in_choices_array = {"index", "delta", "finish_reason"}
|
||||
for choice in item["choices"]:
|
||||
if not isinstance(choice, dict):
|
||||
raise Exception("Incorrect format")
|
||||
if not (
|
||||
isinstance(choice, dict)
|
||||
or isinstance(choice, litellm.utils.StreamingChoices)
|
||||
):
|
||||
raise Exception(f"Incorrect format, type of choice: {type(choice)}")
|
||||
for key in required_keys_in_choices_array:
|
||||
if key not in choice:
|
||||
raise Exception("Incorrect format")
|
||||
|
||||
if not isinstance(choice["delta"], dict):
|
||||
raise Exception("Incorrect format")
|
||||
if not (
|
||||
isinstance(choice["delta"], dict)
|
||||
or isinstance(choice["delta"], litellm.utils.Delta)
|
||||
):
|
||||
raise Exception(
|
||||
f"Incorrect format, type of choice: {type(choice['delta'])}"
|
||||
)
|
||||
|
||||
required_keys_in_delta = {"role", "content", "function_call"}
|
||||
for key in required_keys_in_delta:
|
||||
if key not in choice["delta"]:
|
||||
raise Exception("Incorrect format")
|
||||
|
||||
if not isinstance(choice["delta"]["function_call"], dict):
|
||||
raise Exception("Incorrect format")
|
||||
if not (
|
||||
isinstance(choice["delta"]["function_call"], dict)
|
||||
or isinstance(choice["delta"]["function_call"], BaseModel)
|
||||
):
|
||||
raise Exception(
|
||||
f"Incorrect format, type of function call: {type(choice['delta']['function_call'])}"
|
||||
)
|
||||
|
||||
required_keys_in_function_call = {"name", "arguments"}
|
||||
for key in required_keys_in_function_call:
|
||||
if key not in choice["delta"]["function_call"]:
|
||||
raise Exception("Incorrect format")
|
||||
if not hasattr(choice["delta"]["function_call"], key):
|
||||
raise Exception(
|
||||
f"Incorrect format, expected key={key}; actual keys: {choice['delta']['function_call']}, eval: {hasattr(choice['delta']['function_call'], key)}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -1547,7 +1563,7 @@ final_function_call_chunk_example = {
|
|||
|
||||
|
||||
def validate_final_function_call_chunk_structure(data):
|
||||
if not isinstance(data, dict):
|
||||
if not (isinstance(data, dict) or isinstance(data, litellm.ModelResponse)):
|
||||
raise Exception("Incorrect format")
|
||||
|
||||
required_keys = {"id", "object", "created", "model", "choices"}
|
||||
|
@ -1560,7 +1576,9 @@ def validate_final_function_call_chunk_structure(data):
|
|||
|
||||
required_keys_in_choices_array = {"index", "delta", "finish_reason"}
|
||||
for choice in data["choices"]:
|
||||
if not isinstance(choice, dict):
|
||||
if not (
|
||||
isinstance(choice, dict) or isinstance(choice["delta"], litellm.utils.Delta)
|
||||
):
|
||||
raise Exception("Incorrect format")
|
||||
for key in required_keys_in_choices_array:
|
||||
if key not in choice:
|
||||
|
@ -1592,37 +1610,88 @@ def streaming_and_function_calling_format_tests(idx, chunk):
|
|||
return extracted_chunk, finished
|
||||
|
||||
|
||||
# def test_openai_streaming_and_function_calling():
|
||||
# function1 = [
|
||||
# {
|
||||
# "name": "get_current_weather",
|
||||
# "description": "Get the current weather in a given location",
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "location": {
|
||||
# "type": "string",
|
||||
# "description": "The city and state, e.g. San Francisco, CA",
|
||||
# },
|
||||
# "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
# },
|
||||
# "required": ["location"],
|
||||
# },
|
||||
# }
|
||||
# ]
|
||||
# messages=[{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||
# try:
|
||||
# response = completion(
|
||||
# model="gpt-3.5-turbo", functions=function1, messages=messages, stream=True,
|
||||
# )
|
||||
# # Add any assertions here to check the response
|
||||
# for idx, chunk in enumerate(response):
|
||||
# streaming_and_function_calling_format_tests(idx=idx, chunk=chunk)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
# raise e
|
||||
def test_openai_streaming_and_function_calling():
|
||||
tools = [
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||
try:
|
||||
response = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
tools=tools,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
for idx, chunk in enumerate(response):
|
||||
streaming_and_function_calling_format_tests(idx=idx, chunk=chunk)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
raise e
|
||||
|
||||
# test_openai_streaming_and_function_calling()
|
||||
|
||||
def test_azure_streaming_and_function_calling():
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||
try:
|
||||
response = completion(
|
||||
model="azure/gpt-4-nov-release",
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
api_base=os.getenv("AZURE_FRANCE_API_BASE"),
|
||||
api_key=os.getenv("AZURE_FRANCE_API_KEY"),
|
||||
api_version="2024-02-15-preview",
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
for idx, chunk in enumerate(response):
|
||||
if idx == 0:
|
||||
assert (
|
||||
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
|
||||
)
|
||||
assert isinstance(
|
||||
chunk.choices[0].delta.tool_calls[0].function.arguments, str
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
# test_azure_streaming_and_function_calling()
|
||||
|
||||
|
||||
def test_success_callback_streaming():
|
||||
|
|
|
@ -258,11 +258,14 @@ class Message(OpenAIObject):
|
|||
|
||||
|
||||
class Delta(OpenAIObject):
|
||||
def __init__(self, content=None, role=None, **params):
|
||||
def __init__(
|
||||
self, content=None, role=None, function_call=None, tool_calls=None, **params
|
||||
):
|
||||
super(Delta, self).__init__(**params)
|
||||
self.content = content
|
||||
if role is not None:
|
||||
self.role = role
|
||||
self.function_call = function_call
|
||||
self.tool_calls = tool_calls
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
|
@ -8675,8 +8678,37 @@ class CustomStreamWrapper:
|
|||
):
|
||||
try:
|
||||
delta = dict(original_chunk.choices[0].delta)
|
||||
## AZURE - check if arguments is not None
|
||||
if (
|
||||
original_chunk.choices[0].delta.function_call
|
||||
is not None
|
||||
):
|
||||
if (
|
||||
getattr(
|
||||
original_chunk.choices[0].delta.function_call,
|
||||
"arguments",
|
||||
)
|
||||
is None
|
||||
):
|
||||
original_chunk.choices[
|
||||
0
|
||||
].delta.function_call.arguments = ""
|
||||
elif original_chunk.choices[0].delta.tool_calls is not None:
|
||||
if isinstance(
|
||||
original_chunk.choices[0].delta.tool_calls, list
|
||||
):
|
||||
for t in original_chunk.choices[0].delta.tool_calls:
|
||||
if (
|
||||
getattr(
|
||||
t.function,
|
||||
"arguments",
|
||||
)
|
||||
is None
|
||||
):
|
||||
t.function.arguments = ""
|
||||
model_response.choices[0].delta = Delta(**delta)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
model_response.choices[0].delta = Delta()
|
||||
else:
|
||||
return
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue