mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(databricks/chat): support structured outputs on databricks
Closes https://github.com/BerriAI/litellm/pull/6978 - handles content as list for dbrx, - handles streaming+response_format for dbrx
This commit is contained in:
parent
12aea45447
commit
0caf804f4c
18 changed files with 538 additions and 193 deletions
82
litellm/llms/databricks/common_utils.py
Normal file
82
litellm/llms/databricks/common_utils.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
from typing import Literal, Optional, Tuple
|
||||
|
||||
from .exceptions import DatabricksError
|
||||
|
||||
|
||||
class DatabricksBase:
|
||||
def _get_databricks_credentials(
|
||||
self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict]
|
||||
) -> Tuple[str, dict]:
|
||||
headers = headers or {"Content-Type": "application/json"}
|
||||
try:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
databricks_client = WorkspaceClient()
|
||||
|
||||
api_base = api_base or f"{databricks_client.config.host}/serving-endpoints"
|
||||
|
||||
if api_key is None:
|
||||
databricks_auth_headers: dict[str, str] = (
|
||||
databricks_client.config.authenticate()
|
||||
)
|
||||
headers = {**databricks_auth_headers, **headers}
|
||||
|
||||
return api_base, headers
|
||||
except ImportError:
|
||||
raise DatabricksError(
|
||||
status_code=400,
|
||||
message=(
|
||||
"If the Databricks base URL and API key are not set, the databricks-sdk "
|
||||
"Python library must be installed. Please install the databricks-sdk, set "
|
||||
"{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, "
|
||||
"or provide the base URL and API key as arguments."
|
||||
),
|
||||
)
|
||||
|
||||
def databricks_validate_environment(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
endpoint_type: Literal["chat_completions", "embeddings"],
|
||||
custom_endpoint: Optional[bool],
|
||||
headers: Optional[dict],
|
||||
) -> Tuple[str, dict]:
|
||||
if api_key is None and headers is None:
|
||||
if custom_endpoint is not None:
|
||||
raise DatabricksError(
|
||||
status_code=400,
|
||||
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
else:
|
||||
api_base, headers = self._get_databricks_credentials(
|
||||
api_base=api_base, api_key=api_key, headers=headers
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
if custom_endpoint:
|
||||
raise DatabricksError(
|
||||
status_code=400,
|
||||
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
else:
|
||||
api_base, headers = self._get_databricks_credentials(
|
||||
api_base=api_base, api_key=api_key, headers=headers
|
||||
)
|
||||
|
||||
if headers is None:
|
||||
headers = {
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
else:
|
||||
if api_key is not None:
|
||||
headers.update({"Authorization": "Bearer {}".format(api_key)})
|
||||
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
if endpoint_type == "chat_completions" and custom_endpoint is not True:
|
||||
api_base = "{}/chat/completions".format(api_base)
|
||||
elif endpoint_type == "embeddings" and custom_endpoint is not True:
|
||||
api_base = "{}/embeddings".format(api_base)
|
||||
return api_base, headers
|
Loading…
Add table
Add a link
Reference in a new issue