feat(databricks/chat): support structured outputs on databricks

Closes https://github.com/BerriAI/litellm/pull/6978

- handles content as list for dbrx, - handles streaming+response_format for dbrx
This commit is contained in:
Krrish Dholakia 2024-12-02 18:23:05 -08:00
parent 12aea45447
commit 0caf804f4c
18 changed files with 538 additions and 193 deletions

View file

@ -4,7 +4,8 @@ import json
import pytest
import sys
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock, Mock, patch, ANY
import os
sys.path.insert(
@ -14,6 +15,7 @@ import litellm
from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper
from base_llm_unit_tests import BaseLLMChatTest
try:
import databricks.sdk
@ -333,6 +335,7 @@ def test_completions_with_async_http_handler(monkeypatch):
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
timeout=ANY,
data=json.dumps(
{
"model": "dbrx-instruct-071224",
@ -376,18 +379,22 @@ def test_completions_streaming_with_sync_http_handler(monkeypatch):
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "dbrx-instruct-071224",
"messages": messages,
"temperature": 0.5,
"stream": True,
"extraparam": "testpassingextraparam",
}
),
data=ANY,
stream=True,
)
actual_data = json.loads(
mock_post.call_args.kwargs["data"]
) # Deserialize the actual data
expected_data = {
"model": "dbrx-instruct-071224",
"messages": messages,
"temperature": 0.5,
"stream": True,
"extraparam": "testpassingextraparam",
}
assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}"
def test_completions_streaming_with_async_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
@ -429,21 +436,27 @@ def test_completions_streaming_with_async_http_handler(monkeypatch):
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "dbrx-instruct-071224",
"messages": messages,
"temperature": 0.5,
"stream": True,
"extraparam": "testpassingextraparam",
}
),
data=ANY,
stream=True,
)
actual_data = json.loads(
mock_post.call_args.kwargs["data"]
) # Deserialize the actual data
expected_data = {
"model": "dbrx-instruct-071224",
"messages": messages,
"temperature": 0.5,
"stream": True,
"extraparam": "testpassingextraparam",
}
assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}"
@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
monkeypatch.delenv("DATABRICKS_API_BASE")
monkeypatch.delenv("DATABRICKS_API_KEY")
from databricks.sdk import WorkspaceClient
from databricks.sdk.config import Config
@ -637,3 +650,48 @@ def test_embeddings_uses_databricks_sdk_if_api_key_and_base_not_specified(monkey
}
),
)
class TestDatabricksCompletion(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
return {"model": "databricks/databricks-dbrx-instruct"}
def test_pdf_handling(self, pdf_messages):
pytest.skip("Databricks does not support PDF handling")
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pytest.skip("Databricks is openai compatible")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_databricks_embeddings(sync_mode):
import openai
try:
litellm.set_verbose = True
litellm.drop_params = True
if sync_mode:
response = litellm.embedding(
model="databricks/databricks-bge-large-en",
input=["good morning from litellm"],
instruction="Represent this sentence for searching relevant passages:",
)
else:
response = await litellm.aembedding(
model="databricks/databricks-bge-large-en",
input=["good morning from litellm"],
instruction="Represent this sentence for searching relevant passages:",
)
print(f"response: {response}")
openai.types.CreateEmbeddingResponse.model_validate(
response.model_dump(), strict=True
)
# stubbed endpoint is setup to return this
# assert response.data[0]["embedding"] == [0.1, 0.2, 0.3]
except Exception as e:
pytest.fail(f"Error occurred: {e}")