diff --git a/litellm/__init__.py b/litellm/__init__.py index 95172efa04..4dad0ba085 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1021,7 +1021,8 @@ from .llms.anthropic.experimental_pass_through.transformation import ( ) from .llms.groq.stt.transformation import GroqSTTConfig from .llms.anthropic.completion import AnthropicTextConfig -from .llms.databricks.chat import DatabricksConfig, DatabricksEmbeddingConfig +from .llms.databricks.chat.transformation import DatabricksConfig +from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig from .llms.predibase import PredibaseConfig from .llms.replicate import ReplicateConfig from .llms.cohere.completion import CohereConfig diff --git a/litellm/litellm_core_utils/get_supported_openai_params.py b/litellm/litellm_core_utils/get_supported_openai_params.py index e6ce4d2df8..554df8092b 100644 --- a/litellm/litellm_core_utils/get_supported_openai_params.py +++ b/litellm/litellm_core_utils/get_supported_openai_params.py @@ -178,7 +178,7 @@ def get_supported_openai_params( # noqa: PLR0915 ] elif custom_llm_provider == "databricks": if request_type == "chat_completion": - return litellm.DatabricksConfig().get_supported_openai_params() + return litellm.DatabricksConfig().get_supported_openai_params(model=model) elif request_type == "embeddings": return litellm.DatabricksEmbeddingConfig().get_supported_openai_params() elif custom_llm_provider == "palm" or custom_llm_provider == "gemini": diff --git a/litellm/llms/databricks/chat/handler.py b/litellm/llms/databricks/chat/handler.py new file mode 100644 index 0000000000..078235a284 --- /dev/null +++ b/litellm/llms/databricks/chat/handler.py @@ -0,0 +1,82 @@ +""" +Handles the chat completion request for Databricks +""" + +from typing import Any, Callable, Literal, Optional, Tuple, Union + +from httpx._config import Timeout + +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.utils import CustomStreamingDecoder +from litellm.utils import ModelResponse + +from ...openai_like.chat.handler import OpenAILikeChatHandler +from ..common_utils import DatabricksBase +from ..exceptions import DatabricksError +from .transformation import DatabricksConfig + + +class DatabricksChatCompletion(OpenAILikeChatHandler, DatabricksBase): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def completion( + self, + *, + model: str, + messages: list, + api_base: str, + custom_llm_provider: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key: Optional[str], + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers: Optional[dict] = None, + timeout: Optional[Union[float, Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + custom_endpoint: Optional[bool] = None, + streaming_decoder: Optional[CustomStreamingDecoder] = None, + fake_stream: bool = False, + ): + messages = DatabricksConfig()._transform_messages(messages) # type: ignore + api_base, headers = self.databricks_validate_environment( + api_base=api_base, + api_key=api_key, + endpoint_type="chat_completions", + custom_endpoint=custom_endpoint, + headers=headers, + ) + + if optional_params.get("stream") is True: + fake_stream = DatabricksConfig()._should_fake_stream(optional_params) + else: + fake_stream = False + + return super().completion( + model=model, + messages=messages, + api_base=api_base, + custom_llm_provider=custom_llm_provider, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + custom_endpoint=True, + streaming_decoder=streaming_decoder, + fake_stream=fake_stream, + ) diff --git a/litellm/llms/databricks/chat.py b/litellm/llms/databricks/chat/old_handler.py similarity index 82% rename from litellm/llms/databricks/chat.py rename to litellm/llms/databricks/chat/old_handler.py index 5d40dd9ed4..95cc1cfc6d 100644 --- a/litellm/llms/databricks/chat.py +++ b/litellm/llms/databricks/chat/old_handler.py @@ -13,6 +13,7 @@ import httpx # type: ignore import requests # type: ignore import litellm +from litellm import LlmProviders from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, @@ -33,141 +34,17 @@ from litellm.types.utils import ( GenericStreamingChunk, ProviderField, ) -from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage +from litellm.utils import ( + CustomStreamWrapper, + EmbeddingResponse, + ModelResponse, + ProviderConfigManager, + Usage, +) -from ..base import BaseLLM -from ..prompt_templates.factory import custom_prompt, prompt_factory - - -class DatabricksConfig: - """ - Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request - """ - - max_tokens: Optional[int] = None - temperature: Optional[int] = None - top_p: Optional[int] = None - top_k: Optional[int] = None - stop: Optional[Union[List[str], str]] = None - n: Optional[int] = None - - def __init__( - self, - max_tokens: Optional[int] = None, - temperature: Optional[int] = None, - top_p: Optional[int] = None, - top_k: Optional[int] = None, - stop: Optional[Union[List[str], str]] = None, - n: Optional[int] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - def get_required_params(self) -> List[ProviderField]: - """For a given provider, return it's required fields with a description""" - return [ - ProviderField( - field_name="api_key", - field_type="string", - field_description="Your Databricks API Key.", - field_value="dapi...", - ), - ProviderField( - field_name="api_base", - field_type="string", - field_description="Your Databricks API Base.", - field_value="https://adb-..", - ), - ] - - def get_supported_openai_params(self): - return [ - "stream", - "stop", - "temperature", - "top_p", - "max_tokens", - "max_completion_tokens", - "n", - ] - - def map_openai_params(self, non_default_params: dict, optional_params: dict): - for param, value in non_default_params.items(): - if param == "max_tokens" or param == "max_completion_tokens": - optional_params["max_tokens"] = value - if param == "n": - optional_params["n"] = value - if param == "stream" and value is True: - optional_params["stream"] = value - if param == "temperature": - optional_params["temperature"] = value - if param == "top_p": - optional_params["top_p"] = value - if param == "stop": - optional_params["stop"] = value - return optional_params - - -class DatabricksEmbeddingConfig: - """ - Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task - """ - - instruction: Optional[str] = ( - None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries - ) - - def __init__(self, instruction: Optional[str] = None) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - def get_supported_openai_params( - self, - ): # no optional openai embedding params supported - return [] - - def map_openai_params(self, non_default_params: dict, optional_params: dict): - return optional_params +from ...base import BaseLLM +from ...prompt_templates.factory import custom_prompt, prompt_factory +from .transformation import DatabricksConfig async def make_call( @@ -477,6 +354,12 @@ class DatabricksChatCompletion(BaseLLM): ) # [TODO] add max retry support at llm api call level optional_params["stream"] = stream + if messages is not None and custom_llm_provider is not None: + provider_config = ProviderConfigManager.get_provider_config( + model=model, provider=LlmProviders(custom_llm_provider) + ) + messages = provider_config._transform_messages(messages) + data = { "model": model, "messages": messages, diff --git a/litellm/llms/databricks/chat/transformation.py b/litellm/llms/databricks/chat/transformation.py new file mode 100644 index 0000000000..009e5e1894 --- /dev/null +++ b/litellm/llms/databricks/chat/transformation.py @@ -0,0 +1,143 @@ +""" +Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions` +""" + +import types +from typing import List, Optional, Union + +from pydantic import BaseModel + +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ProviderField + +from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig +from ...prompt_templates.common_utils import ( + handle_messages_with_content_list_to_str_conversion, + strip_name_from_messages, +) + + +class DatabricksConfig(OpenAIGPTConfig): + """ + Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request + """ + + max_tokens: Optional[int] = None + temperature: Optional[int] = None + top_p: Optional[int] = None + top_k: Optional[int] = None + stop: Optional[Union[List[str], str]] = None + n: Optional[int] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + top_k: Optional[int] = None, + stop: Optional[Union[List[str], str]] = None, + n: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_required_params(self) -> List[ProviderField]: + """For a given provider, return it's required fields with a description""" + return [ + ProviderField( + field_name="api_key", + field_type="string", + field_description="Your Databricks API Key.", + field_value="dapi...", + ), + ProviderField( + field_name="api_base", + field_type="string", + field_description="Your Databricks API Base.", + field_value="https://adb-..", + ), + ] + + def get_supported_openai_params(self, model: Optional[str] = None) -> list: + return [ + "stream", + "stop", + "temperature", + "top_p", + "max_tokens", + "max_completion_tokens", + "n", + "response_format", + ] + + def _should_fake_stream(self, optional_params: dict) -> bool: + """ + Databricks doesn't support 'response_format' while streaming + """ + if optional_params.get("response_format") is not None: + return True + + return False + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ): + for param, value in non_default_params.items(): + if param == "max_tokens" or param == "max_completion_tokens": + optional_params["max_tokens"] = value + if param == "n": + optional_params["n"] = value + if param == "stream" and value is True: + optional_params["stream"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "stop": + optional_params["stop"] = value + if param == "response_format": + optional_params["response_format"] = value + return optional_params + + def _transform_messages( + self, messages: List[AllMessageValues] + ) -> List[AllMessageValues]: + """ + Databricks does not support: + - content in list format. + - 'name' in user message. + """ + new_messages = [] + for idx, message in enumerate(messages): + if isinstance(message, BaseModel): + _message = message.model_dump() + else: + _message = message + new_messages.append(_message) + new_messages = handle_messages_with_content_list_to_str_conversion(new_messages) + new_messages = strip_name_from_messages(new_messages) + return super()._transform_messages(new_messages) diff --git a/litellm/llms/databricks/common_utils.py b/litellm/llms/databricks/common_utils.py new file mode 100644 index 0000000000..e8481e25b2 --- /dev/null +++ b/litellm/llms/databricks/common_utils.py @@ -0,0 +1,82 @@ +from typing import Literal, Optional, Tuple + +from .exceptions import DatabricksError + + +class DatabricksBase: + def _get_databricks_credentials( + self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict] + ) -> Tuple[str, dict]: + headers = headers or {"Content-Type": "application/json"} + try: + from databricks.sdk import WorkspaceClient + + databricks_client = WorkspaceClient() + + api_base = api_base or f"{databricks_client.config.host}/serving-endpoints" + + if api_key is None: + databricks_auth_headers: dict[str, str] = ( + databricks_client.config.authenticate() + ) + headers = {**databricks_auth_headers, **headers} + + return api_base, headers + except ImportError: + raise DatabricksError( + status_code=400, + message=( + "If the Databricks base URL and API key are not set, the databricks-sdk " + "Python library must be installed. Please install the databricks-sdk, set " + "{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, " + "or provide the base URL and API key as arguments." + ), + ) + + def databricks_validate_environment( + self, + api_key: Optional[str], + api_base: Optional[str], + endpoint_type: Literal["chat_completions", "embeddings"], + custom_endpoint: Optional[bool], + headers: Optional[dict], + ) -> Tuple[str, dict]: + if api_key is None and headers is None: + if custom_endpoint is not None: + raise DatabricksError( + status_code=400, + message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", + ) + else: + api_base, headers = self._get_databricks_credentials( + api_base=api_base, api_key=api_key, headers=headers + ) + + if api_base is None: + if custom_endpoint: + raise DatabricksError( + status_code=400, + message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", + ) + else: + api_base, headers = self._get_databricks_credentials( + api_base=api_base, api_key=api_key, headers=headers + ) + + if headers is None: + headers = { + "Authorization": "Bearer {}".format(api_key), + "Content-Type": "application/json", + } + else: + if api_key is not None: + headers.update({"Authorization": "Bearer {}".format(api_key)}) + + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" + + if endpoint_type == "chat_completions" and custom_endpoint is not True: + api_base = "{}/chat/completions".format(api_base) + elif endpoint_type == "embeddings" and custom_endpoint is not True: + api_base = "{}/embeddings".format(api_base) + return api_base, headers diff --git a/litellm/llms/databricks/embed/handler.py b/litellm/llms/databricks/embed/handler.py new file mode 100644 index 0000000000..4ed5853762 --- /dev/null +++ b/litellm/llms/databricks/embed/handler.py @@ -0,0 +1,50 @@ +""" +Calling logic for Databricks embeddings +""" + +from typing import Optional + +import litellm +from litellm.utils import EmbeddingResponse + +from ...openai_like.embedding.handler import OpenAILikeEmbeddingHandler +from ..common_utils import DatabricksBase + + +class DatabricksEmbeddingHandler(OpenAILikeEmbeddingHandler, DatabricksBase): + def embedding( + self, + model: str, + input: list, + timeout: float, + logging_obj, + api_key: Optional[str], + api_base: Optional[str], + optional_params: dict, + model_response: Optional[litellm.utils.EmbeddingResponse] = None, + client=None, + aembedding=None, + custom_endpoint: Optional[bool] = None, + headers: Optional[dict] = None, + ) -> EmbeddingResponse: + api_base, headers = self.databricks_validate_environment( + api_base=api_base, + api_key=api_key, + endpoint_type="embeddings", + custom_endpoint=custom_endpoint, + headers=headers, + ) + return super().embedding( + model=model, + input=input, + timeout=timeout, + logging_obj=logging_obj, + api_key=api_key, + api_base=api_base, + optional_params=optional_params, + model_response=model_response, + client=client, + aembedding=aembedding, + custom_endpoint=True, + headers=headers, + ) diff --git a/litellm/llms/databricks/embed/transformation.py b/litellm/llms/databricks/embed/transformation.py new file mode 100644 index 0000000000..8c7e119714 --- /dev/null +++ b/litellm/llms/databricks/embed/transformation.py @@ -0,0 +1,48 @@ +""" +Translates from OpenAI's `/v1/embeddings` to Databricks' `/embeddings` +""" + +import types +from typing import Optional + + +class DatabricksEmbeddingConfig: + """ + Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task + """ + + instruction: Optional[str] = ( + None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries + ) + + def __init__(self, instruction: Optional[str] = None) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params( + self, + ): # no optional openai embedding params supported + return [] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + return optional_params diff --git a/litellm/llms/prompt_templates/common_utils.py b/litellm/llms/prompt_templates/common_utils.py index 24cb7b4512..c0798f3b22 100644 --- a/litellm/llms/prompt_templates/common_utils.py +++ b/litellm/llms/prompt_templates/common_utils.py @@ -37,6 +37,22 @@ def handle_messages_with_content_list_to_str_conversion( return messages +def strip_name_from_messages( + messages: List[AllMessageValues], +) -> List[AllMessageValues]: + """ + Removes 'name' from messages + """ + new_messages = [] + for message in messages: + msg_role = message.get("role") + msg_copy = message.copy() + if msg_role == "user": + msg_copy.pop("name", None) # type: ignore + new_messages.append(msg_copy) + return new_messages + + def convert_content_list_to_str(message: AllMessageValues) -> str: """ - handles scenario where content is list and not string diff --git a/litellm/llms/sagemaker/sagemaker.py b/litellm/llms/sagemaker/sagemaker.py index ecfa40b8cf..88a0adc1ef 100644 --- a/litellm/llms/sagemaker/sagemaker.py +++ b/litellm/llms/sagemaker/sagemaker.py @@ -273,7 +273,7 @@ class SagemakerLLM(BaseAWSLLM): model_id = optional_params.get("model_id", None) if use_messages_api is True: - from litellm.llms.databricks.chat import DatabricksChatCompletion + from litellm.llms.databricks.chat.handler import DatabricksChatCompletion openai_like_chat_completions = DatabricksChatCompletion() inference_params["stream"] = True if stream is True else False diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py index f335f53d9a..62668f5b03 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py @@ -90,7 +90,7 @@ class VertexAIPartnerModels(VertexBase): from google.cloud import aiplatform from litellm.llms.anthropic.chat import AnthropicChatCompletion - from litellm.llms.databricks.chat import DatabricksChatCompletion + from litellm.llms.databricks.chat.handler import DatabricksChatCompletion from litellm.llms.OpenAI.openai import OpenAIChatCompletion from litellm.llms.text_completion_codestral import CodestralTextCompletion from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py index 4285c4dcbc..4c467f7c71 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py @@ -76,7 +76,7 @@ class VertexAIModelGardenModels(VertexBase): from google.cloud import aiplatform from litellm.llms.anthropic.chat import AnthropicChatCompletion - from litellm.llms.databricks.chat import DatabricksChatCompletion + from litellm.llms.databricks.chat.handler import DatabricksChatCompletion from litellm.llms.OpenAI.openai import OpenAIChatCompletion from litellm.llms.text_completion_codestral import CodestralTextCompletion from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( diff --git a/litellm/main.py b/litellm/main.py index 6da7bb604a..a32e8b6c05 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -115,7 +115,8 @@ from .llms.cohere import chat as cohere_chat from .llms.cohere import completion as cohere_completion # type: ignore from .llms.cohere.embed import handler as cohere_embed from .llms.custom_llm import CustomLLM, custom_chat_llm_router -from .llms.databricks.chat import DatabricksChatCompletion +from .llms.databricks.chat.handler import DatabricksChatCompletion +from .llms.databricks.embed.handler import DatabricksEmbeddingHandler from .llms.groq.chat.handler import GroqChatCompletion from .llms.huggingface_restapi import Huggingface from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription @@ -230,6 +231,7 @@ watsonxai = IBMWatsonXAI() sagemaker_llm = SagemakerLLM() watsonx_chat_completion = WatsonXChatHandler() openai_like_embedding = OpenAILikeEmbeddingHandler() +databricks_embedding = DatabricksEmbeddingHandler() ####### COMPLETION ENDPOINTS ################ @@ -3475,7 +3477,7 @@ def embedding( # noqa: PLR0915 ) # type: ignore ## EMBEDDING CALL - response = databricks_chat_completions.embedding( + response = databricks_embedding.embedding( model=model, input=input, api_base=api_base, diff --git a/litellm/utils.py b/litellm/utils.py index 3cbe038ef3..946d81982b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3418,7 +3418,14 @@ def get_optional_params( # noqa: PLR0915 ) _check_valid_arg(supported_params=supported_params) optional_params = litellm.DatabricksConfig().map_openai_params( - non_default_params=non_default_params, optional_params=optional_params + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), ) elif custom_llm_provider == "nvidia_nim": supported_params = get_supported_openai_params( @@ -6182,6 +6189,8 @@ class ProviderConfigManager: return litellm.DeepSeekChatConfig() elif litellm.LlmProviders.GROQ == provider: return litellm.GroqChatConfig() + elif litellm.LlmProviders.DATABRICKS == provider: + return litellm.DatabricksConfig() return OpenAIGPTConfig() diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index d4c2777448..cde016125b 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -55,6 +55,7 @@ class BaseLLMChatTest(ABC): assert response.choices[0].message.content is not None def test_message_with_name(self): + litellm.set_verbose = True base_completion_call_args = self.get_base_completion_call_args() messages = [ {"role": "user", "content": "Hello", "name": "test_name"}, @@ -69,6 +70,7 @@ class BaseLLMChatTest(ABC): {"type": "text"}, ], ) + @pytest.mark.flaky(retries=6, delay=1) def test_json_response_format(self, response_format): """ Test that the JSON response format is supported by the LLM API diff --git a/tests/llm_translation/test_databricks.py b/tests/llm_translation/test_databricks.py index 89ad6832ba..9ea6b6f576 100644 --- a/tests/llm_translation/test_databricks.py +++ b/tests/llm_translation/test_databricks.py @@ -4,7 +4,8 @@ import json import pytest import sys from typing import Any, Dict, List -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, patch, ANY + import os sys.path.insert( @@ -14,6 +15,7 @@ import litellm from litellm.exceptions import BadRequestError from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.utils import CustomStreamWrapper +from base_llm_unit_tests import BaseLLMChatTest try: import databricks.sdk @@ -333,6 +335,7 @@ def test_completions_with_async_http_handler(monkeypatch): "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", }, + timeout=ANY, data=json.dumps( { "model": "dbrx-instruct-071224", @@ -376,18 +379,22 @@ def test_completions_streaming_with_sync_http_handler(monkeypatch): "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", }, - data=json.dumps( - { - "model": "dbrx-instruct-071224", - "messages": messages, - "temperature": 0.5, - "stream": True, - "extraparam": "testpassingextraparam", - } - ), + data=ANY, stream=True, ) + actual_data = json.loads( + mock_post.call_args.kwargs["data"] + ) # Deserialize the actual data + expected_data = { + "model": "dbrx-instruct-071224", + "messages": messages, + "temperature": 0.5, + "stream": True, + "extraparam": "testpassingextraparam", + } + assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}" + def test_completions_streaming_with_async_http_handler(monkeypatch): base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" @@ -429,21 +436,27 @@ def test_completions_streaming_with_async_http_handler(monkeypatch): "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", }, - data=json.dumps( - { - "model": "dbrx-instruct-071224", - "messages": messages, - "temperature": 0.5, - "stream": True, - "extraparam": "testpassingextraparam", - } - ), + data=ANY, stream=True, ) + actual_data = json.loads( + mock_post.call_args.kwargs["data"] + ) # Deserialize the actual data + expected_data = { + "model": "dbrx-instruct-071224", + "messages": messages, + "temperature": 0.5, + "stream": True, + "extraparam": "testpassingextraparam", + } + assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}" + @pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed") def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch): + monkeypatch.delenv("DATABRICKS_API_BASE") + monkeypatch.delenv("DATABRICKS_API_KEY") from databricks.sdk import WorkspaceClient from databricks.sdk.config import Config @@ -637,3 +650,48 @@ def test_embeddings_uses_databricks_sdk_if_api_key_and_base_not_specified(monkey } ), ) + + +class TestDatabricksCompletion(BaseLLMChatTest): + def get_base_completion_call_args(self) -> dict: + return {"model": "databricks/databricks-dbrx-instruct"} + + def test_pdf_handling(self, pdf_messages): + pytest.skip("Databricks does not support PDF handling") + + def test_tool_call_no_arguments(self, tool_call_no_arguments): + """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" + pytest.skip("Databricks is openai compatible") + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_databricks_embeddings(sync_mode): + import openai + + try: + litellm.set_verbose = True + litellm.drop_params = True + + if sync_mode: + response = litellm.embedding( + model="databricks/databricks-bge-large-en", + input=["good morning from litellm"], + instruction="Represent this sentence for searching relevant passages:", + ) + else: + response = await litellm.aembedding( + model="databricks/databricks-bge-large-en", + input=["good morning from litellm"], + instruction="Represent this sentence for searching relevant passages:", + ) + + print(f"response: {response}") + + openai.types.CreateEmbeddingResponse.model_validate( + response.model_dump(), strict=True + ) + # stubbed endpoint is setup to return this + # assert response.data[0]["embedding"] == [0.1, 0.2, 0.3] + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/tests/llm_translation/test_max_completion_tokens.py b/tests/llm_translation/test_max_completion_tokens.py index 6ac681b80a..422141d472 100644 --- a/tests/llm_translation/test_max_completion_tokens.py +++ b/tests/llm_translation/test_max_completion_tokens.py @@ -292,7 +292,7 @@ def test_all_model_configs(): optional_params={}, ) == {"max_tokens_to_sample": 10} - from litellm.llms.databricks.chat import DatabricksConfig + from litellm.llms.databricks.chat.handler import DatabricksConfig assert "max_completion_tokens" in DatabricksConfig().get_supported_openai_params() diff --git a/tests/local_testing/test_embedding.py b/tests/local_testing/test_embedding.py index 096dfc4190..5930e16d11 100644 --- a/tests/local_testing/test_embedding.py +++ b/tests/local_testing/test_embedding.py @@ -932,37 +932,6 @@ async def test_gemini_embeddings(sync_mode, input): pytest.fail(f"Error occurred: {e}") -@pytest.mark.parametrize("sync_mode", [True, False]) -@pytest.mark.asyncio -async def test_databricks_embeddings(sync_mode): - try: - litellm.set_verbose = True - litellm.drop_params = True - - if sync_mode: - response = litellm.embedding( - model="databricks/databricks-bge-large-en", - input=["good morning from litellm"], - instruction="Represent this sentence for searching relevant passages:", - ) - else: - response = await litellm.aembedding( - model="databricks/databricks-bge-large-en", - input=["good morning from litellm"], - instruction="Represent this sentence for searching relevant passages:", - ) - - print(f"response: {response}") - - openai.types.CreateEmbeddingResponse.model_validate( - response.model_dump(), strict=True - ) - # stubbed endpoint is setup to return this - # assert response.data[0]["embedding"] == [0.1, 0.2, 0.3] - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - # test_voyage_embeddings() # def test_xinference_embeddings(): # try: