mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(databricks/chat/transformation.py): add tools and 'tool_choice' param support (#8076)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 38s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 38s
* feat(databricks/chat/transformation.py): add tools and 'tool_choice' param support Closes https://github.com/BerriAI/litellm/issues/7788 * refactor: cleanup redundant file * test: mark flaky test * test: mark all parallel request tests as flaky
This commit is contained in:
parent
9fa44a4fbe
commit
ba8ba9eddb
3 changed files with 92 additions and 0 deletions
|
@ -73,6 +73,8 @@ class DatabricksConfig(OpenAILikeChatConfig):
|
|||
"max_completion_tokens",
|
||||
"n",
|
||||
"response_format",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
|
||||
def _should_fake_stream(self, optional_params: dict) -> bool:
|
||||
|
|
|
@ -687,3 +687,85 @@ async def test_watsonx_tool_choice(sync_mode):
|
|||
pytest.skip("Skipping test due to timeout")
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_calling_with_dbrx():
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
client = AsyncHTTPHandler()
|
||||
with patch.object(client, "post", return_value=MagicMock()) as mock_completion:
|
||||
try:
|
||||
resp = await litellm.acompletion(
|
||||
model="databricks/databricks-dbrx-instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful customer support assistant. Use the supplied tools to assist the user.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi, can you tell me the delivery date for my order?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hi there! I can help with that. Can you please provide your order ID?",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "i think it is order_12345, also what is the weather in Phoenix, AZ?",
|
||||
},
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_delivery_date",
|
||||
"description": "Get the delivery date for a customer'''s order. Call this whenever you need to know the delivery date, for example when a customer asks '''Where is my package'''",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"order_id": {
|
||||
"type": "string",
|
||||
"description": "The customer'''s order ID.",
|
||||
}
|
||||
},
|
||||
"required": ["order_id"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "check_weather",
|
||||
"description": "Check the current weather in a location. For example when asked: '''What is the temperature in San Fransisco, CA?'''",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to check the weather for.",
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "The state to check the weather for.",
|
||||
},
|
||||
},
|
||||
"required": ["city", "state"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
client=client,
|
||||
tool_choice="auto",
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
print(mock_completion.call_args.kwargs)
|
||||
json_data = json.loads(mock_completion.call_args.kwargs["data"])
|
||||
assert "tools" in json_data
|
||||
assert "tool_choice" in json_data
|
||||
|
|
|
@ -454,6 +454,7 @@ Test with Router
|
|||
"""
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=6, delay=2)
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_router_call():
|
||||
model_list = [
|
||||
|
@ -528,6 +529,7 @@ async def test_normal_router_call():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=6, delay=2)
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_router_tpm_limit():
|
||||
import logging
|
||||
|
@ -615,6 +617,7 @@ async def test_normal_router_tpm_limit():
|
|||
assert e.status_code == 429
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=6, delay=2)
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_router_call():
|
||||
model_list = [
|
||||
|
@ -690,6 +693,7 @@ async def test_streaming_router_call():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=6, delay=2)
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_router_tpm_limit():
|
||||
litellm.set_verbose = True
|
||||
|
@ -845,6 +849,7 @@ async def test_bad_router_call():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=6, delay=2)
|
||||
@pytest.mark.asyncio
|
||||
async def test_bad_router_tpm_limit():
|
||||
model_list = [
|
||||
|
@ -923,6 +928,7 @@ async def test_bad_router_tpm_limit():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=6, delay=2)
|
||||
@pytest.mark.asyncio
|
||||
async def test_bad_router_tpm_limit_per_model():
|
||||
model_list = [
|
||||
|
@ -1023,6 +1029,7 @@ async def test_bad_router_tpm_limit_per_model():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=6, delay=2)
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_call_hook_rpm_limits_per_model():
|
||||
"""
|
||||
|
@ -1101,6 +1108,7 @@ async def test_pre_call_hook_rpm_limits_per_model():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=6, delay=2)
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_call_hook_tpm_limits_per_model():
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue