mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(databricks/chat): support structured outputs on databricks
Closes https://github.com/BerriAI/litellm/pull/6978 - handles content as list for dbrx, - handles streaming+response_format for dbrx
This commit is contained in:
parent
12aea45447
commit
0caf804f4c
18 changed files with 538 additions and 193 deletions
|
@ -1021,7 +1021,8 @@ from .llms.anthropic.experimental_pass_through.transformation import (
|
||||||
)
|
)
|
||||||
from .llms.groq.stt.transformation import GroqSTTConfig
|
from .llms.groq.stt.transformation import GroqSTTConfig
|
||||||
from .llms.anthropic.completion import AnthropicTextConfig
|
from .llms.anthropic.completion import AnthropicTextConfig
|
||||||
from .llms.databricks.chat import DatabricksConfig, DatabricksEmbeddingConfig
|
from .llms.databricks.chat.transformation import DatabricksConfig
|
||||||
|
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
||||||
from .llms.predibase import PredibaseConfig
|
from .llms.predibase import PredibaseConfig
|
||||||
from .llms.replicate import ReplicateConfig
|
from .llms.replicate import ReplicateConfig
|
||||||
from .llms.cohere.completion import CohereConfig
|
from .llms.cohere.completion import CohereConfig
|
||||||
|
|
|
@ -178,7 +178,7 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "databricks":
|
elif custom_llm_provider == "databricks":
|
||||||
if request_type == "chat_completion":
|
if request_type == "chat_completion":
|
||||||
return litellm.DatabricksConfig().get_supported_openai_params()
|
return litellm.DatabricksConfig().get_supported_openai_params(model=model)
|
||||||
elif request_type == "embeddings":
|
elif request_type == "embeddings":
|
||||||
return litellm.DatabricksEmbeddingConfig().get_supported_openai_params()
|
return litellm.DatabricksEmbeddingConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
|
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
|
||||||
|
|
82
litellm/llms/databricks/chat/handler.py
Normal file
82
litellm/llms/databricks/chat/handler.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
"""
|
||||||
|
Handles the chat completion request for Databricks
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Callable, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from httpx._config import Timeout
|
||||||
|
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.types.utils import CustomStreamingDecoder
|
||||||
|
from litellm.utils import ModelResponse
|
||||||
|
|
||||||
|
from ...openai_like.chat.handler import OpenAILikeChatHandler
|
||||||
|
from ..common_utils import DatabricksBase
|
||||||
|
from ..exceptions import DatabricksError
|
||||||
|
from .transformation import DatabricksConfig
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksChatCompletion(OpenAILikeChatHandler, DatabricksBase):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_llm_provider: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key: Optional[str],
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers: Optional[dict] = None,
|
||||||
|
timeout: Optional[Union[float, Timeout]] = None,
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
|
custom_endpoint: Optional[bool] = None,
|
||||||
|
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||||
|
fake_stream: bool = False,
|
||||||
|
):
|
||||||
|
messages = DatabricksConfig()._transform_messages(messages) # type: ignore
|
||||||
|
api_base, headers = self.databricks_validate_environment(
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
endpoint_type="chat_completions",
|
||||||
|
custom_endpoint=custom_endpoint,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if optional_params.get("stream") is True:
|
||||||
|
fake_stream = DatabricksConfig()._should_fake_stream(optional_params)
|
||||||
|
else:
|
||||||
|
fake_stream = False
|
||||||
|
|
||||||
|
return super().completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
api_base=api_base,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
acompletion=acompletion,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
custom_endpoint=True,
|
||||||
|
streaming_decoder=streaming_decoder,
|
||||||
|
fake_stream=fake_stream,
|
||||||
|
)
|
|
@ -13,6 +13,7 @@ import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm import LlmProviders
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
|
@ -33,141 +34,17 @@ from litellm.types.utils import (
|
||||||
GenericStreamingChunk,
|
GenericStreamingChunk,
|
||||||
ProviderField,
|
ProviderField,
|
||||||
)
|
)
|
||||||
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
|
from litellm.utils import (
|
||||||
|
CustomStreamWrapper,
|
||||||
|
EmbeddingResponse,
|
||||||
|
ModelResponse,
|
||||||
|
ProviderConfigManager,
|
||||||
|
Usage,
|
||||||
|
)
|
||||||
|
|
||||||
from ..base import BaseLLM
|
from ...base import BaseLLM
|
||||||
from ..prompt_templates.factory import custom_prompt, prompt_factory
|
from ...prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
from .transformation import DatabricksConfig
|
||||||
|
|
||||||
class DatabricksConfig:
|
|
||||||
"""
|
|
||||||
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
temperature: Optional[int] = None
|
|
||||||
top_p: Optional[int] = None
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
stop: Optional[Union[List[str], str]] = None
|
|
||||||
n: Optional[int] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
temperature: Optional[int] = None,
|
|
||||||
top_p: Optional[int] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
stop: Optional[Union[List[str], str]] = None,
|
|
||||||
n: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_required_params(self) -> List[ProviderField]:
|
|
||||||
"""For a given provider, return it's required fields with a description"""
|
|
||||||
return [
|
|
||||||
ProviderField(
|
|
||||||
field_name="api_key",
|
|
||||||
field_type="string",
|
|
||||||
field_description="Your Databricks API Key.",
|
|
||||||
field_value="dapi...",
|
|
||||||
),
|
|
||||||
ProviderField(
|
|
||||||
field_name="api_base",
|
|
||||||
field_type="string",
|
|
||||||
field_description="Your Databricks API Base.",
|
|
||||||
field_value="https://adb-..",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
|
||||||
return [
|
|
||||||
"stream",
|
|
||||||
"stop",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"max_tokens",
|
|
||||||
"max_completion_tokens",
|
|
||||||
"n",
|
|
||||||
]
|
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
|
||||||
for param, value in non_default_params.items():
|
|
||||||
if param == "max_tokens" or param == "max_completion_tokens":
|
|
||||||
optional_params["max_tokens"] = value
|
|
||||||
if param == "n":
|
|
||||||
optional_params["n"] = value
|
|
||||||
if param == "stream" and value is True:
|
|
||||||
optional_params["stream"] = value
|
|
||||||
if param == "temperature":
|
|
||||||
optional_params["temperature"] = value
|
|
||||||
if param == "top_p":
|
|
||||||
optional_params["top_p"] = value
|
|
||||||
if param == "stop":
|
|
||||||
optional_params["stop"] = value
|
|
||||||
return optional_params
|
|
||||||
|
|
||||||
|
|
||||||
class DatabricksEmbeddingConfig:
|
|
||||||
"""
|
|
||||||
Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
|
|
||||||
"""
|
|
||||||
|
|
||||||
instruction: Optional[str] = (
|
|
||||||
None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, instruction: Optional[str] = None) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(
|
|
||||||
self,
|
|
||||||
): # no optional openai embedding params supported
|
|
||||||
return []
|
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
|
||||||
return optional_params
|
|
||||||
|
|
||||||
|
|
||||||
async def make_call(
|
async def make_call(
|
||||||
|
@ -477,6 +354,12 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
) # [TODO] add max retry support at llm api call level
|
) # [TODO] add max retry support at llm api call level
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
|
|
||||||
|
if messages is not None and custom_llm_provider is not None:
|
||||||
|
provider_config = ProviderConfigManager.get_provider_config(
|
||||||
|
model=model, provider=LlmProviders(custom_llm_provider)
|
||||||
|
)
|
||||||
|
messages = provider_config._transform_messages(messages)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
143
litellm/llms/databricks/chat/transformation.py
Normal file
143
litellm/llms/databricks/chat/transformation.py
Normal file
|
@ -0,0 +1,143 @@
|
||||||
|
"""
|
||||||
|
Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions`
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import ProviderField
|
||||||
|
|
||||||
|
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
from ...prompt_templates.common_utils import (
|
||||||
|
handle_messages_with_content_list_to_str_conversion,
|
||||||
|
strip_name_from_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksConfig(OpenAIGPTConfig):
|
||||||
|
"""
|
||||||
|
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
temperature: Optional[int] = None
|
||||||
|
top_p: Optional[int] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
stop: Optional[Union[List[str], str]] = None
|
||||||
|
n: Optional[int] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
stop: Optional[Union[List[str], str]] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_required_params(self) -> List[ProviderField]:
|
||||||
|
"""For a given provider, return it's required fields with a description"""
|
||||||
|
return [
|
||||||
|
ProviderField(
|
||||||
|
field_name="api_key",
|
||||||
|
field_type="string",
|
||||||
|
field_description="Your Databricks API Key.",
|
||||||
|
field_value="dapi...",
|
||||||
|
),
|
||||||
|
ProviderField(
|
||||||
|
field_name="api_base",
|
||||||
|
field_type="string",
|
||||||
|
field_description="Your Databricks API Base.",
|
||||||
|
field_value="https://adb-..",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: Optional[str] = None) -> list:
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"stop",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"max_tokens",
|
||||||
|
"max_completion_tokens",
|
||||||
|
"n",
|
||||||
|
"response_format",
|
||||||
|
]
|
||||||
|
|
||||||
|
def _should_fake_stream(self, optional_params: dict) -> bool:
|
||||||
|
"""
|
||||||
|
Databricks doesn't support 'response_format' while streaming
|
||||||
|
"""
|
||||||
|
if optional_params.get("response_format") is not None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens" or param == "max_completion_tokens":
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
if param == "n":
|
||||||
|
optional_params["n"] = value
|
||||||
|
if param == "stream" and value is True:
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop"] = value
|
||||||
|
if param == "response_format":
|
||||||
|
optional_params["response_format"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def _transform_messages(
|
||||||
|
self, messages: List[AllMessageValues]
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
"""
|
||||||
|
Databricks does not support:
|
||||||
|
- content in list format.
|
||||||
|
- 'name' in user message.
|
||||||
|
"""
|
||||||
|
new_messages = []
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if isinstance(message, BaseModel):
|
||||||
|
_message = message.model_dump()
|
||||||
|
else:
|
||||||
|
_message = message
|
||||||
|
new_messages.append(_message)
|
||||||
|
new_messages = handle_messages_with_content_list_to_str_conversion(new_messages)
|
||||||
|
new_messages = strip_name_from_messages(new_messages)
|
||||||
|
return super()._transform_messages(new_messages)
|
82
litellm/llms/databricks/common_utils.py
Normal file
82
litellm/llms/databricks/common_utils.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
from typing import Literal, Optional, Tuple
|
||||||
|
|
||||||
|
from .exceptions import DatabricksError
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksBase:
|
||||||
|
def _get_databricks_credentials(
|
||||||
|
self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict]
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
headers = headers or {"Content-Type": "application/json"}
|
||||||
|
try:
|
||||||
|
from databricks.sdk import WorkspaceClient
|
||||||
|
|
||||||
|
databricks_client = WorkspaceClient()
|
||||||
|
|
||||||
|
api_base = api_base or f"{databricks_client.config.host}/serving-endpoints"
|
||||||
|
|
||||||
|
if api_key is None:
|
||||||
|
databricks_auth_headers: dict[str, str] = (
|
||||||
|
databricks_client.config.authenticate()
|
||||||
|
)
|
||||||
|
headers = {**databricks_auth_headers, **headers}
|
||||||
|
|
||||||
|
return api_base, headers
|
||||||
|
except ImportError:
|
||||||
|
raise DatabricksError(
|
||||||
|
status_code=400,
|
||||||
|
message=(
|
||||||
|
"If the Databricks base URL and API key are not set, the databricks-sdk "
|
||||||
|
"Python library must be installed. Please install the databricks-sdk, set "
|
||||||
|
"{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, "
|
||||||
|
"or provide the base URL and API key as arguments."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def databricks_validate_environment(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
endpoint_type: Literal["chat_completions", "embeddings"],
|
||||||
|
custom_endpoint: Optional[bool],
|
||||||
|
headers: Optional[dict],
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
if api_key is None and headers is None:
|
||||||
|
if custom_endpoint is not None:
|
||||||
|
raise DatabricksError(
|
||||||
|
status_code=400,
|
||||||
|
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
api_base, headers = self._get_databricks_credentials(
|
||||||
|
api_base=api_base, api_key=api_key, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if api_base is None:
|
||||||
|
if custom_endpoint:
|
||||||
|
raise DatabricksError(
|
||||||
|
status_code=400,
|
||||||
|
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
api_base, headers = self._get_databricks_credentials(
|
||||||
|
api_base=api_base, api_key=api_key, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if headers is None:
|
||||||
|
headers = {
|
||||||
|
"Authorization": "Bearer {}".format(api_key),
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
if api_key is not None:
|
||||||
|
headers.update({"Authorization": "Bearer {}".format(api_key)})
|
||||||
|
|
||||||
|
if api_key is not None:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
if endpoint_type == "chat_completions" and custom_endpoint is not True:
|
||||||
|
api_base = "{}/chat/completions".format(api_base)
|
||||||
|
elif endpoint_type == "embeddings" and custom_endpoint is not True:
|
||||||
|
api_base = "{}/embeddings".format(api_base)
|
||||||
|
return api_base, headers
|
50
litellm/llms/databricks/embed/handler.py
Normal file
50
litellm/llms/databricks/embed/handler.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
"""
|
||||||
|
Calling logic for Databricks embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.utils import EmbeddingResponse
|
||||||
|
|
||||||
|
from ...openai_like.embedding.handler import OpenAILikeEmbeddingHandler
|
||||||
|
from ..common_utils import DatabricksBase
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksEmbeddingHandler(OpenAILikeEmbeddingHandler, DatabricksBase):
|
||||||
|
def embedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: list,
|
||||||
|
timeout: float,
|
||||||
|
logging_obj,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
optional_params: dict,
|
||||||
|
model_response: Optional[litellm.utils.EmbeddingResponse] = None,
|
||||||
|
client=None,
|
||||||
|
aembedding=None,
|
||||||
|
custom_endpoint: Optional[bool] = None,
|
||||||
|
headers: Optional[dict] = None,
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
api_base, headers = self.databricks_validate_environment(
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
endpoint_type="embeddings",
|
||||||
|
custom_endpoint=custom_endpoint,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
return super().embedding(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
timeout=timeout,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model_response=model_response,
|
||||||
|
client=client,
|
||||||
|
aembedding=aembedding,
|
||||||
|
custom_endpoint=True,
|
||||||
|
headers=headers,
|
||||||
|
)
|
48
litellm/llms/databricks/embed/transformation.py
Normal file
48
litellm/llms/databricks/embed/transformation.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
"""
|
||||||
|
Translates from OpenAI's `/v1/embeddings` to Databricks' `/embeddings`
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksEmbeddingConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
|
||||||
|
"""
|
||||||
|
|
||||||
|
instruction: Optional[str] = (
|
||||||
|
None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, instruction: Optional[str] = None) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(
|
||||||
|
self,
|
||||||
|
): # no optional openai embedding params supported
|
||||||
|
return []
|
||||||
|
|
||||||
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
return optional_params
|
|
@ -37,6 +37,22 @@ def handle_messages_with_content_list_to_str_conversion(
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def strip_name_from_messages(
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
"""
|
||||||
|
Removes 'name' from messages
|
||||||
|
"""
|
||||||
|
new_messages = []
|
||||||
|
for message in messages:
|
||||||
|
msg_role = message.get("role")
|
||||||
|
msg_copy = message.copy()
|
||||||
|
if msg_role == "user":
|
||||||
|
msg_copy.pop("name", None) # type: ignore
|
||||||
|
new_messages.append(msg_copy)
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
|
||||||
def convert_content_list_to_str(message: AllMessageValues) -> str:
|
def convert_content_list_to_str(message: AllMessageValues) -> str:
|
||||||
"""
|
"""
|
||||||
- handles scenario where content is list and not string
|
- handles scenario where content is list and not string
|
||||||
|
|
|
@ -273,7 +273,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
model_id = optional_params.get("model_id", None)
|
model_id = optional_params.get("model_id", None)
|
||||||
|
|
||||||
if use_messages_api is True:
|
if use_messages_api is True:
|
||||||
from litellm.llms.databricks.chat import DatabricksChatCompletion
|
from litellm.llms.databricks.chat.handler import DatabricksChatCompletion
|
||||||
|
|
||||||
openai_like_chat_completions = DatabricksChatCompletion()
|
openai_like_chat_completions = DatabricksChatCompletion()
|
||||||
inference_params["stream"] = True if stream is True else False
|
inference_params["stream"] = True if stream is True else False
|
||||||
|
|
|
@ -90,7 +90,7 @@ class VertexAIPartnerModels(VertexBase):
|
||||||
from google.cloud import aiplatform
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
from litellm.llms.anthropic.chat import AnthropicChatCompletion
|
from litellm.llms.anthropic.chat import AnthropicChatCompletion
|
||||||
from litellm.llms.databricks.chat import DatabricksChatCompletion
|
from litellm.llms.databricks.chat.handler import DatabricksChatCompletion
|
||||||
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
|
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
|
||||||
from litellm.llms.text_completion_codestral import CodestralTextCompletion
|
from litellm.llms.text_completion_codestral import CodestralTextCompletion
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
|
|
|
@ -76,7 +76,7 @@ class VertexAIModelGardenModels(VertexBase):
|
||||||
from google.cloud import aiplatform
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
from litellm.llms.anthropic.chat import AnthropicChatCompletion
|
from litellm.llms.anthropic.chat import AnthropicChatCompletion
|
||||||
from litellm.llms.databricks.chat import DatabricksChatCompletion
|
from litellm.llms.databricks.chat.handler import DatabricksChatCompletion
|
||||||
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
|
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
|
||||||
from litellm.llms.text_completion_codestral import CodestralTextCompletion
|
from litellm.llms.text_completion_codestral import CodestralTextCompletion
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
|
|
|
@ -115,7 +115,8 @@ from .llms.cohere import chat as cohere_chat
|
||||||
from .llms.cohere import completion as cohere_completion # type: ignore
|
from .llms.cohere import completion as cohere_completion # type: ignore
|
||||||
from .llms.cohere.embed import handler as cohere_embed
|
from .llms.cohere.embed import handler as cohere_embed
|
||||||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||||
from .llms.databricks.chat import DatabricksChatCompletion
|
from .llms.databricks.chat.handler import DatabricksChatCompletion
|
||||||
|
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
|
||||||
from .llms.groq.chat.handler import GroqChatCompletion
|
from .llms.groq.chat.handler import GroqChatCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription
|
from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription
|
||||||
|
@ -230,6 +231,7 @@ watsonxai = IBMWatsonXAI()
|
||||||
sagemaker_llm = SagemakerLLM()
|
sagemaker_llm = SagemakerLLM()
|
||||||
watsonx_chat_completion = WatsonXChatHandler()
|
watsonx_chat_completion = WatsonXChatHandler()
|
||||||
openai_like_embedding = OpenAILikeEmbeddingHandler()
|
openai_like_embedding = OpenAILikeEmbeddingHandler()
|
||||||
|
databricks_embedding = DatabricksEmbeddingHandler()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
|
||||||
|
@ -3475,7 +3477,7 @@ def embedding( # noqa: PLR0915
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
## EMBEDDING CALL
|
## EMBEDDING CALL
|
||||||
response = databricks_chat_completions.embedding(
|
response = databricks_embedding.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
|
|
@ -3418,7 +3418,14 @@ def get_optional_params( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
optional_params = litellm.DatabricksConfig().map_openai_params(
|
optional_params = litellm.DatabricksConfig().map_openai_params(
|
||||||
non_default_params=non_default_params, optional_params=optional_params
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "nvidia_nim":
|
elif custom_llm_provider == "nvidia_nim":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -6182,6 +6189,8 @@ class ProviderConfigManager:
|
||||||
return litellm.DeepSeekChatConfig()
|
return litellm.DeepSeekChatConfig()
|
||||||
elif litellm.LlmProviders.GROQ == provider:
|
elif litellm.LlmProviders.GROQ == provider:
|
||||||
return litellm.GroqChatConfig()
|
return litellm.GroqChatConfig()
|
||||||
|
elif litellm.LlmProviders.DATABRICKS == provider:
|
||||||
|
return litellm.DatabricksConfig()
|
||||||
|
|
||||||
return OpenAIGPTConfig()
|
return OpenAIGPTConfig()
|
||||||
|
|
||||||
|
|
|
@ -55,6 +55,7 @@ class BaseLLMChatTest(ABC):
|
||||||
assert response.choices[0].message.content is not None
|
assert response.choices[0].message.content is not None
|
||||||
|
|
||||||
def test_message_with_name(self):
|
def test_message_with_name(self):
|
||||||
|
litellm.set_verbose = True
|
||||||
base_completion_call_args = self.get_base_completion_call_args()
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": "Hello", "name": "test_name"},
|
{"role": "user", "content": "Hello", "name": "test_name"},
|
||||||
|
@ -69,6 +70,7 @@ class BaseLLMChatTest(ABC):
|
||||||
{"type": "text"},
|
{"type": "text"},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.flaky(retries=6, delay=1)
|
||||||
def test_json_response_format(self, response_format):
|
def test_json_response_format(self, response_format):
|
||||||
"""
|
"""
|
||||||
Test that the JSON response format is supported by the LLM API
|
Test that the JSON response format is supported by the LLM API
|
||||||
|
|
|
@ -4,7 +4,8 @@ import json
|
||||||
import pytest
|
import pytest
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch, ANY
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
|
@ -14,6 +15,7 @@ import litellm
|
||||||
from litellm.exceptions import BadRequestError
|
from litellm.exceptions import BadRequestError
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.utils import CustomStreamWrapper
|
from litellm.utils import CustomStreamWrapper
|
||||||
|
from base_llm_unit_tests import BaseLLMChatTest
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import databricks.sdk
|
import databricks.sdk
|
||||||
|
@ -333,6 +335,7 @@ def test_completions_with_async_http_handler(monkeypatch):
|
||||||
"Authorization": f"Bearer {api_key}",
|
"Authorization": f"Bearer {api_key}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
|
timeout=ANY,
|
||||||
data=json.dumps(
|
data=json.dumps(
|
||||||
{
|
{
|
||||||
"model": "dbrx-instruct-071224",
|
"model": "dbrx-instruct-071224",
|
||||||
|
@ -376,18 +379,22 @@ def test_completions_streaming_with_sync_http_handler(monkeypatch):
|
||||||
"Authorization": f"Bearer {api_key}",
|
"Authorization": f"Bearer {api_key}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
data=json.dumps(
|
data=ANY,
|
||||||
{
|
|
||||||
"model": "dbrx-instruct-071224",
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": 0.5,
|
|
||||||
"stream": True,
|
|
||||||
"extraparam": "testpassingextraparam",
|
|
||||||
}
|
|
||||||
),
|
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
actual_data = json.loads(
|
||||||
|
mock_post.call_args.kwargs["data"]
|
||||||
|
) # Deserialize the actual data
|
||||||
|
expected_data = {
|
||||||
|
"model": "dbrx-instruct-071224",
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": 0.5,
|
||||||
|
"stream": True,
|
||||||
|
"extraparam": "testpassingextraparam",
|
||||||
|
}
|
||||||
|
assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}"
|
||||||
|
|
||||||
|
|
||||||
def test_completions_streaming_with_async_http_handler(monkeypatch):
|
def test_completions_streaming_with_async_http_handler(monkeypatch):
|
||||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||||
|
@ -429,21 +436,27 @@ def test_completions_streaming_with_async_http_handler(monkeypatch):
|
||||||
"Authorization": f"Bearer {api_key}",
|
"Authorization": f"Bearer {api_key}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
data=json.dumps(
|
data=ANY,
|
||||||
{
|
|
||||||
"model": "dbrx-instruct-071224",
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": 0.5,
|
|
||||||
"stream": True,
|
|
||||||
"extraparam": "testpassingextraparam",
|
|
||||||
}
|
|
||||||
),
|
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
actual_data = json.loads(
|
||||||
|
mock_post.call_args.kwargs["data"]
|
||||||
|
) # Deserialize the actual data
|
||||||
|
expected_data = {
|
||||||
|
"model": "dbrx-instruct-071224",
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": 0.5,
|
||||||
|
"stream": True,
|
||||||
|
"extraparam": "testpassingextraparam",
|
||||||
|
}
|
||||||
|
assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
|
@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
|
||||||
def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
|
def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
|
||||||
|
monkeypatch.delenv("DATABRICKS_API_BASE")
|
||||||
|
monkeypatch.delenv("DATABRICKS_API_KEY")
|
||||||
from databricks.sdk import WorkspaceClient
|
from databricks.sdk import WorkspaceClient
|
||||||
from databricks.sdk.config import Config
|
from databricks.sdk.config import Config
|
||||||
|
|
||||||
|
@ -637,3 +650,48 @@ def test_embeddings_uses_databricks_sdk_if_api_key_and_base_not_specified(monkey
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatabricksCompletion(BaseLLMChatTest):
|
||||||
|
def get_base_completion_call_args(self) -> dict:
|
||||||
|
return {"model": "databricks/databricks-dbrx-instruct"}
|
||||||
|
|
||||||
|
def test_pdf_handling(self, pdf_messages):
|
||||||
|
pytest.skip("Databricks does not support PDF handling")
|
||||||
|
|
||||||
|
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||||
|
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||||
|
pytest.skip("Databricks is openai compatible")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_databricks_embeddings(sync_mode):
|
||||||
|
import openai
|
||||||
|
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
response = litellm.embedding(
|
||||||
|
model="databricks/databricks-bge-large-en",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
instruction="Represent this sentence for searching relevant passages:",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="databricks/databricks-bge-large-en",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
instruction="Represent this sentence for searching relevant passages:",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
openai.types.CreateEmbeddingResponse.model_validate(
|
||||||
|
response.model_dump(), strict=True
|
||||||
|
)
|
||||||
|
# stubbed endpoint is setup to return this
|
||||||
|
# assert response.data[0]["embedding"] == [0.1, 0.2, 0.3]
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
|
@ -292,7 +292,7 @@ def test_all_model_configs():
|
||||||
optional_params={},
|
optional_params={},
|
||||||
) == {"max_tokens_to_sample": 10}
|
) == {"max_tokens_to_sample": 10}
|
||||||
|
|
||||||
from litellm.llms.databricks.chat import DatabricksConfig
|
from litellm.llms.databricks.chat.handler import DatabricksConfig
|
||||||
|
|
||||||
assert "max_completion_tokens" in DatabricksConfig().get_supported_openai_params()
|
assert "max_completion_tokens" in DatabricksConfig().get_supported_openai_params()
|
||||||
|
|
||||||
|
|
|
@ -932,37 +932,6 @@ async def test_gemini_embeddings(sync_mode, input):
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_databricks_embeddings(sync_mode):
|
|
||||||
try:
|
|
||||||
litellm.set_verbose = True
|
|
||||||
litellm.drop_params = True
|
|
||||||
|
|
||||||
if sync_mode:
|
|
||||||
response = litellm.embedding(
|
|
||||||
model="databricks/databricks-bge-large-en",
|
|
||||||
input=["good morning from litellm"],
|
|
||||||
instruction="Represent this sentence for searching relevant passages:",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = await litellm.aembedding(
|
|
||||||
model="databricks/databricks-bge-large-en",
|
|
||||||
input=["good morning from litellm"],
|
|
||||||
instruction="Represent this sentence for searching relevant passages:",
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"response: {response}")
|
|
||||||
|
|
||||||
openai.types.CreateEmbeddingResponse.model_validate(
|
|
||||||
response.model_dump(), strict=True
|
|
||||||
)
|
|
||||||
# stubbed endpoint is setup to return this
|
|
||||||
# assert response.data[0]["embedding"] == [0.1, 0.2, 0.3]
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_voyage_embeddings()
|
# test_voyage_embeddings()
|
||||||
# def test_xinference_embeddings():
|
# def test_xinference_embeddings():
|
||||||
# try:
|
# try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue