diff --git a/litellm/__init__.py b/litellm/__init__.py index ebdac8c6e..8c910c3d5 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -732,7 +732,7 @@ from .utils import ( ) from .llms.huggingface_restapi import HuggingfaceConfig from .llms.anthropic import AnthropicConfig -from .llms.databricks import DatabricksConfig +from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig from .llms.predibase import PredibaseConfig from .llms.anthropic_text import AnthropicTextConfig from .llms.replicate import ReplicateConfig diff --git a/litellm/llms/databricks.py b/litellm/llms/databricks.py index 3212c7ad1..b306d425e 100644 --- a/litellm/llms/databricks.py +++ b/litellm/llms/databricks.py @@ -5,8 +5,14 @@ import json from enum import Enum import requests, copy # type: ignore import time -from typing import Callable, Optional, List, Union, Tuple -from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper +from typing import Callable, Optional, List, Union, Tuple, Literal +from litellm.utils import ( + ModelResponse, + Usage, + map_finish_reason, + CustomStreamWrapper, + EmbeddingResponse, +) import litellm from .prompt_templates.factory import prompt_factory, custom_prompt from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler @@ -155,6 +161,48 @@ class DatabricksConfig: raise e +class DatabricksEmbeddingConfig: + """ + Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task + """ + + instruction: Optional[str] = ( + None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries + ) + + def __init__(self, instruction: Optional[str] = None) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params( + self, + ): # no optional openai embedding params supported + return [] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + return optional_params + + class DatabricksChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() @@ -162,7 +210,10 @@ class DatabricksChatCompletion(BaseLLM): # makes headers for API call def _validate_environment( - self, api_key: Optional[str], api_base: Optional[str] + self, + api_key: Optional[str], + api_base: Optional[str], + endpoint_type: Literal["chat_completions", "embeddings"], ) -> Tuple[str, dict]: if api_key is None: raise DatabricksError( @@ -181,7 +232,10 @@ class DatabricksChatCompletion(BaseLLM): "Content-Type": "application/json", } - api_base = "{}/chat/completions".format(api_base) + if endpoint_type == "chat_completions": + api_base = "{}/chat/completions".format(api_base) + elif endpoint_type == "embeddings": + api_base = "{}/embeddings".format(api_base) return api_base, headers def process_response( @@ -374,7 +428,7 @@ class DatabricksChatCompletion(BaseLLM): client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): api_base, headers = self._validate_environment( - api_base=api_base, api_key=api_key + api_base=api_base, api_key=api_key, endpoint_type="chat_completions" ) ## Load Config config = litellm.DatabricksConfig().get_config() @@ -501,6 +555,124 @@ class DatabricksChatCompletion(BaseLLM): return ModelResponse(**response_json) - def embedding(self): - # logic for parsing in - calling - parsing out model embedding calls - pass + async def aembedding( + self, + input: list, + data: dict, + model_response: ModelResponse, + timeout: float, + api_key: str, + api_base: str, + logging_obj, + headers: dict, + client=None, + ) -> EmbeddingResponse: + response = None + try: + if client is None or isinstance(client, AsyncHTTPHandler): + self.async_client = AsyncHTTPHandler(timeout=timeout) # type: ignore + else: + self.async_client = client + + try: + response = await self.async_client.post( + api_base, + headers=headers, + data=json.dumps(data), + ) # type: ignore + + response.raise_for_status() + + response_json = response.json() + except httpx.HTTPStatusError as e: + raise DatabricksError( + status_code=e.response.status_code, + message=response.text if response else str(e), + ) + except httpx.TimeoutException as e: + raise DatabricksError( + status_code=408, message="Timeout error occurred." + ) + except Exception as e: + raise DatabricksError(status_code=500, message=str(e)) + + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response_json, + ) + return EmbeddingResponse(**response_json) + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + original_response=str(e), + ) + raise e + + def embedding( + self, + model: str, + input: list, + timeout: float, + logging_obj, + api_key: Optional[str], + api_base: Optional[str], + optional_params: dict, + model_response: Optional[litellm.utils.EmbeddingResponse] = None, + client=None, + aembedding=None, + ) -> EmbeddingResponse: + api_base, headers = self._validate_environment( + api_base=api_base, api_key=api_key, endpoint_type="embeddings" + ) + model = model + data = {"model": model, "input": input, **optional_params} + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data, "api_base": api_base}, + ) + + if aembedding == True: + return self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, headers=headers) # type: ignore + if client is None or isinstance(client, AsyncHTTPHandler): + self.client = HTTPHandler(timeout=timeout) # type: ignore + else: + self.client = client + + ## EMBEDDING CALL + try: + response = self.client.post( + api_base, + headers=headers, + data=json.dumps(data), + ) # type: ignore + + response.raise_for_status() # type: ignore + + response_json = response.json() # type: ignore + except httpx.HTTPStatusError as e: + raise DatabricksError( + status_code=e.response.status_code, + message=response.text if response else str(e), + ) + except httpx.TimeoutException as e: + raise DatabricksError(status_code=408, message="Timeout error occurred.") + except Exception as e: + raise DatabricksError(status_code=500, message=str(e)) + + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response_json, + ) + + return litellm.EmbeddingResponse(**response_json) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 0bac50639..2e0196faa 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -795,10 +795,10 @@ class OpenAIChatCompletion(BaseLLM): model: str, input: list, timeout: float, + logging_obj, api_key: Optional[str] = None, api_base: Optional[str] = None, model_response: Optional[litellm.utils.EmbeddingResponse] = None, - logging_obj=None, optional_params=None, client=None, aembedding=None, diff --git a/litellm/main.py b/litellm/main.py index 7757fead1..dc4cf0001 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2727,7 +2727,7 @@ def batch_completion_models_all_responses(*args, **kwargs): ### EMBEDDING ENDPOINTS #################### @client -async def aembedding(*args, **kwargs): +async def aembedding(*args, **kwargs) -> EmbeddingResponse: """ Asynchronously calls the `embedding` function with the given arguments and keyword arguments. @@ -2772,12 +2772,13 @@ async def aembedding(*args, **kwargs): or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "ollama" or custom_llm_provider == "vertex_ai" + or custom_llm_provider == "databricks" ): # currently implemented aiohttp calls for just azure and openai, soon all. # Await normally init_response = await loop.run_in_executor(None, func_with_context) - if isinstance(init_response, dict) or isinstance( - init_response, ModelResponse - ): ## CACHING SCENARIO + if isinstance(init_response, dict): + response = EmbeddingResponse(**init_response) + elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO response = init_response elif asyncio.iscoroutine(init_response): response = await init_response @@ -2817,7 +2818,7 @@ def embedding( litellm_logging_obj=None, logger_fn=None, **kwargs, -): +) -> EmbeddingResponse: """ Embedding function that calls an API to generate embeddings for the given input. @@ -2965,7 +2966,7 @@ def embedding( ) try: response = None - logging = litellm_logging_obj + logging: Logging = litellm_logging_obj # type: ignore logging.update_environment_variables( model=model, user=user, @@ -3055,6 +3056,32 @@ def embedding( client=client, aembedding=aembedding, ) + elif custom_llm_provider == "databricks": + api_base = ( + api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE") + ) # type: ignore + + # set API KEY + api_key = ( + api_key + or litellm.api_key + or litellm.databricks_key + or get_secret("DATABRICKS_API_KEY") + ) # type: ignore + + ## EMBEDDING CALL + response = databricks_chat_completions.embedding( + model=model, + input=input, + api_base=api_base, + api_key=api_key, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + ) elif custom_llm_provider == "cohere": cohere_key = ( api_key diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index a441b0e70..30988dba1 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -535,6 +535,37 @@ async def test_triton_embeddings(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_databricks_embeddings(sync_mode): + try: + litellm.set_verbose = True + litellm.drop_params = True + + if sync_mode: + response = litellm.embedding( + model="databricks/databricks-bge-large-en", + input=["good morning from litellm"], + instruction="Represent this sentence for searching relevant passages:", + ) + else: + response = await litellm.aembedding( + model="databricks/databricks-bge-large-en", + input=["good morning from litellm"], + instruction="Represent this sentence for searching relevant passages:", + ) + + print(f"response: {response}") + + openai.types.CreateEmbeddingResponse.model_validate( + response.model_dump(), strict=True + ) + # stubbed endpoint is setup to return this + # assert response.data[0]["embedding"] == [0.1, 0.2, 0.3] + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + # test_voyage_embeddings() # def test_xinference_embeddings(): # try: diff --git a/litellm/tests/test_optional_params.py b/litellm/tests/test_optional_params.py index 5c33cfa0e..31ee3d99b 100644 --- a/litellm/tests/test_optional_params.py +++ b/litellm/tests/test_optional_params.py @@ -83,6 +83,20 @@ def test_azure_optional_params_embeddings(): assert optional_params["user"] == "John" +def test_databricks_optional_params(): + litellm.drop_params = True + optional_params = get_optional_params( + model="", + user="John", + custom_llm_provider="databricks", + max_tokens=10, + temperature=0.2, + ) + print(f"optional_params: {optional_params}") + assert len(optional_params) == 2 + assert "user" not in optional_params + + def test_azure_gpt_optional_params_gpt_vision(): # for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here optional_params = litellm.utils.get_optional_params( diff --git a/litellm/utils.py b/litellm/utils.py index 8189ee058..051717236 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -766,7 +766,13 @@ class EmbeddingResponse(OpenAIObject): _hidden_params: dict = {} def __init__( - self, model=None, usage=None, stream=False, response_ms=None, data=None + self, + model=None, + usage=None, + stream=False, + response_ms=None, + data=None, + **params, ): object = "list" if response_ms: @@ -5033,6 +5039,19 @@ def get_optional_params_embeddings( default_params = {"user": None, "encoding_format": None, "dimensions": None} + def _check_valid_arg(supported_params: Optional[list]): + if supported_params is None: + return + unsupported_params = {} + for k in non_default_params.keys(): + if k not in supported_params: + unsupported_params[k] = non_default_params[k] + if unsupported_params and not litellm.drop_params: + raise UnsupportedParamsError( + status_code=500, + message=f"{custom_llm_provider} does not support parameters: {unsupported_params}, for model={model}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n", + ) + non_default_params = { k: v for k, v in passed_params.items() @@ -5058,6 +5077,18 @@ def get_optional_params_embeddings( non_default_params.pop(k, None) final_params = {**non_default_params, **kwargs} return final_params + if custom_llm_provider == "databricks": + supported_params = get_supported_openai_params( + model=model or "", + custom_llm_provider="databricks", + request_type="embeddings", + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.DatabricksEmbeddingConfig().map_openai_params( + non_default_params=non_default_params, optional_params={} + ) + final_params = {**optional_params, **kwargs} + return final_params if custom_llm_provider == "vertex_ai": if len(non_default_params.keys()) > 0: if litellm.drop_params is True: # drop the unsupported non-default values @@ -5844,6 +5875,14 @@ def get_optional_params( optional_params = litellm.MistralConfig().map_openai_params( non_default_params=non_default_params, optional_params=optional_params ) + elif custom_llm_provider == "databricks": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.DatabricksConfig().map_openai_params( + non_default_params=non_default_params, optional_params=optional_params + ) elif custom_llm_provider == "groq": supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider @@ -6331,7 +6370,11 @@ def get_first_chars_messages(kwargs: dict) -> str: return "" -def get_supported_openai_params(model: str, custom_llm_provider: str) -> Optional[list]: +def get_supported_openai_params( + model: str, + custom_llm_provider: str, + request_type: Literal["chat_completion", "embeddings"] = "chat_completion", +) -> Optional[list]: """ Returns the supported openai params for a given model + provider @@ -6504,6 +6547,11 @@ def get_supported_openai_params(model: str, custom_llm_provider: str) -> Optiona "frequency_penalty", "presence_penalty", ] + elif custom_llm_provider == "databricks": + if request_type == "chat_completion": + return litellm.DatabricksConfig().get_supported_openai_params() + elif request_type == "embeddings": + return litellm.DatabricksEmbeddingConfig().get_supported_openai_params() elif custom_llm_provider == "palm" or custom_llm_provider == "gemini": return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"] elif custom_llm_provider == "vertex_ai":