mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(databricks/chat): support structured outputs on databricks
Closes https://github.com/BerriAI/litellm/pull/6978 - handles content as list for dbrx, - handles streaming+response_format for dbrx
This commit is contained in:
parent
af817276c0
commit
70f7d7e787
18 changed files with 538 additions and 193 deletions
|
@ -1021,7 +1021,8 @@ from .llms.anthropic.experimental_pass_through.transformation import (
|
|||
)
|
||||
from .llms.groq.stt.transformation import GroqSTTConfig
|
||||
from .llms.anthropic.completion import AnthropicTextConfig
|
||||
from .llms.databricks.chat import DatabricksConfig, DatabricksEmbeddingConfig
|
||||
from .llms.databricks.chat.transformation import DatabricksConfig
|
||||
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
||||
from .llms.predibase import PredibaseConfig
|
||||
from .llms.replicate import ReplicateConfig
|
||||
from .llms.cohere.completion import CohereConfig
|
||||
|
|
|
@ -178,7 +178,7 @@ def get_supported_openai_params( # noqa: PLR0915
|
|||
]
|
||||
elif custom_llm_provider == "databricks":
|
||||
if request_type == "chat_completion":
|
||||
return litellm.DatabricksConfig().get_supported_openai_params()
|
||||
return litellm.DatabricksConfig().get_supported_openai_params(model=model)
|
||||
elif request_type == "embeddings":
|
||||
return litellm.DatabricksEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
|
||||
|
|
82
litellm/llms/databricks/chat/handler.py
Normal file
82
litellm/llms/databricks/chat/handler.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
"""
|
||||
Handles the chat completion request for Databricks
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Literal, Optional, Tuple, Union
|
||||
|
||||
from httpx._config import Timeout
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.utils import CustomStreamingDecoder
|
||||
from litellm.utils import ModelResponse
|
||||
|
||||
from ...openai_like.chat.handler import OpenAILikeChatHandler
|
||||
from ..common_utils import DatabricksBase
|
||||
from ..exceptions import DatabricksError
|
||||
from .transformation import DatabricksConfig
|
||||
|
||||
|
||||
class DatabricksChatCompletion(OpenAILikeChatHandler, DatabricksBase):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_llm_provider: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key: Optional[str],
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
custom_endpoint: Optional[bool] = None,
|
||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||
fake_stream: bool = False,
|
||||
):
|
||||
messages = DatabricksConfig()._transform_messages(messages) # type: ignore
|
||||
api_base, headers = self.databricks_validate_environment(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
endpoint_type="chat_completions",
|
||||
custom_endpoint=custom_endpoint,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if optional_params.get("stream") is True:
|
||||
fake_stream = DatabricksConfig()._should_fake_stream(optional_params)
|
||||
else:
|
||||
fake_stream = False
|
||||
|
||||
return super().completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
acompletion=acompletion,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
custom_endpoint=True,
|
||||
streaming_decoder=streaming_decoder,
|
||||
fake_stream=fake_stream,
|
||||
)
|
|
@ -13,6 +13,7 @@ import httpx # type: ignore
|
|||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm import LlmProviders
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
|
@ -33,141 +34,17 @@ from litellm.types.utils import (
|
|||
GenericStreamingChunk,
|
||||
ProviderField,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
EmbeddingResponse,
|
||||
ModelResponse,
|
||||
ProviderConfigManager,
|
||||
Usage,
|
||||
)
|
||||
|
||||
from ..base import BaseLLM
|
||||
from ..prompt_templates.factory import custom_prompt, prompt_factory
|
||||
|
||||
|
||||
class DatabricksConfig:
|
||||
"""
|
||||
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
stop: Optional[Union[List[str], str]] = None
|
||||
n: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
stop: Optional[Union[List[str], str]] = None,
|
||||
n: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_required_params(self) -> List[ProviderField]:
|
||||
"""For a given provider, return it's required fields with a description"""
|
||||
return [
|
||||
ProviderField(
|
||||
field_name="api_key",
|
||||
field_type="string",
|
||||
field_description="Your Databricks API Key.",
|
||||
field_value="dapi...",
|
||||
),
|
||||
ProviderField(
|
||||
field_name="api_base",
|
||||
field_type="string",
|
||||
field_description="Your Databricks API Base.",
|
||||
field_value="https://adb-..",
|
||||
),
|
||||
]
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"stream",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"n",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "n":
|
||||
optional_params["n"] = value
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
class DatabricksEmbeddingConfig:
|
||||
"""
|
||||
Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
|
||||
"""
|
||||
|
||||
instruction: Optional[str] = (
|
||||
None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
|
||||
)
|
||||
|
||||
def __init__(self, instruction: Optional[str] = None) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(
|
||||
self,
|
||||
): # no optional openai embedding params supported
|
||||
return []
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
return optional_params
|
||||
from ...base import BaseLLM
|
||||
from ...prompt_templates.factory import custom_prompt, prompt_factory
|
||||
from .transformation import DatabricksConfig
|
||||
|
||||
|
||||
async def make_call(
|
||||
|
@ -477,6 +354,12 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
) # [TODO] add max retry support at llm api call level
|
||||
optional_params["stream"] = stream
|
||||
|
||||
if messages is not None and custom_llm_provider is not None:
|
||||
provider_config = ProviderConfigManager.get_provider_config(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
)
|
||||
messages = provider_config._transform_messages(messages)
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
143
litellm/llms/databricks/chat/transformation.py
Normal file
143
litellm/llms/databricks/chat/transformation.py
Normal file
|
@ -0,0 +1,143 @@
|
|||
"""
|
||||
Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions`
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ProviderField
|
||||
|
||||
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from ...prompt_templates.common_utils import (
|
||||
handle_messages_with_content_list_to_str_conversion,
|
||||
strip_name_from_messages,
|
||||
)
|
||||
|
||||
|
||||
class DatabricksConfig(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
stop: Optional[Union[List[str], str]] = None
|
||||
n: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
stop: Optional[Union[List[str], str]] = None,
|
||||
n: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_required_params(self) -> List[ProviderField]:
|
||||
"""For a given provider, return it's required fields with a description"""
|
||||
return [
|
||||
ProviderField(
|
||||
field_name="api_key",
|
||||
field_type="string",
|
||||
field_description="Your Databricks API Key.",
|
||||
field_value="dapi...",
|
||||
),
|
||||
ProviderField(
|
||||
field_name="api_base",
|
||||
field_type="string",
|
||||
field_description="Your Databricks API Base.",
|
||||
field_value="https://adb-..",
|
||||
),
|
||||
]
|
||||
|
||||
def get_supported_openai_params(self, model: Optional[str] = None) -> list:
|
||||
return [
|
||||
"stream",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"n",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def _should_fake_stream(self, optional_params: dict) -> bool:
|
||||
"""
|
||||
Databricks doesn't support 'response_format' while streaming
|
||||
"""
|
||||
if optional_params.get("response_format") is not None:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "n":
|
||||
optional_params["n"] = value
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
if param == "response_format":
|
||||
optional_params["response_format"] = value
|
||||
return optional_params
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Databricks does not support:
|
||||
- content in list format.
|
||||
- 'name' in user message.
|
||||
"""
|
||||
new_messages = []
|
||||
for idx, message in enumerate(messages):
|
||||
if isinstance(message, BaseModel):
|
||||
_message = message.model_dump()
|
||||
else:
|
||||
_message = message
|
||||
new_messages.append(_message)
|
||||
new_messages = handle_messages_with_content_list_to_str_conversion(new_messages)
|
||||
new_messages = strip_name_from_messages(new_messages)
|
||||
return super()._transform_messages(new_messages)
|
82
litellm/llms/databricks/common_utils.py
Normal file
82
litellm/llms/databricks/common_utils.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
from typing import Literal, Optional, Tuple
|
||||
|
||||
from .exceptions import DatabricksError
|
||||
|
||||
|
||||
class DatabricksBase:
|
||||
def _get_databricks_credentials(
|
||||
self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict]
|
||||
) -> Tuple[str, dict]:
|
||||
headers = headers or {"Content-Type": "application/json"}
|
||||
try:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
databricks_client = WorkspaceClient()
|
||||
|
||||
api_base = api_base or f"{databricks_client.config.host}/serving-endpoints"
|
||||
|
||||
if api_key is None:
|
||||
databricks_auth_headers: dict[str, str] = (
|
||||
databricks_client.config.authenticate()
|
||||
)
|
||||
headers = {**databricks_auth_headers, **headers}
|
||||
|
||||
return api_base, headers
|
||||
except ImportError:
|
||||
raise DatabricksError(
|
||||
status_code=400,
|
||||
message=(
|
||||
"If the Databricks base URL and API key are not set, the databricks-sdk "
|
||||
"Python library must be installed. Please install the databricks-sdk, set "
|
||||
"{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, "
|
||||
"or provide the base URL and API key as arguments."
|
||||
),
|
||||
)
|
||||
|
||||
def databricks_validate_environment(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
endpoint_type: Literal["chat_completions", "embeddings"],
|
||||
custom_endpoint: Optional[bool],
|
||||
headers: Optional[dict],
|
||||
) -> Tuple[str, dict]:
|
||||
if api_key is None and headers is None:
|
||||
if custom_endpoint is not None:
|
||||
raise DatabricksError(
|
||||
status_code=400,
|
||||
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
else:
|
||||
api_base, headers = self._get_databricks_credentials(
|
||||
api_base=api_base, api_key=api_key, headers=headers
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
if custom_endpoint:
|
||||
raise DatabricksError(
|
||||
status_code=400,
|
||||
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
else:
|
||||
api_base, headers = self._get_databricks_credentials(
|
||||
api_base=api_base, api_key=api_key, headers=headers
|
||||
)
|
||||
|
||||
if headers is None:
|
||||
headers = {
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
else:
|
||||
if api_key is not None:
|
||||
headers.update({"Authorization": "Bearer {}".format(api_key)})
|
||||
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
if endpoint_type == "chat_completions" and custom_endpoint is not True:
|
||||
api_base = "{}/chat/completions".format(api_base)
|
||||
elif endpoint_type == "embeddings" and custom_endpoint is not True:
|
||||
api_base = "{}/embeddings".format(api_base)
|
||||
return api_base, headers
|
50
litellm/llms/databricks/embed/handler.py
Normal file
50
litellm/llms/databricks/embed/handler.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
"""
|
||||
Calling logic for Databricks embeddings
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm.utils import EmbeddingResponse
|
||||
|
||||
from ...openai_like.embedding.handler import OpenAILikeEmbeddingHandler
|
||||
from ..common_utils import DatabricksBase
|
||||
|
||||
|
||||
class DatabricksEmbeddingHandler(OpenAILikeEmbeddingHandler, DatabricksBase):
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: list,
|
||||
timeout: float,
|
||||
logging_obj,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
optional_params: dict,
|
||||
model_response: Optional[litellm.utils.EmbeddingResponse] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
custom_endpoint: Optional[bool] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> EmbeddingResponse:
|
||||
api_base, headers = self.databricks_validate_environment(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
endpoint_type="embeddings",
|
||||
custom_endpoint=custom_endpoint,
|
||||
headers=headers,
|
||||
)
|
||||
return super().embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
optional_params=optional_params,
|
||||
model_response=model_response,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
custom_endpoint=True,
|
||||
headers=headers,
|
||||
)
|
48
litellm/llms/databricks/embed/transformation.py
Normal file
48
litellm/llms/databricks/embed/transformation.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
"""
|
||||
Translates from OpenAI's `/v1/embeddings` to Databricks' `/embeddings`
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class DatabricksEmbeddingConfig:
|
||||
"""
|
||||
Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
|
||||
"""
|
||||
|
||||
instruction: Optional[str] = (
|
||||
None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
|
||||
)
|
||||
|
||||
def __init__(self, instruction: Optional[str] = None) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(
|
||||
self,
|
||||
): # no optional openai embedding params supported
|
||||
return []
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
return optional_params
|
|
@ -37,6 +37,22 @@ def handle_messages_with_content_list_to_str_conversion(
|
|||
return messages
|
||||
|
||||
|
||||
def strip_name_from_messages(
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Removes 'name' from messages
|
||||
"""
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
msg_role = message.get("role")
|
||||
msg_copy = message.copy()
|
||||
if msg_role == "user":
|
||||
msg_copy.pop("name", None) # type: ignore
|
||||
new_messages.append(msg_copy)
|
||||
return new_messages
|
||||
|
||||
|
||||
def convert_content_list_to_str(message: AllMessageValues) -> str:
|
||||
"""
|
||||
- handles scenario where content is list and not string
|
||||
|
|
|
@ -273,7 +273,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
model_id = optional_params.get("model_id", None)
|
||||
|
||||
if use_messages_api is True:
|
||||
from litellm.llms.databricks.chat import DatabricksChatCompletion
|
||||
from litellm.llms.databricks.chat.handler import DatabricksChatCompletion
|
||||
|
||||
openai_like_chat_completions = DatabricksChatCompletion()
|
||||
inference_params["stream"] = True if stream is True else False
|
||||
|
|
|
@ -90,7 +90,7 @@ class VertexAIPartnerModels(VertexBase):
|
|||
from google.cloud import aiplatform
|
||||
|
||||
from litellm.llms.anthropic.chat import AnthropicChatCompletion
|
||||
from litellm.llms.databricks.chat import DatabricksChatCompletion
|
||||
from litellm.llms.databricks.chat.handler import DatabricksChatCompletion
|
||||
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
|
||||
from litellm.llms.text_completion_codestral import CodestralTextCompletion
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
|
|
|
@ -76,7 +76,7 @@ class VertexAIModelGardenModels(VertexBase):
|
|||
from google.cloud import aiplatform
|
||||
|
||||
from litellm.llms.anthropic.chat import AnthropicChatCompletion
|
||||
from litellm.llms.databricks.chat import DatabricksChatCompletion
|
||||
from litellm.llms.databricks.chat.handler import DatabricksChatCompletion
|
||||
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
|
||||
from litellm.llms.text_completion_codestral import CodestralTextCompletion
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
|
|
|
@ -115,7 +115,8 @@ from .llms.cohere import chat as cohere_chat
|
|||
from .llms.cohere import completion as cohere_completion # type: ignore
|
||||
from .llms.cohere.embed import handler as cohere_embed
|
||||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||
from .llms.databricks.chat import DatabricksChatCompletion
|
||||
from .llms.databricks.chat.handler import DatabricksChatCompletion
|
||||
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
|
||||
from .llms.groq.chat.handler import GroqChatCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription
|
||||
|
@ -230,6 +231,7 @@ watsonxai = IBMWatsonXAI()
|
|||
sagemaker_llm = SagemakerLLM()
|
||||
watsonx_chat_completion = WatsonXChatHandler()
|
||||
openai_like_embedding = OpenAILikeEmbeddingHandler()
|
||||
databricks_embedding = DatabricksEmbeddingHandler()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -3475,7 +3477,7 @@ def embedding( # noqa: PLR0915
|
|||
) # type: ignore
|
||||
|
||||
## EMBEDDING CALL
|
||||
response = databricks_chat_completions.embedding(
|
||||
response = databricks_embedding.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
api_base=api_base,
|
||||
|
|
|
@ -3418,7 +3418,14 @@ def get_optional_params( # noqa: PLR0915
|
|||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.DatabricksConfig().map_openai_params(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "nvidia_nim":
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -6182,6 +6189,8 @@ class ProviderConfigManager:
|
|||
return litellm.DeepSeekChatConfig()
|
||||
elif litellm.LlmProviders.GROQ == provider:
|
||||
return litellm.GroqChatConfig()
|
||||
elif litellm.LlmProviders.DATABRICKS == provider:
|
||||
return litellm.DatabricksConfig()
|
||||
|
||||
return OpenAIGPTConfig()
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ class BaseLLMChatTest(ABC):
|
|||
assert response.choices[0].message.content is not None
|
||||
|
||||
def test_message_with_name(self):
|
||||
litellm.set_verbose = True
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello", "name": "test_name"},
|
||||
|
@ -69,6 +70,7 @@ class BaseLLMChatTest(ABC):
|
|||
{"type": "text"},
|
||||
],
|
||||
)
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
def test_json_response_format(self, response_format):
|
||||
"""
|
||||
Test that the JSON response format is supported by the LLM API
|
||||
|
|
|
@ -4,7 +4,8 @@ import json
|
|||
import pytest
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch, ANY
|
||||
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
|
@ -14,6 +15,7 @@ import litellm
|
|||
from litellm.exceptions import BadRequestError
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
|
||||
try:
|
||||
import databricks.sdk
|
||||
|
@ -333,6 +335,7 @@ def test_completions_with_async_http_handler(monkeypatch):
|
|||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=ANY,
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "dbrx-instruct-071224",
|
||||
|
@ -376,17 +379,21 @@ def test_completions_streaming_with_sync_http_handler(monkeypatch):
|
|||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
data=ANY,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
actual_data = json.loads(
|
||||
mock_post.call_args.kwargs["data"]
|
||||
) # Deserialize the actual data
|
||||
expected_data = {
|
||||
"model": "dbrx-instruct-071224",
|
||||
"messages": messages,
|
||||
"temperature": 0.5,
|
||||
"stream": True,
|
||||
"extraparam": "testpassingextraparam",
|
||||
}
|
||||
),
|
||||
stream=True,
|
||||
)
|
||||
assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}"
|
||||
|
||||
|
||||
def test_completions_streaming_with_async_http_handler(monkeypatch):
|
||||
|
@ -429,21 +436,27 @@ def test_completions_streaming_with_async_http_handler(monkeypatch):
|
|||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
data=ANY,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
actual_data = json.loads(
|
||||
mock_post.call_args.kwargs["data"]
|
||||
) # Deserialize the actual data
|
||||
expected_data = {
|
||||
"model": "dbrx-instruct-071224",
|
||||
"messages": messages,
|
||||
"temperature": 0.5,
|
||||
"stream": True,
|
||||
"extraparam": "testpassingextraparam",
|
||||
}
|
||||
),
|
||||
stream=True,
|
||||
)
|
||||
assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
|
||||
def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
|
||||
monkeypatch.delenv("DATABRICKS_API_BASE")
|
||||
monkeypatch.delenv("DATABRICKS_API_KEY")
|
||||
from databricks.sdk import WorkspaceClient
|
||||
from databricks.sdk.config import Config
|
||||
|
||||
|
@ -637,3 +650,48 @@ def test_embeddings_uses_databricks_sdk_if_api_key_and_base_not_specified(monkey
|
|||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestDatabricksCompletion(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
return {"model": "databricks/databricks-dbrx-instruct"}
|
||||
|
||||
def test_pdf_handling(self, pdf_messages):
|
||||
pytest.skip("Databricks does not support PDF handling")
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pytest.skip("Databricks is openai compatible")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_databricks_embeddings(sync_mode):
|
||||
import openai
|
||||
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
litellm.drop_params = True
|
||||
|
||||
if sync_mode:
|
||||
response = litellm.embedding(
|
||||
model="databricks/databricks-bge-large-en",
|
||||
input=["good morning from litellm"],
|
||||
instruction="Represent this sentence for searching relevant passages:",
|
||||
)
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
model="databricks/databricks-bge-large-en",
|
||||
input=["good morning from litellm"],
|
||||
instruction="Represent this sentence for searching relevant passages:",
|
||||
)
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
openai.types.CreateEmbeddingResponse.model_validate(
|
||||
response.model_dump(), strict=True
|
||||
)
|
||||
# stubbed endpoint is setup to return this
|
||||
# assert response.data[0]["embedding"] == [0.1, 0.2, 0.3]
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
|
|
@ -292,7 +292,7 @@ def test_all_model_configs():
|
|||
optional_params={},
|
||||
) == {"max_tokens_to_sample": 10}
|
||||
|
||||
from litellm.llms.databricks.chat import DatabricksConfig
|
||||
from litellm.llms.databricks.chat.handler import DatabricksConfig
|
||||
|
||||
assert "max_completion_tokens" in DatabricksConfig().get_supported_openai_params()
|
||||
|
||||
|
|
|
@ -932,37 +932,6 @@ async def test_gemini_embeddings(sync_mode, input):
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_databricks_embeddings(sync_mode):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
litellm.drop_params = True
|
||||
|
||||
if sync_mode:
|
||||
response = litellm.embedding(
|
||||
model="databricks/databricks-bge-large-en",
|
||||
input=["good morning from litellm"],
|
||||
instruction="Represent this sentence for searching relevant passages:",
|
||||
)
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
model="databricks/databricks-bge-large-en",
|
||||
input=["good morning from litellm"],
|
||||
instruction="Represent this sentence for searching relevant passages:",
|
||||
)
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
openai.types.CreateEmbeddingResponse.model_validate(
|
||||
response.model_dump(), strict=True
|
||||
)
|
||||
# stubbed endpoint is setup to return this
|
||||
# assert response.data[0]["embedding"] == [0.1, 0.2, 0.3]
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_voyage_embeddings()
|
||||
# def test_xinference_embeddings():
|
||||
# try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue