feat(databricks/chat/transformation.py): add tools and 'tool_choice' param support (#8076)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 38s

* feat(databricks/chat/transformation.py): add tools and 'tool_choice' param support

Closes https://github.com/BerriAI/litellm/issues/7788

* refactor: cleanup redundant file

* test: mark flaky test

* test: mark all parallel request tests as flaky
This commit is contained in:
Krish Dholakia 2025-01-29 21:09:07 -08:00 committed by GitHub
parent 9fa44a4fbe
commit ba8ba9eddb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 92 additions and 0 deletions

View file

@ -687,3 +687,85 @@ async def test_watsonx_tool_choice(sync_mode):
pytest.skip("Skipping test due to timeout")
else:
raise e
@pytest.mark.asyncio
async def test_function_calling_with_dbrx():
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=MagicMock()) as mock_completion:
try:
resp = await litellm.acompletion(
model="databricks/databricks-dbrx-instruct",
messages=[
{
"role": "system",
"content": "You are a helpful customer support assistant. Use the supplied tools to assist the user.",
},
{
"role": "user",
"content": "Hi, can you tell me the delivery date for my order?",
},
{
"role": "assistant",
"content": "Hi there! I can help with that. Can you please provide your order ID?",
},
{
"role": "user",
"content": "i think it is order_12345, also what is the weather in Phoenix, AZ?",
},
],
tools=[
{
"type": "function",
"function": {
"name": "get_delivery_date",
"description": "Get the delivery date for a customer'''s order. Call this whenever you need to know the delivery date, for example when a customer asks '''Where is my package'''",
"parameters": {
"type": "object",
"properties": {
"order_id": {
"type": "string",
"description": "The customer'''s order ID.",
}
},
"required": ["order_id"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "check_weather",
"description": "Check the current weather in a location. For example when asked: '''What is the temperature in San Fransisco, CA?'''",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to check the weather for.",
},
"state": {
"type": "string",
"description": "The state to check the weather for.",
},
},
"required": ["city", "state"],
"additionalProperties": False,
},
},
},
],
client=client,
tool_choice="auto",
)
except Exception as e:
print(e)
mock_completion.assert_called_once()
print(mock_completion.call_args.kwargs)
json_data = json.loads(mock_completion.call_args.kwargs["data"])
assert "tools" in json_data
assert "tool_choice" in json_data