This commit is contained in:
Kaushik Deka 2025-04-24 01:00:35 -07:00 committed by GitHub
commit a1dcbe95e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 60 additions and 1 deletions

View file

@ -1339,8 +1339,9 @@ class CustomStreamWrapper:
and "function" in tool
and isinstance(tool["function"], dict)
and ("type" not in tool or tool["type"] is None)
and tool.get('id')
):
# if function returned but type set to None - mistral's api returns type: None
# Refer to the test test_function_calling_tool_type
tool["type"] = "function"
model_response.choices[0].delta = Delta(**_json_delta)
except Exception as e:

View file

@ -880,3 +880,61 @@ async def test_function_calling_with_dbrx():
json_data = json.loads(mock_completion.call_args.kwargs["data"])
assert "tools" in json_data
assert "tool_choice" in json_data
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model",
[
"gpt-4-1106-preview",
"mistral/mistral-small-latest",
],
)
async def test_function_calling_tool_type(model):
"""
Test the API's response when streaming tool calls.
- The first streamed chunk contains a tool call with an ID and type "function".
- Subsequent chunks may only contain arguments, with ID and type being None.
"""
response = completion(
model=model,
messages=[{
"role": "user",
"content": "Search for recent AI breakthroughs"
}],
tools=[{
"type": "function",
"function": {
"name": "web_search",
"description": "Searches the web",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"}
},
"required": ["query"]
}
}
}],
stream=True,
temperature=0.00000001
)
async for chunk in response:
print(json.dumps(chunk, indent=2, default=str))
if "choices" not in chunk or not chunk.choices:
raise ValueError("Unexpected chunk structure: 'choices' missing or empty")
delta = chunk.choices[0].delta
tool_calls = getattr(delta, "tool_calls", None)
if not tool_calls or not isinstance(tool_calls, list):
continue
tool_call = tool_calls[0]
id = getattr(tool_call, "id", None)
type = getattr(tool_call, "type", None)
if id:
assert type == "function", f"Expected type 'function' for id {id}, got '{type}'"