feat(databricks/chat): support structured outputs on databricks

Closes https://github.com/BerriAI/litellm/pull/6978

- handles content as list for dbrx, - handles streaming+response_format for dbrx
This commit is contained in:
Krrish Dholakia 2024-12-02 18:23:05 -08:00
parent 12aea45447
commit 0caf804f4c
18 changed files with 538 additions and 193 deletions

View file

@ -1021,7 +1021,8 @@ from .llms.anthropic.experimental_pass_through.transformation import (
) )
from .llms.groq.stt.transformation import GroqSTTConfig from .llms.groq.stt.transformation import GroqSTTConfig
from .llms.anthropic.completion import AnthropicTextConfig from .llms.anthropic.completion import AnthropicTextConfig
from .llms.databricks.chat import DatabricksConfig, DatabricksEmbeddingConfig from .llms.databricks.chat.transformation import DatabricksConfig
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
from .llms.predibase import PredibaseConfig from .llms.predibase import PredibaseConfig
from .llms.replicate import ReplicateConfig from .llms.replicate import ReplicateConfig
from .llms.cohere.completion import CohereConfig from .llms.cohere.completion import CohereConfig

View file

@ -178,7 +178,7 @@ def get_supported_openai_params( # noqa: PLR0915
] ]
elif custom_llm_provider == "databricks": elif custom_llm_provider == "databricks":
if request_type == "chat_completion": if request_type == "chat_completion":
return litellm.DatabricksConfig().get_supported_openai_params() return litellm.DatabricksConfig().get_supported_openai_params(model=model)
elif request_type == "embeddings": elif request_type == "embeddings":
return litellm.DatabricksEmbeddingConfig().get_supported_openai_params() return litellm.DatabricksEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini": elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":

View file

@ -0,0 +1,82 @@
"""
Handles the chat completion request for Databricks
"""
from typing import Any, Callable, Literal, Optional, Tuple, Union
from httpx._config import Timeout
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import CustomStreamingDecoder
from litellm.utils import ModelResponse
from ...openai_like.chat.handler import OpenAILikeChatHandler
from ..common_utils import DatabricksBase
from ..exceptions import DatabricksError
from .transformation import DatabricksConfig
class DatabricksChatCompletion(OpenAILikeChatHandler, DatabricksBase):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def completion(
self,
*,
model: str,
messages: list,
api_base: str,
custom_llm_provider: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key: Optional[str],
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
timeout: Optional[Union[float, Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
custom_endpoint: Optional[bool] = None,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
fake_stream: bool = False,
):
messages = DatabricksConfig()._transform_messages(messages) # type: ignore
api_base, headers = self.databricks_validate_environment(
api_base=api_base,
api_key=api_key,
endpoint_type="chat_completions",
custom_endpoint=custom_endpoint,
headers=headers,
)
if optional_params.get("stream") is True:
fake_stream = DatabricksConfig()._should_fake_stream(optional_params)
else:
fake_stream = False
return super().completion(
model=model,
messages=messages,
api_base=api_base,
custom_llm_provider=custom_llm_provider,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client,
custom_endpoint=True,
streaming_decoder=streaming_decoder,
fake_stream=fake_stream,
)

View file

@ -13,6 +13,7 @@ import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
import litellm import litellm
from litellm import LlmProviders
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
@ -33,141 +34,17 @@ from litellm.types.utils import (
GenericStreamingChunk, GenericStreamingChunk,
ProviderField, ProviderField,
) )
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage from litellm.utils import (
CustomStreamWrapper,
from ..base import BaseLLM EmbeddingResponse,
from ..prompt_templates.factory import custom_prompt, prompt_factory ModelResponse,
ProviderConfigManager,
Usage,
class DatabricksConfig:
"""
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
"""
max_tokens: Optional[int] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
top_k: Optional[int] = None
stop: Optional[Union[List[str], str]] = None
n: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
stop: Optional[Union[List[str], str]] = None,
n: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Databricks API Key.",
field_value="dapi...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Databricks API Base.",
field_value="https://adb-..",
),
]
def get_supported_openai_params(self):
return [
"stream",
"stop",
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"n",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "n":
optional_params["n"] = value
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop"] = value
return optional_params
class DatabricksEmbeddingConfig:
"""
Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
"""
instruction: Optional[str] = (
None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
) )
def __init__(self, instruction: Optional[str] = None) -> None: from ...base import BaseLLM
locals_ = locals() from ...prompt_templates.factory import custom_prompt, prompt_factory
for key, value in locals_.items(): from .transformation import DatabricksConfig
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(
self,
): # no optional openai embedding params supported
return []
def map_openai_params(self, non_default_params: dict, optional_params: dict):
return optional_params
async def make_call( async def make_call(
@ -477,6 +354,12 @@ class DatabricksChatCompletion(BaseLLM):
) # [TODO] add max retry support at llm api call level ) # [TODO] add max retry support at llm api call level
optional_params["stream"] = stream optional_params["stream"] = stream
if messages is not None and custom_llm_provider is not None:
provider_config = ProviderConfigManager.get_provider_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
messages = provider_config._transform_messages(messages)
data = { data = {
"model": model, "model": model,
"messages": messages, "messages": messages,

View file

@ -0,0 +1,143 @@
"""
Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions`
"""
import types
from typing import List, Optional, Union
from pydantic import BaseModel
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ProviderField
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
from ...prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion,
strip_name_from_messages,
)
class DatabricksConfig(OpenAIGPTConfig):
"""
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
"""
max_tokens: Optional[int] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
top_k: Optional[int] = None
stop: Optional[Union[List[str], str]] = None
n: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
stop: Optional[Union[List[str], str]] = None,
n: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Databricks API Key.",
field_value="dapi...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Databricks API Base.",
field_value="https://adb-..",
),
]
def get_supported_openai_params(self, model: Optional[str] = None) -> list:
return [
"stream",
"stop",
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"n",
"response_format",
]
def _should_fake_stream(self, optional_params: dict) -> bool:
"""
Databricks doesn't support 'response_format' while streaming
"""
if optional_params.get("response_format") is not None:
return True
return False
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
):
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "n":
optional_params["n"] = value
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop"] = value
if param == "response_format":
optional_params["response_format"] = value
return optional_params
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"""
Databricks does not support:
- content in list format.
- 'name' in user message.
"""
new_messages = []
for idx, message in enumerate(messages):
if isinstance(message, BaseModel):
_message = message.model_dump()
else:
_message = message
new_messages.append(_message)
new_messages = handle_messages_with_content_list_to_str_conversion(new_messages)
new_messages = strip_name_from_messages(new_messages)
return super()._transform_messages(new_messages)

View file

@ -0,0 +1,82 @@
from typing import Literal, Optional, Tuple
from .exceptions import DatabricksError
class DatabricksBase:
def _get_databricks_credentials(
self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict]
) -> Tuple[str, dict]:
headers = headers or {"Content-Type": "application/json"}
try:
from databricks.sdk import WorkspaceClient
databricks_client = WorkspaceClient()
api_base = api_base or f"{databricks_client.config.host}/serving-endpoints"
if api_key is None:
databricks_auth_headers: dict[str, str] = (
databricks_client.config.authenticate()
)
headers = {**databricks_auth_headers, **headers}
return api_base, headers
except ImportError:
raise DatabricksError(
status_code=400,
message=(
"If the Databricks base URL and API key are not set, the databricks-sdk "
"Python library must be installed. Please install the databricks-sdk, set "
"{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, "
"or provide the base URL and API key as arguments."
),
)
def databricks_validate_environment(
self,
api_key: Optional[str],
api_base: Optional[str],
endpoint_type: Literal["chat_completions", "embeddings"],
custom_endpoint: Optional[bool],
headers: Optional[dict],
) -> Tuple[str, dict]:
if api_key is None and headers is None:
if custom_endpoint is not None:
raise DatabricksError(
status_code=400,
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)
else:
api_base, headers = self._get_databricks_credentials(
api_base=api_base, api_key=api_key, headers=headers
)
if api_base is None:
if custom_endpoint:
raise DatabricksError(
status_code=400,
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)
else:
api_base, headers = self._get_databricks_credentials(
api_base=api_base, api_key=api_key, headers=headers
)
if headers is None:
headers = {
"Authorization": "Bearer {}".format(api_key),
"Content-Type": "application/json",
}
else:
if api_key is not None:
headers.update({"Authorization": "Bearer {}".format(api_key)})
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
if endpoint_type == "chat_completions" and custom_endpoint is not True:
api_base = "{}/chat/completions".format(api_base)
elif endpoint_type == "embeddings" and custom_endpoint is not True:
api_base = "{}/embeddings".format(api_base)
return api_base, headers

