forked from phoenix/litellm-mirror
Merge pull request #2175 from BerriAI/litellm_stricter_function_calling_streaming_tests
fix(utils.py): stricter azure function calling tests
This commit is contained in:
commit
a83bc2fd1e
3 changed files with 262 additions and 33 deletions
|
@ -206,6 +206,7 @@ def test_async_custom_handler_stream():
|
||||||
# test_async_custom_handler_stream()
|
# test_async_custom_handler_stream()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Flaky test")
|
||||||
def test_azure_completion_stream():
|
def test_azure_completion_stream():
|
||||||
# [PROD Test] - Do not DELETE
|
# [PROD Test] - Do not DELETE
|
||||||
# test if completion() + sync custom logger get the same complete stream response
|
# test if completion() + sync custom logger get the same complete stream response
|
||||||
|
|
|
@ -1655,6 +1655,202 @@ def test_openai_streaming_and_function_calling():
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
# test_azure_streaming_and_function_calling()
|
||||||
|
|
||||||
|
|
||||||
|
def test_success_callback_streaming():
|
||||||
|
def success_callback(kwargs, completion_response, start_time, end_time):
|
||||||
|
print(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"input": kwargs,
|
||||||
|
"output": completion_response,
|
||||||
|
"start_time": start_time,
|
||||||
|
"end_time": end_time,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
litellm.success_callback = [success_callback]
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "hello"}]
|
||||||
|
print("TESTING LITELLM COMPLETION CALL")
|
||||||
|
response = litellm.completion(
|
||||||
|
model="j2-light",
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=5,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk["choices"][0])
|
||||||
|
|
||||||
|
|
||||||
|
# test_success_callback_streaming()
|
||||||
|
|
||||||
|
#### STREAMING + FUNCTION CALLING ###
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class Function(BaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCalls(BaseModel):
|
||||||
|
index: int
|
||||||
|
id: str
|
||||||
|
type: str
|
||||||
|
function: Function
|
||||||
|
|
||||||
|
|
||||||
|
class Delta(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: Optional[str]
|
||||||
|
tool_calls: List[ToolCalls]
|
||||||
|
|
||||||
|
|
||||||
|
class Choices(BaseModel):
|
||||||
|
index: int
|
||||||
|
delta: Delta
|
||||||
|
logprobs: Optional[str]
|
||||||
|
finish_reason: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
system_fingerprint: str
|
||||||
|
choices: List[Choices]
|
||||||
|
|
||||||
|
|
||||||
|
def validate_first_streaming_function_calling_chunk(chunk: ModelResponse):
|
||||||
|
chunk_instance = Chunk(**chunk.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
### Chunk 1
|
||||||
|
|
||||||
|
|
||||||
|
# {
|
||||||
|
# "id": "chatcmpl-8vdVjtzxc0JqGjq93NxC79dMp6Qcs",
|
||||||
|
# "object": "chat.completion.chunk",
|
||||||
|
# "created": 1708747267,
|
||||||
|
# "model": "gpt-3.5-turbo-0125",
|
||||||
|
# "system_fingerprint": "fp_86156a94a0",
|
||||||
|
# "choices": [
|
||||||
|
# {
|
||||||
|
# "index": 0,
|
||||||
|
# "delta": {
|
||||||
|
# "role": "assistant",
|
||||||
|
# "content": null,
|
||||||
|
# "tool_calls": [
|
||||||
|
# {
|
||||||
|
# "index": 0,
|
||||||
|
# "id": "call_oN10vaaC9iA8GLFRIFwjCsN7",
|
||||||
|
# "type": "function",
|
||||||
|
# "function": {
|
||||||
|
# "name": "get_current_weather",
|
||||||
|
# "arguments": ""
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# },
|
||||||
|
# "logprobs": null,
|
||||||
|
# "finish_reason": null
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
class Function2(BaseModel):
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCalls2(BaseModel):
|
||||||
|
index: int
|
||||||
|
function: Optional[Function2]
|
||||||
|
|
||||||
|
|
||||||
|
class Delta2(BaseModel):
|
||||||
|
tool_calls: List[ToolCalls2]
|
||||||
|
|
||||||
|
|
||||||
|
class Choices2(BaseModel):
|
||||||
|
index: int
|
||||||
|
delta: Delta2
|
||||||
|
logprobs: Optional[str]
|
||||||
|
finish_reason: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk2(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
system_fingerprint: str
|
||||||
|
choices: List[Choices2]
|
||||||
|
|
||||||
|
|
||||||
|
## Chunk 2
|
||||||
|
|
||||||
|
# {
|
||||||
|
# "id": "chatcmpl-8vdVjtzxc0JqGjq93NxC79dMp6Qcs",
|
||||||
|
# "object": "chat.completion.chunk",
|
||||||
|
# "created": 1708747267,
|
||||||
|
# "model": "gpt-3.5-turbo-0125",
|
||||||
|
# "system_fingerprint": "fp_86156a94a0",
|
||||||
|
# "choices": [
|
||||||
|
# {
|
||||||
|
# "index": 0,
|
||||||
|
# "delta": {
|
||||||
|
# "tool_calls": [
|
||||||
|
# {
|
||||||
|
# "index": 0,
|
||||||
|
# "function": {
|
||||||
|
# "arguments": "{\""
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# },
|
||||||
|
# "logprobs": null,
|
||||||
|
# "finish_reason": null
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
def validate_second_streaming_function_calling_chunk(chunk: ModelResponse):
|
||||||
|
chunk_instance = Chunk2(**chunk.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
class Delta3(BaseModel):
|
||||||
|
content: Optional[str] = None
|
||||||
|
role: Optional[str] = None
|
||||||
|
function_call: Optional[dict] = None
|
||||||
|
tool_calls: Optional[List] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Choices3(BaseModel):
|
||||||
|
index: int
|
||||||
|
delta: Delta3
|
||||||
|
logprobs: Optional[str]
|
||||||
|
finish_reason: str
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk3(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
system_fingerprint: str
|
||||||
|
choices: List[Choices3]
|
||||||
|
|
||||||
|
|
||||||
|
def validate_final_streaming_function_calling_chunk(chunk: ModelResponse):
|
||||||
|
chunk_instance = Chunk3(**chunk.model_dump())
|
||||||
|
|
||||||
|
|
||||||
def test_azure_streaming_and_function_calling():
|
def test_azure_streaming_and_function_calling():
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
|
@ -1690,6 +1886,7 @@ def test_azure_streaming_and_function_calling():
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
for idx, chunk in enumerate(response):
|
for idx, chunk in enumerate(response):
|
||||||
|
print(f"chunk: {chunk}")
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
assert (
|
assert (
|
||||||
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
|
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
|
||||||
|
@ -1697,40 +1894,69 @@ def test_azure_streaming_and_function_calling():
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
chunk.choices[0].delta.tool_calls[0].function.arguments, str
|
chunk.choices[0].delta.tool_calls[0].function.arguments, str
|
||||||
)
|
)
|
||||||
|
validate_first_streaming_function_calling_chunk(chunk=chunk)
|
||||||
|
elif idx == 1:
|
||||||
|
validate_second_streaming_function_calling_chunk(chunk=chunk)
|
||||||
|
elif chunk.choices[0].finish_reason is not None: # last chunk
|
||||||
|
validate_final_streaming_function_calling_chunk(chunk=chunk)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
# test_azure_streaming_and_function_calling()
|
@pytest.mark.asyncio
|
||||||
|
async def test_azure_astreaming_and_function_calling():
|
||||||
|
tools = [
|
||||||
def test_success_callback_streaming():
|
{
|
||||||
def success_callback(kwargs, completion_response, start_time, end_time):
|
"type": "function",
|
||||||
print(
|
"function": {
|
||||||
{
|
"name": "get_current_weather",
|
||||||
"success": True,
|
"description": "Get the current weather in a given location",
|
||||||
"input": kwargs,
|
"parameters": {
|
||||||
"output": completion_response,
|
"type": "object",
|
||||||
"start_time": start_time,
|
"properties": {
|
||||||
"end_time": end_time,
|
"location": {
|
||||||
}
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="azure/gpt-4-nov-release",
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
api_base=os.getenv("AZURE_FRANCE_API_BASE"),
|
||||||
|
api_key=os.getenv("AZURE_FRANCE_API_KEY"),
|
||||||
|
api_version="2024-02-15-preview",
|
||||||
)
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
idx = 0
|
||||||
|
async for chunk in response:
|
||||||
|
print(f"chunk: {chunk}")
|
||||||
|
if idx == 0:
|
||||||
|
assert (
|
||||||
|
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
|
||||||
|
)
|
||||||
|
assert isinstance(
|
||||||
|
chunk.choices[0].delta.tool_calls[0].function.arguments, str
|
||||||
|
)
|
||||||
|
validate_first_streaming_function_calling_chunk(chunk=chunk)
|
||||||
|
elif idx == 1:
|
||||||
|
validate_second_streaming_function_calling_chunk(chunk=chunk)
|
||||||
|
elif chunk.choices[0].finish_reason is not None: # last chunk
|
||||||
|
validate_final_streaming_function_calling_chunk(chunk=chunk)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
litellm.success_callback = [success_callback]
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
messages = [{"role": "user", "content": "hello"}]
|
raise e
|
||||||
print("TESTING LITELLM COMPLETION CALL")
|
|
||||||
response = litellm.completion(
|
|
||||||
model="j2-light",
|
|
||||||
messages=messages,
|
|
||||||
stream=True,
|
|
||||||
max_tokens=5,
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
for chunk in response:
|
|
||||||
print(chunk["choices"][0])
|
|
||||||
|
|
||||||
|
|
||||||
# test_success_callback_streaming()
|
|
||||||
|
|
|
@ -376,11 +376,9 @@ class StreamingChoices(OpenAIObject):
|
||||||
self.delta = delta
|
self.delta = delta
|
||||||
else:
|
else:
|
||||||
self.delta = Delta()
|
self.delta = Delta()
|
||||||
|
|
||||||
if logprobs is not None:
|
|
||||||
self.logprobs = logprobs
|
|
||||||
if enhancements is not None:
|
if enhancements is not None:
|
||||||
self.enhancements = enhancements
|
self.enhancements = enhancements
|
||||||
|
self.logprobs = logprobs
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
# Define custom behavior for the 'in' operator
|
# Define custom behavior for the 'in' operator
|
||||||
|
@ -8631,6 +8629,10 @@ class CustomStreamWrapper:
|
||||||
model_response.choices[0].finish_reason = response_obj[
|
model_response.choices[0].finish_reason = response_obj[
|
||||||
"finish_reason"
|
"finish_reason"
|
||||||
]
|
]
|
||||||
|
if response_obj.get("original_chunk", None) is not None:
|
||||||
|
model_response.system_fingerprint = getattr(
|
||||||
|
response_obj["original_chunk"], "system_fingerprint", None
|
||||||
|
)
|
||||||
if response_obj["logprobs"] is not None:
|
if response_obj["logprobs"] is not None:
|
||||||
model_response.choices[0].logprobs = response_obj["logprobs"]
|
model_response.choices[0].logprobs = response_obj["logprobs"]
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue