fix(utils.py): ensure argument is always a string

This commit is contained in:
Krrish Dholakia 2024-02-22 15:15:56 -08:00
parent 673e9f0703
commit b6017115e3
2 changed files with 147 additions and 46 deletions

View file

@ -4,6 +4,7 @@
import sys, os, asyncio
import traceback
import time, pytest
from pydantic import BaseModel
sys.path.insert(
0, os.path.abspath("../..")
@ -92,7 +93,7 @@ def validate_second_format(chunk):
for choice in chunk["choices"]:
assert isinstance(choice["index"], int), "'index' should be an integer."
assert "role" not in choice["delta"], "'role' should be a string."
assert hasattr(choice["delta"], "role"), "'role' should be a string."
# openai v1.0.0 returns content as None
assert (choice["finish_reason"] is None) or isinstance(
choice["finish_reason"], str
@ -1455,8 +1456,8 @@ first_openai_function_call_example = {
def validate_first_function_call_chunk_structure(item):
if not isinstance(item, dict):
raise Exception("Incorrect format")
if not (isinstance(item, dict) or isinstance(item, litellm.ModelResponse)):
raise Exception(f"Incorrect format, type of item: {type(item)}")
required_keys = {"id", "object", "created", "model", "choices"}
for key in required_keys:
@ -1468,27 +1469,42 @@ def validate_first_function_call_chunk_structure(item):
required_keys_in_choices_array = {"index", "delta", "finish_reason"}
for choice in item["choices"]:
if not isinstance(choice, dict):
raise Exception("Incorrect format")
if not (
isinstance(choice, dict)
or isinstance(choice, litellm.utils.StreamingChoices)
):
raise Exception(f"Incorrect format, type of choice: {type(choice)}")
for key in required_keys_in_choices_array:
if key not in choice:
raise Exception("Incorrect format")
if not isinstance(choice["delta"], dict):
raise Exception("Incorrect format")
if not (
isinstance(choice["delta"], dict)
or isinstance(choice["delta"], litellm.utils.Delta)
):
raise Exception(
f"Incorrect format, type of choice: {type(choice['delta'])}"
)
required_keys_in_delta = {"role", "content", "function_call"}
for key in required_keys_in_delta:
if key not in choice["delta"]:
raise Exception("Incorrect format")
if not isinstance(choice["delta"]["function_call"], dict):
raise Exception("Incorrect format")
if not (
isinstance(choice["delta"]["function_call"], dict)
or isinstance(choice["delta"]["function_call"], BaseModel)
):
raise Exception(
f"Incorrect format, type of function call: {type(choice['delta']['function_call'])}"
)
required_keys_in_function_call = {"name", "arguments"}
for key in required_keys_in_function_call:
if key not in choice["delta"]["function_call"]:
raise Exception("Incorrect format")
if not hasattr(choice["delta"]["function_call"], key):
raise Exception(
f"Incorrect format, expected key={key}; actual keys: {choice['delta']['function_call']}, eval: {hasattr(choice['delta']['function_call'], key)}"
)
return True
@ -1547,7 +1563,7 @@ final_function_call_chunk_example = {
def validate_final_function_call_chunk_structure(data):
if not isinstance(data, dict):
if not (isinstance(data, dict) or isinstance(data, litellm.ModelResponse)):
raise Exception("Incorrect format")
required_keys = {"id", "object", "created", "model", "choices"}
@ -1560,7 +1576,9 @@ def validate_final_function_call_chunk_structure(data):
required_keys_in_choices_array = {"index", "delta", "finish_reason"}
for choice in data["choices"]:
if not isinstance(choice, dict):
if not (
isinstance(choice, dict) or isinstance(choice["delta"], litellm.utils.Delta)
):
raise Exception("Incorrect format")
for key in required_keys_in_choices_array:
if key not in choice:
@ -1592,37 +1610,88 @@ def streaming_and_function_calling_format_tests(idx, chunk):
return extracted_chunk, finished
# def test_openai_streaming_and_function_calling():
# function1 = [
# {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA",
# },
# "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
# },
# "required": ["location"],
# },
# }
# ]
# messages=[{"role": "user", "content": "What is the weather like in Boston?"}]
# try:
# response = completion(
# model="gpt-3.5-turbo", functions=function1, messages=messages, stream=True,
# )
# # Add any assertions here to check the response
# for idx, chunk in enumerate(response):
# streaming_and_function_calling_format_tests(idx=idx, chunk=chunk)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# raise e
def test_openai_streaming_and_function_calling():
tools = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
try:
response = completion(
model="gpt-3.5-turbo",
tools=tools,
messages=messages,
stream=True,
)
# Add any assertions here to check the response
for idx, chunk in enumerate(response):
streaming_and_function_calling_format_tests(idx=idx, chunk=chunk)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
raise e
# test_openai_streaming_and_function_calling()
def test_azure_streaming_and_function_calling():
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
try:
response = completion(
model="azure/gpt-4-nov-release",
tools=tools,
tool_choice="auto",
messages=messages,
stream=True,
api_base=os.getenv("AZURE_FRANCE_API_BASE"),
api_key=os.getenv("AZURE_FRANCE_API_KEY"),
api_version="2024-02-15-preview",
)
# Add any assertions here to check the response
for idx, chunk in enumerate(response):
if idx == 0:
assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
)
assert isinstance(
chunk.choices[0].delta.tool_calls[0].function.arguments, str
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
raise e
# test_azure_streaming_and_function_calling()
def test_success_callback_streaming():

View file

@ -258,11 +258,14 @@ class Message(OpenAIObject):
class Delta(OpenAIObject):
def __init__(self, content=None, role=None, **params):
def __init__(
self, content=None, role=None, function_call=None, tool_calls=None, **params
):
super(Delta, self).__init__(**params)
self.content = content
if role is not None:
self.role = role
self.function_call = function_call
self.tool_calls = tool_calls
def __contains__(self, key):
# Define custom behavior for the 'in' operator
@ -8675,8 +8678,37 @@ class CustomStreamWrapper:
):
try:
delta = dict(original_chunk.choices[0].delta)
## AZURE - check if arguments is not None
if (
original_chunk.choices[0].delta.function_call
is not None
):
if (
getattr(
original_chunk.choices[0].delta.function_call,
"arguments",
)
is None
):
original_chunk.choices[
0
].delta.function_call.arguments = ""
elif original_chunk.choices[0].delta.tool_calls is not None:
if isinstance(
original_chunk.choices[0].delta.tool_calls, list
):
for t in original_chunk.choices[0].delta.tool_calls:
if (
getattr(
t.function,
"arguments",
)
is None
):
t.function.arguments = ""
model_response.choices[0].delta = Delta(**delta)
except Exception as e:
traceback.print_exc()
model_response.choices[0].delta = Delta()
else:
return