mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(databricks/chat): support structured outputs on databricks
Closes https://github.com/BerriAI/litellm/pull/6978 - handles content as list for dbrx, - handles streaming+response_format for dbrx
This commit is contained in:
parent
12aea45447
commit
0caf804f4c
18 changed files with 538 additions and 193 deletions
|
@ -4,7 +4,8 @@ import json
|
|||
import pytest
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch, ANY
|
||||
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
|
@ -14,6 +15,7 @@ import litellm
|
|||
from litellm.exceptions import BadRequestError
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
|
||||
try:
|
||||
import databricks.sdk
|
||||
|
@ -333,6 +335,7 @@ def test_completions_with_async_http_handler(monkeypatch):
|
|||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=ANY,
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "dbrx-instruct-071224",
|
||||
|
@ -376,18 +379,22 @@ def test_completions_streaming_with_sync_http_handler(monkeypatch):
|
|||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "dbrx-instruct-071224",
|
||||
"messages": messages,
|
||||
"temperature": 0.5,
|
||||
"stream": True,
|
||||
"extraparam": "testpassingextraparam",
|
||||
}
|
||||
),
|
||||
data=ANY,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
actual_data = json.loads(
|
||||
mock_post.call_args.kwargs["data"]
|
||||
) # Deserialize the actual data
|
||||
expected_data = {
|
||||
"model": "dbrx-instruct-071224",
|
||||
"messages": messages,
|
||||
"temperature": 0.5,
|
||||
"stream": True,
|
||||
"extraparam": "testpassingextraparam",
|
||||
}
|
||||
assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}"
|
||||
|
||||
|
||||
def test_completions_streaming_with_async_http_handler(monkeypatch):
|
||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||
|
@ -429,21 +436,27 @@ def test_completions_streaming_with_async_http_handler(monkeypatch):
|
|||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "dbrx-instruct-071224",
|
||||
"messages": messages,
|
||||
"temperature": 0.5,
|
||||
"stream": True,
|
||||
"extraparam": "testpassingextraparam",
|
||||
}
|
||||
),
|
||||
data=ANY,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
actual_data = json.loads(
|
||||
mock_post.call_args.kwargs["data"]
|
||||
) # Deserialize the actual data
|
||||
expected_data = {
|
||||
"model": "dbrx-instruct-071224",
|
||||
"messages": messages,
|
||||
"temperature": 0.5,
|
||||
"stream": True,
|
||||
"extraparam": "testpassingextraparam",
|
||||
}
|
||||
assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
|
||||
def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
|
||||
monkeypatch.delenv("DATABRICKS_API_BASE")
|
||||
monkeypatch.delenv("DATABRICKS_API_KEY")
|
||||
from databricks.sdk import WorkspaceClient
|
||||
from databricks.sdk.config import Config
|
||||
|
||||
|
@ -637,3 +650,48 @@ def test_embeddings_uses_databricks_sdk_if_api_key_and_base_not_specified(monkey
|
|||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestDatabricksCompletion(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
return {"model": "databricks/databricks-dbrx-instruct"}
|
||||
|
||||
def test_pdf_handling(self, pdf_messages):
|
||||
pytest.skip("Databricks does not support PDF handling")
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pytest.skip("Databricks is openai compatible")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_databricks_embeddings(sync_mode):
|
||||
import openai
|
||||
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
litellm.drop_params = True
|
||||
|
||||
if sync_mode:
|
||||
response = litellm.embedding(
|
||||
model="databricks/databricks-bge-large-en",
|
||||
input=["good morning from litellm"],
|
||||
instruction="Represent this sentence for searching relevant passages:",
|
||||
)
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
model="databricks/databricks-bge-large-en",
|
||||
input=["good morning from litellm"],
|
||||
instruction="Represent this sentence for searching relevant passages:",
|
||||
)
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
openai.types.CreateEmbeddingResponse.model_validate(
|
||||
response.model_dump(), strict=True
|
||||
)
|
||||
# stubbed endpoint is setup to return this
|
||||
# assert response.data[0]["embedding"] == [0.1, 0.2, 0.3]
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue