fix(utils.py): support returning caching streaming response for function calling streaming calls

This commit is contained in:
Krrish Dholakia 2024-02-26 12:31:00 -08:00
parent 92ff9a1a79
commit dfb1d34e26
2 changed files with 134 additions and 80 deletions

View file

@ -1907,6 +1907,8 @@ def test_azure_streaming_and_function_calling():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_azure_astreaming_and_function_calling(): async def test_azure_astreaming_and_function_calling():
import uuid
tools = [ tools = [
{ {
"type": "function", "type": "function",
@ -1927,7 +1929,20 @@ async def test_azure_astreaming_and_function_calling():
}, },
} }
] ]
messages = [{"role": "user", "content": "What is the weather like in Boston?"}] messages = [
{
"role": "user",
"content": f"What is the weather like in Boston? {uuid.uuid4()}",
}
]
from litellm.caching import Cache
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
try: try:
response = await litellm.acompletion( response = await litellm.acompletion(
model="azure/gpt-4-nov-release", model="azure/gpt-4-nov-release",
@ -1938,6 +1953,7 @@ async def test_azure_astreaming_and_function_calling():
api_base=os.getenv("AZURE_FRANCE_API_BASE"), api_base=os.getenv("AZURE_FRANCE_API_BASE"),
api_key=os.getenv("AZURE_FRANCE_API_KEY"), api_key=os.getenv("AZURE_FRANCE_API_KEY"),
api_version="2024-02-15-preview", api_version="2024-02-15-preview",
caching=True,
) )
# Add any assertions here to check the response # Add any assertions here to check the response
idx = 0 idx = 0
@ -1957,6 +1973,36 @@ async def test_azure_astreaming_and_function_calling():
validate_final_streaming_function_calling_chunk(chunk=chunk) validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1 idx += 1
## CACHING TEST
print("\n\nCACHING TESTS\n\n")
response = await litellm.acompletion(
model="azure/gpt-4-nov-release",
tools=tools,
tool_choice="auto",
messages=messages,
stream=True,
api_base=os.getenv("AZURE_FRANCE_API_BASE"),
api_key=os.getenv("AZURE_FRANCE_API_KEY"),
api_version="2024-02-15-preview",
caching=True,
)
# Add any assertions here to check the response
idx = 0
async for chunk in response:
print(f"chunk: {chunk}")
if idx == 0:
assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
)
assert isinstance(
chunk.choices[0].delta.tool_calls[0].function.arguments, str
)
validate_first_streaming_function_calling_chunk(chunk=chunk)
elif idx == 1:
validate_second_streaming_function_calling_chunk(chunk=chunk)
elif chunk.choices[0].finish_reason is not None: # last chunk
validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
raise e raise e

View file