View file

@ -0,0 +1,50 @@
"""
Calling logic for Databricks embeddings
"""
from typing import Optional
import litellm
from litellm.utils import EmbeddingResponse
from ...openai_like.embedding.handler import OpenAILikeEmbeddingHandler
from ..common_utils import DatabricksBase
class DatabricksEmbeddingHandler(OpenAILikeEmbeddingHandler, DatabricksBase):
def embedding(
self,
model: str,
input: list,
timeout: float,
logging_obj,
api_key: Optional[str],
api_base: Optional[str],
optional_params: dict,
model_response: Optional[litellm.utils.EmbeddingResponse] = None,
client=None,
aembedding=None,
custom_endpoint: Optional[bool] = None,
headers: Optional[dict] = None,
) -> EmbeddingResponse:
api_base, headers = self.databricks_validate_environment(
api_base=api_base,
api_key=api_key,
endpoint_type="embeddings",
custom_endpoint=custom_endpoint,
headers=headers,
)
return super().embedding(
model=model,
input=input,
timeout=timeout,
logging_obj=logging_obj,
api_key=api_key,
api_base=api_base,
optional_params=optional_params,
model_response=model_response,
client=client,
aembedding=aembedding,
custom_endpoint=True,
headers=headers,
)

View file

@ -0,0 +1,48 @@
"""
Translates from OpenAI's `/v1/embeddings` to Databricks' `/embeddings`
"""
import types
from typing import Optional
class DatabricksEmbeddingConfig:
"""
Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
"""
instruction: Optional[str] = (
None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
)
def __init__(self, instruction: Optional[str] = None) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(
self,
): # no optional openai embedding params supported
return []
def map_openai_params(self, non_default_params: dict, optional_params: dict):
return optional_params

View file

@ -37,6 +37,22 @@ def handle_messages_with_content_list_to_str_conversion(
return messages return messages
def strip_name_from_messages(
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
"""
Removes 'name' from messages
"""
new_messages = []
for message in messages:
msg_role = message.get("role")
msg_copy = message.copy()
if msg_role == "user":
msg_copy.pop("name", None) # type: ignore
new_messages.append(msg_copy)
return new_messages
def convert_content_list_to_str(message: AllMessageValues) -> str: def convert_content_list_to_str(message: AllMessageValues) -> str:
""" """
- handles scenario where content is list and not string - handles scenario where content is list and not string

View file

@ -273,7 +273,7 @@ class SagemakerLLM(BaseAWSLLM):
model_id = optional_params.get("model_id", None) model_id = optional_params.get("model_id", None)
if use_messages_api is True: if use_messages_api is True:
from litellm.llms.databricks.chat import DatabricksChatCompletion from litellm.llms.databricks.chat.handler import DatabricksChatCompletion
openai_like_chat_completions = DatabricksChatCompletion() openai_like_chat_completions = DatabricksChatCompletion()
inference_params["stream"] = True if stream is True else False inference_params["stream"] = True if stream is True else False

View file

@ -90,7 +90,7 @@ class VertexAIPartnerModels(VertexBase):
from google.cloud import aiplatform from google.cloud import aiplatform
from litellm.llms.anthropic.chat import AnthropicChatCompletion from litellm.llms.anthropic.chat import AnthropicChatCompletion
from litellm.llms.databricks.chat import DatabricksChatCompletion from litellm.llms.databricks.chat.handler import DatabricksChatCompletion
from litellm.llms.OpenAI.openai import OpenAIChatCompletion from litellm.llms.OpenAI.openai import OpenAIChatCompletion
from litellm.llms.text_completion_codestral import CodestralTextCompletion from litellm.llms.text_completion_codestral import CodestralTextCompletion
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (

View file

@ -76,7 +76,7 @@ class VertexAIModelGardenModels(VertexBase):
from google.cloud import aiplatform from google.cloud import aiplatform
from litellm.llms.anthropic.chat import AnthropicChatCompletion from litellm.llms.anthropic.chat import AnthropicChatCompletion
from litellm.llms.databricks.chat import DatabricksChatCompletion from litellm.llms.databricks.chat.handler import DatabricksChatCompletion
from litellm.llms.OpenAI.openai import OpenAIChatCompletion from litellm.llms.OpenAI.openai import OpenAIChatCompletion
from litellm.llms.text_completion_codestral import CodestralTextCompletion from litellm.llms.text_completion_codestral import CodestralTextCompletion
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (

View file

@ -115,7 +115,8 @@ from .llms.cohere import chat as cohere_chat
from .llms.cohere import completion as cohere_completion # type: ignore from .llms.cohere import completion as cohere_completion # type: ignore
from .llms.cohere.embed import handler as cohere_embed from .llms.cohere.embed import handler as cohere_embed
from .llms.custom_llm import CustomLLM, custom_chat_llm_router from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks.chat import DatabricksChatCompletion from .llms.databricks.chat.handler import DatabricksChatCompletion
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
from .llms.groq.chat.handler import GroqChatCompletion from .llms.groq.chat.handler import GroqChatCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription
@ -230,6 +231,7 @@ watsonxai = IBMWatsonXAI()
sagemaker_llm = SagemakerLLM() sagemaker_llm = SagemakerLLM()
watsonx_chat_completion = WatsonXChatHandler() watsonx_chat_completion = WatsonXChatHandler()
openai_like_embedding = OpenAILikeEmbeddingHandler() openai_like_embedding = OpenAILikeEmbeddingHandler()
databricks_embedding = DatabricksEmbeddingHandler()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -3475,7 +3477,7 @@ def embedding( # noqa: PLR0915
) # type: ignore ) # type: ignore
## EMBEDDING CALL ## EMBEDDING CALL
response = databricks_chat_completions.embedding( response = databricks_embedding.embedding(
model=model, model=model,
input=input, input=input,
api_base=api_base, api_base=api_base,

View file

@ -3418,7 +3418,14 @@ def get_optional_params( # noqa: PLR0915
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
optional_params = litellm.DatabricksConfig().map_openai_params( optional_params = litellm.DatabricksConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif custom_llm_provider == "nvidia_nim": elif custom_llm_provider == "nvidia_nim":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -6182,6 +6189,8 @@ class ProviderConfigManager:
return litellm.DeepSeekChatConfig() return litellm.DeepSeekChatConfig()
elif litellm.LlmProviders.GROQ == provider: elif litellm.LlmProviders.GROQ == provider:
return litellm.GroqChatConfig() return litellm.GroqChatConfig()
elif litellm.LlmProviders.DATABRICKS == provider:
return litellm.DatabricksConfig()
return OpenAIGPTConfig() return OpenAIGPTConfig()

View file

@ -55,6 +55,7 @@ class BaseLLMChatTest(ABC):
assert response.choices[0].message.content is not None assert response.choices[0].message.content is not None
def test_message_with_name(self): def test_message_with_name(self):
litellm.set_verbose = True
base_completion_call_args = self.get_base_completion_call_args() base_completion_call_args = self.get_base_completion_call_args()
messages = [ messages = [
{"role": "user", "content": "Hello", "name": "test_name"}, {"role": "user", "content": "Hello", "name": "test_name"},
@ -69,6 +70,7 @@ class BaseLLMChatTest(ABC):
{"type": "text"}, {"type": "text"},
], ],
) )
@pytest.mark.flaky(retries=6, delay=1)
def test_json_response_format(self, response_format): def test_json_response_format(self, response_format):
""" """
Test that the JSON response format is supported by the LLM API Test that the JSON response format is supported by the LLM API

View file

@ -4,7 +4,8 @@ import json
import pytest import pytest
import sys import sys
from typing import Any, Dict, List from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch, ANY
import os import os
sys.path.insert( sys.path.insert(
@ -14,6 +15,7 @@ import litellm
from litellm.exceptions import BadRequestError from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper from litellm.utils import CustomStreamWrapper
from base_llm_unit_tests import BaseLLMChatTest
try: try:
import databricks.sdk import databricks.sdk
@ -333,6 +335,7 @@ def test_completions_with_async_http_handler(monkeypatch):
"Authorization": f"Bearer {api_key}", "Authorization": f"Bearer {api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
timeout=ANY,
data=json.dumps( data=json.dumps(
{ {
"model": "dbrx-instruct-071224", "model": "dbrx-instruct-071224",
@ -376,17 +379,21 @@ def test_completions_streaming_with_sync_http_handler(monkeypatch):
"Authorization": f"Bearer {api_key}", "Authorization": f"Bearer {api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
data=json.dumps( data=ANY,
{ stream=True,
)
actual_data = json.loads(
mock_post.call_args.kwargs["data"]
) # Deserialize the actual data
expected_data = {
"model": "dbrx-instruct-071224", "model": "dbrx-instruct-071224",
"messages": messages, "messages": messages,
"temperature": 0.5, "temperature": 0.5,
"stream": True, "stream": True,
"extraparam": "testpassingextraparam", "extraparam": "testpassingextraparam",
} }
), assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}"
stream=True,
)
def test_completions_streaming_with_async_http_handler(monkeypatch): def test_completions_streaming_with_async_http_handler(monkeypatch):
@ -429,21 +436,27 @@ def test_completions_streaming_with_async_http_handler(monkeypatch):
"Authorization": f"Bearer {api_key}", "Authorization": f"Bearer {api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
data=json.dumps( data=ANY,
{ stream=True,
)
actual_data = json.loads(
mock_post.call_args.kwargs["data"]
) # Deserialize the actual data
expected_data = {
"model": "dbrx-instruct-071224", "model": "dbrx-instruct-071224",
"messages": messages, "messages": messages,
"temperature": 0.5, "temperature": 0.5,
"stream": True, "stream": True,
"extraparam": "testpassingextraparam", "extraparam": "testpassingextraparam",
} }
), assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}"
stream=True,
)
@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed") @pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch): def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
monkeypatch.delenv("DATABRICKS_API_BASE")
monkeypatch.delenv("DATABRICKS_API_KEY")
from databricks.sdk import WorkspaceClient from databricks.sdk import WorkspaceClient
from databricks.sdk.config import Config from databricks.sdk.config import Config
@ -637,3 +650,48 @@ def test_embeddings_uses_databricks_sdk_if_api_key_and_base_not_specified(monkey
} }
), ),
) )
class TestDatabricksCompletion(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
return {"model": "databricks/databricks-dbrx-instruct"}
def test_pdf_handling(self, pdf_messages):
pytest.skip("Databricks does not support PDF handling")
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pytest.skip("Databricks is openai compatible")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_databricks_embeddings(sync_mode):
import openai
try:
litellm.set_verbose = True
litellm.drop_params = True
if sync_mode:
response = litellm.embedding(
model="databricks/databricks-bge-large-en",
input=["good morning from litellm"],
instruction="Represent this sentence for searching relevant passages:",
)
else:
response = await litellm.aembedding(
model="databricks/databricks-bge-large-en",
input=["good morning from litellm"],
instruction="Represent this sentence for searching relevant passages:",
)
print(f"response: {response}")
openai.types.CreateEmbeddingResponse.model_validate(
response.model_dump(), strict=True
)
# stubbed endpoint is setup to return this
# assert response.data[0]["embedding"] == [0.1, 0.2, 0.3]
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -292,7 +292,7 @@ def test_all_model_configs():
optional_params={}, optional_params={},
) == {"max_tokens_to_sample": 10} ) == {"max_tokens_to_sample": 10}
from litellm.llms.databricks.chat import DatabricksConfig from litellm.llms.databricks.chat.handler import DatabricksConfig
assert "max_completion_tokens" in DatabricksConfig().get_supported_openai_params() assert "max_completion_tokens" in DatabricksConfig().get_supported_openai_params()

View file

@ -932,37 +932,6 @@ async def test_gemini_embeddings(sync_mode, input):
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_databricks_embeddings(sync_mode):
try:
litellm.set_verbose = True
litellm.drop_params = True
if sync_mode:
response = litellm.embedding(
model="databricks/databricks-bge-large-en",
input=["good morning from litellm"],
instruction="Represent this sentence for searching relevant passages:",
)
else:
response = await litellm.aembedding(
model="databricks/databricks-bge-large-en",
input=["good morning from litellm"],
instruction="Represent this sentence for searching relevant passages:",
)
print(f"response: {response}")
openai.types.CreateEmbeddingResponse.model_validate(
response.model_dump(), strict=True
)
# stubbed endpoint is setup to return this
# assert response.data[0]["embedding"] == [0.1, 0.2, 0.3]
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_voyage_embeddings() # test_voyage_embeddings()
# def test_xinference_embeddings(): # def test_xinference_embeddings():
# try: # try: