feat: (re-)enable Databricks inference adapter (#3500)

# What does this PR do?

add/enable the Databricks inference adapter

Databricks inference adapter was broken, closes #3486 

- remove deprecated completion / chat_completion endpoints
- enable dynamic model listing w/o refresh, listing is not async
- use SecretStr instead of str for token
- backward incompatible change: for consistency with databricks docs,
env DATABRICKS_URL -> DATABRICKS_HOST and DATABRICKS_API_TOKEN ->
DATABRICKS_TOKEN
- databricks urls are custom per user/org, add special recorder handling
for databricks urls
- add integration test --setup databricks
- enable chat completions tests
- enable embeddings tests
- disable n > 1 tests
- disable embeddings base64 tests
- disable embeddings dimensions tests

note: reasoning models, e.g. gpt oss, fail because databricks has a
custom, incompatible response format

## Test Plan

ci and 

```
./scripts/integration-tests.sh --stack-config server:ci-tests --setup databricks --subdirs inference --pattern openai
```

note: databricks needs to be manually added to the ci-tests distro for
replay testing
This commit is contained in:
Matthew Farrellee 2025-09-23 15:37:23 -04:00 committed by GitHub
parent 9406a998b9
commit d07ebce4d9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 11649 additions and 101 deletions

View file

@ -4,23 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from typing import Any
from openai import OpenAI
from databricks.sdk import WorkspaceClient
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
OpenAICompletion,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -29,49 +32,50 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
ProviderModelEntry,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import DatabricksImplConfig
SAFETY_MODELS_ENTRIES = []
logger = get_logger(name=__name__, category="inference::databricks")
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"databricks-meta-llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
EMBEDDING_MODEL_ENTRIES = {
"databricks-gte-large-en": ProviderModelEntry(
provider_model_id="databricks-gte-large-en",
metadata={
"embedding_dimension": 1024,
"context_length": 8192,
},
),
build_hf_repo_model_entry(
"databricks-meta-llama-3-1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
"databricks-bge-large-en": ProviderModelEntry(
provider_model_id="databricks-bge-large-en",
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
] + SAFETY_MODELS_ENTRIES
}
class DatabricksInferenceAdapter(
ModelRegistryHelper,
OpenAIMixin,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config
def get_api_key(self) -> str:
return self.config.api_token.get_secret_value()
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
async def initialize(self) -> None:
return
@ -80,72 +84,54 @@ class DatabricksInferenceAdapter(
async def completion(
self,
model: str,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
raise NotImplementedError()
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
raise NotImplementedError()
async def chat_completion(
self,
model: str,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream:
return self._stream_chat_completion(request, client)
else:
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": request.model,
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
raise NotImplementedError()
async def embeddings(
self,
@ -157,12 +143,39 @@ class DatabricksInferenceAdapter(
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def list_models(self) -> list[Model] | None:
self._model_cache = {} # from OpenAIMixin
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
endpoints = ws_client.serving_endpoints.list()
for endpoint in endpoints:
model = Model(
provider_id=self.__provider_id__,
provider_resource_id=endpoint.name,
identifier=endpoint.name,
)
if endpoint.task == "llm/v1/chat":
model.model_type = ModelType.llm # this is redundant, but informative
elif endpoint.task == "llm/v1/embeddings":
if endpoint.name not in EMBEDDING_MODEL_ENTRIES:
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
continue
model.model_type = ModelType.embedding
model.metadata = EMBEDDING_MODEL_ENTRIES[endpoint.name].metadata
else:
logger.warning(f"Unknown model type, skipping: {endpoint}")
continue
self._model_cache[endpoint.name] = model
return list(self._model_cache.values())
async def register_model(self, model: Model) -> Model:
if not await self.check_model_availability(model.provider_resource_id):
raise ValueError(f"Model {model.provider_resource_id} is not available in Databricks workspace.")
return model
async def unregister_model(self, model_id: str) -> None:
pass
async def should_refresh_models(self) -> bool:
return False