@ -213,6 +213,13 @@ class Function(OpenAIObject):
name: str name: str
class ChatCompletionDeltaToolCall(OpenAIObject):
id: str
function: Function
type: str
index: int
class ChatCompletionMessageToolCall(OpenAIObject): class ChatCompletionMessageToolCall(OpenAIObject):
id: str id: str
function: Function function: Function
@ -269,7 +276,15 @@ class Delta(OpenAIObject):
self.content = content self.content = content
self.role = role self.role = role
self.function_call = function_call self.function_call = function_call
self.tool_calls = tool_calls if tool_calls is not None:
if isinstance(tool_calls, dict):
self.tool_calls = []
for tool_call in tool_calls:
if tool_call.get("index", None) is None:
tool_call["index"] = 0
self.tool_calls.append(ChatCompletionDeltaToolCall(**tool_call))
else:
self.tool_calls = tool_calls
def __contains__(self, key): def __contains__(self, key):
# Define custom behavior for the 'in' operator # Define custom behavior for the 'in' operator
@ -5847,6 +5862,18 @@ async def convert_to_streaming_response_async(response_object: Optional[dict] =
choice_list = [] choice_list = []
for idx, choice in enumerate(response_object["choices"]): for idx, choice in enumerate(response_object["choices"]):
if (
choice["message"].get("tool_calls", None) is not None
and isinstance(choice["message"]["tool_calls"], list)
and len(choice["message"]["tool_calls"]) > 0
and isinstance(choice["message"]["tool_calls"][0], dict)
):
pydantic_tool_calls = []
for index, t in enumerate(choice["message"]["tool_calls"]):
if "index" not in t:
t["index"] = index
pydantic_tool_calls.append(ChatCompletionDeltaToolCall(**t))
choice["message"]["tool_calls"] = pydantic_tool_calls
delta = Delta( delta = Delta(
content=choice["message"].get("content", None), content=choice["message"].get("content", None),
role=choice["message"]["role"], role=choice["message"]["role"],
@ -8646,6 +8673,7 @@ class CustomStreamWrapper:
"text": chunk.choices[0].delta.content, "text": chunk.choices[0].delta.content,
"is_finished": True, "is_finished": True,
"finish_reason": chunk.choices[0].finish_reason, "finish_reason": chunk.choices[0].finish_reason,
"original_chunk": chunk,
} }
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
@ -8676,80 +8704,11 @@ class CustomStreamWrapper:
model_response.choices[0].logprobs = response_obj["logprobs"] model_response.choices[0].logprobs = response_obj["logprobs"]
model_response.model = self.model model_response.model = self.model
print_verbose(
f"model_response: {model_response}; completion_obj: {completion_obj}"
)
print_verbose( print_verbose(
f"model_response finish reason 3: {model_response.choices[0].finish_reason}" f"model_response finish reason 3: {model_response.choices[0].finish_reason}"
) )
## FUNCTION CALL PARSING
if ( if (
len(completion_obj["content"]) > 0
): # cannot set content of an OpenAI Object to be an empty string
hold, model_response_str = self.check_special_tokens(
chunk=completion_obj["content"],
finish_reason=model_response.choices[0].finish_reason,
) # filter out bos/eos tokens from openai-compatible hf endpoints
print_verbose(
f"hold - {hold}, model_response_str - {model_response_str}"
)
if hold is False:
## check if openai/azure chunk
original_chunk = response_obj.get("original_chunk", None)
if original_chunk:
model_response.id = original_chunk.id
if len(original_chunk.choices) > 0:
try:
delta = dict(original_chunk.choices[0].delta)
print_verbose(f"original delta: {delta}")
model_response.choices[0].delta = Delta(**delta)
except Exception as e:
model_response.choices[0].delta = Delta()
else:
return
model_response.system_fingerprint = (
original_chunk.system_fingerprint
)
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
if self.sent_first_chunk == False:
model_response.choices[0].delta["role"] = "assistant"
self.sent_first_chunk = True
elif self.sent_first_chunk == True and hasattr(
model_response.choices[0].delta, "role"
):
_initial_delta = model_response.choices[
0
].delta.model_dump()
_initial_delta.pop("role", None)
model_response.choices[0].delta = Delta(**_initial_delta)
print_verbose(
f"model_response.choices[0].delta: {model_response.choices[0].delta}"
)
else:
## else
completion_obj["content"] = model_response_str
if self.sent_first_chunk == False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
print_verbose(f"returning model_response: {model_response}")
return model_response
else:
return
elif model_response.choices[0].finish_reason:
# flush any remaining holding chunk
if len(self.holding_chunk) > 0:
if model_response.choices[0].delta.content is None:
model_response.choices[0].delta.content = self.holding_chunk
else:
model_response.choices[0].delta.content = (
self.holding_chunk + model_response.choices[0].delta.content
)
self.holding_chunk = ""
model_response.choices[0].finish_reason = map_finish_reason(
model_response.choices[0].finish_reason
) # ensure consistent output to openai
return model_response
elif (
response_obj is not None response_obj is not None
and response_obj.get("original_chunk", None) is not None and response_obj.get("original_chunk", None) is not None
): # function / tool calling branch - only set for openai/azure compatible endpoints ): # function / tool calling branch - only set for openai/azure compatible endpoints
@ -8783,26 +8742,75 @@ class CustomStreamWrapper:
original_chunk.choices[0].delta.tool_calls, list original_chunk.choices[0].delta.tool_calls, list
): ):
for t in original_chunk.choices[0].delta.tool_calls: for t in original_chunk.choices[0].delta.tool_calls:
if ( if hasattr(t, "functions") and hasattr(
getattr( t.functions, "arguments"
t.function,
"arguments",
)
is None
): ):
t.function.arguments = "" if (
getattr(
t.function,
"arguments",
)
is None
):
t.function.arguments = ""
model_response.choices[0].delta = Delta(**delta) model_response.choices[0].delta = Delta(**delta)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
model_response.choices[0].delta = Delta() model_response.choices[0].delta = Delta()
else: else:
return try:
delta = dict(original_chunk.choices[0].delta)
print_verbose(f"original delta: {delta}")
model_response.choices[0].delta = Delta(**delta)
except Exception as e:
model_response.choices[0].delta = Delta()
else: else:
return return
model_response.system_fingerprint = original_chunk.system_fingerprint model_response.system_fingerprint = original_chunk.system_fingerprint
if self.sent_first_chunk == False: if self.sent_first_chunk == False:
model_response.choices[0].delta["role"] = "assistant" model_response.choices[0].delta["role"] = "assistant"
self.sent_first_chunk = True self.sent_first_chunk = True
## RETURN ARG
if (
response_obj.get("text", None) is not None
or response_obj.get("original_chunk", None) is not None
):
hold = False
if response_obj.get("content", None) is not None:
hold, model_response_str = self.check_special_tokens(
chunk=completion_obj["content"],
finish_reason=model_response.choices[0].finish_reason,
) # filter out bos/eos tokens from openai-compatible hf endpoints
print_verbose(
f"hold - {hold}, model_response_str - {model_response_str}"
)
if hold is False:
original_chunk = response_obj.get("original_chunk", None)
if original_chunk is None:
completion_obj["content"] = model_response_str
if self.sent_first_chunk == False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
print_verbose(f"returning model_response: {model_response}")
return model_response
else:
return
elif model_response.choices[0].finish_reason is not None:
# flush any remaining holding chunk
if len(self.holding_chunk) > 0:
if model_response.choices[0].delta.content is None:
model_response.choices[0].delta.content = self.holding_chunk
else:
model_response.choices[0].delta.content = (
self.holding_chunk + model_response.choices[0].delta.content
)
self.holding_chunk = ""
# get any function call arguments
model_response.choices[0].finish_reason = map_finish_reason(
model_response.choices[0].finish_reason
) # ensure consistent output to openai
return model_response return model_response
else: else:
return return