mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(databricks/chat/transformation.py): add tools and 'tool_choice' param support (#8076)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 38s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 38s
* feat(databricks/chat/transformation.py): add tools and 'tool_choice' param support Closes https://github.com/BerriAI/litellm/issues/7788 * refactor: cleanup redundant file * test: mark flaky test * test: mark all parallel request tests as flaky
This commit is contained in:
parent
9fa44a4fbe
commit
ba8ba9eddb
3 changed files with 92 additions and 0 deletions
|
@ -687,3 +687,85 @@ async def test_watsonx_tool_choice(sync_mode):
|
|||
pytest.skip("Skipping test due to timeout")
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_calling_with_dbrx():
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
client = AsyncHTTPHandler()
|
||||
with patch.object(client, "post", return_value=MagicMock()) as mock_completion:
|
||||
try:
|
||||
resp = await litellm.acompletion(
|
||||
model="databricks/databricks-dbrx-instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful customer support assistant. Use the supplied tools to assist the user.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi, can you tell me the delivery date for my order?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hi there! I can help with that. Can you please provide your order ID?",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "i think it is order_12345, also what is the weather in Phoenix, AZ?",
|
||||
},
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_delivery_date",
|
||||
"description": "Get the delivery date for a customer'''s order. Call this whenever you need to know the delivery date, for example when a customer asks '''Where is my package'''",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"order_id": {
|
||||
"type": "string",
|
||||
"description": "The customer'''s order ID.",
|
||||
}
|
||||
},
|
||||
"required": ["order_id"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "check_weather",
|
||||
"description": "Check the current weather in a location. For example when asked: '''What is the temperature in San Fransisco, CA?'''",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to check the weather for.",
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "The state to check the weather for.",
|
||||
},
|
||||
},
|
||||
"required": ["city", "state"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
client=client,
|
||||
tool_choice="auto",
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
print(mock_completion.call_args.kwargs)
|
||||
json_data = json.loads(mock_completion.call_args.kwargs["data"])
|
||||
assert "tools" in json_data
|
||||
assert "tool_choice" in json_data
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue