feat: (re-)enable Databricks inference adapter

Databricks inference adapter was broken, would not start, see #3486

- remove deprecated completion / chat_completion endpoints
- enable dynamic model listing w/o refresh, listing is not async
- use SecretStr instead of str for token
- backward incompatible change: for consistency with databricks docs, env DATABRICKS_URL -> DATABRICKS_HOST and DATABRICKS_API_TOKEN -> DATABRICKS_TOKEN
- databricks urls are custom per user/org, add special recorder handling for databricks urls
- add integration test --setup databricks
- enable chat completions tests
- enable embeddings tests
- disable n > 1 tests
- disable embeddings base64 tests
- disable embeddings dimensions tests

note: reasoning models, e.g. gpt oss, fail because databricks has a custom, incompatible response format

test with: ./scripts/integration-tests.sh --stack-config server:ci-tests --setup databricks --subdirs inference --pattern openai

note: databricks needs to be manually added to the ci-tests distro for replay testing
This commit is contained in:
Matthew Farrellee 2025-09-20 05:05:05 -04:00
parent e66103c09d
commit ae804ed5a8
25 changed files with 11650 additions and 102 deletions

View file

@ -9,13 +9,13 @@ Databricks inference provider for running models on Databricks' unified analytic
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `url` | `<class 'str'>` | No | | The URL for the Databricks model serving endpoint |
| `api_token` | `<class 'str'>` | No | | The Databricks API token |
| `api_token` | `<class 'pydantic.types.SecretStr'>` | No | | The Databricks API token |
## Sample Configuration
```yaml
url: ${env.DATABRICKS_URL:=}
api_token: ${env.DATABRICKS_API_TOKEN:=}
url: ${env.DATABRICKS_HOST:=}
api_token: ${env.DATABRICKS_TOKEN:=}
```

View file

@ -152,7 +152,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter=AdapterSpec(
adapter_type="databricks",
pip_packages=[],
pip_packages=["databricks-sdk"],
module="llama_stack.providers.remote.inference.databricks",
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
description="Databricks inference provider for running models on Databricks' unified analytics platform.",

View file

@ -5,10 +5,11 @@
# the root directory of this source tree.
from .config import DatabricksImplConfig
from .databricks import DatabricksInferenceAdapter
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
from .databricks import DatabricksInferenceAdapter
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
await impl.initialize()

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
@ -17,16 +17,16 @@ class DatabricksImplConfig(BaseModel):
default=None,
description="The URL for the Databricks model serving endpoint",
)
api_token: str = Field(
default=None,
api_token: SecretStr = Field(
default=SecretStr(None),
description="The Databricks API token",
)
@classmethod
def sample_run_config(
cls,
url: str = "${env.DATABRICKS_URL:=}",
api_token: str = "${env.DATABRICKS_API_TOKEN:=}",
url: str = "${env.DATABRICKS_HOST:=}",
api_token: str = "${env.DATABRICKS_TOKEN:=}",
**kwargs: Any,
) -> dict[str, Any]:
return {

View file

@ -4,23 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from typing import Any
from openai import OpenAI
from databricks.sdk import WorkspaceClient
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
OpenAICompletion,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -29,49 +32,50 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
ProviderModelEntry,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import DatabricksImplConfig
SAFETY_MODELS_ENTRIES = []
logger = get_logger(name=__name__, category="inference::databricks")
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"databricks-meta-llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
EMBEDDING_MODEL_ENTRIES = {
"databricks-gte-large-en": ProviderModelEntry(
provider_model_id="databricks-gte-large-en",
metadata={
"embedding_dimension": 1024,
"context_length": 8192,
},
),
build_hf_repo_model_entry(
"databricks-meta-llama-3-1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
"databricks-bge-large-en": ProviderModelEntry(
provider_model_id="databricks-bge-large-en",
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
] + SAFETY_MODELS_ENTRIES
}
class DatabricksInferenceAdapter(
ModelRegistryHelper,
OpenAIMixin,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config
def get_api_key(self) -> str:
return self.config.api_token.get_secret_value()
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
async def initialize(self) -> None:
return
@ -80,72 +84,54 @@ class DatabricksInferenceAdapter(
async def completion(
self,
model: str,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
raise NotImplementedError()
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
raise NotImplementedError()
async def chat_completion(
self,
model: str,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream:
return self._stream_chat_completion(request, client)
else:
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": request.model,
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
raise NotImplementedError()
async def embeddings(
self,
@ -157,12 +143,39 @@ class DatabricksInferenceAdapter(
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def list_models(self) -> list[Model] | None:
self._model_cache = {} # from OpenAIMixin
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
endpoints = ws_client.serving_endpoints.list()
for endpoint in endpoints:
model = Model(
provider_id=self.__provider_id__,
provider_resource_id=endpoint.name,
identifier=endpoint.name,
)
if endpoint.task == "llm/v1/chat":
model.model_type = ModelType.llm # this is redundant, but informative
elif endpoint.task == "llm/v1/embeddings":
if endpoint.name not in EMBEDDING_MODEL_ENTRIES:
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
continue
model.model_type = ModelType.embedding
model.metadata = EMBEDDING_MODEL_ENTRIES[endpoint.name].metadata
else:
logger.warning(f"Unknown model type, skipping: {endpoint}")
continue
self._model_cache[endpoint.name] = model
return list(self._model_cache.values())
async def register_model(self, model: Model) -> Model:
if not await self.check_model_availability(model.provider_resource_id):
raise ValueError(f"Model {model.provider_resource_id} is not available in Databricks workspace.")
return model
async def unregister_model(self, model_id: str) -> None:
pass
async def should_refresh_models(self) -> bool:
return False

View file

@ -296,7 +296,7 @@ class OpenAIMixin(ABC):
return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
model=model,
usage=usage,
)

View file

@ -262,6 +262,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
raise ValueError(f"Unknown client type: {client_type}")
url = base_url.rstrip("/") + endpoint
# Special handling for Databricks URLs to avoid leaking workspace info
# e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com
if "cloud.databricks.com" in url:
url = "__databricks__" + url.split("cloud.databricks.com")[-1]
method = "POST"
headers = {}
body = kwargs

View file

@ -98,6 +98,7 @@ def skip_if_doesnt_support_n(client_with_models, model_id):
# the entered value was 2. Update the candidateCount value and try again.', 'status': 'INVALID_ARGUMENT'}
"remote::tgi", # TGI ignores n param silently
"remote::together", # `n` > 1 is not supported when streaming tokens. Please disable `stream`
"remote::databricks", # Bad request: parameter "n" must be equal to 1 for streaming mode
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
@ -110,7 +111,6 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode
"inline::vllm",
"remote::bedrock",
"remote::cerebras",
"remote::databricks",
"remote::runpod",
"remote::watsonx", # watsonx returns 404 when hitting the /openai/v1 endpoint
):

View file

@ -41,6 +41,7 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
provider = provider_from_model(client, model_id)
if provider.provider_type in (
"remote::together", # param silently ignored, always returns floats
"remote::databricks", # param silently ignored, always returns floats
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
@ -50,6 +51,8 @@ def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_i
if provider.provider_type in (
"remote::together", # returns 400
"inline::sentence-transformers",
# Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Bad request: json: unknown field "dimensions"\n'}
"remote::databricks",
):
pytest.skip(
f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions."
@ -73,7 +76,6 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id):
"inline::meta-reference",
"remote::bedrock",
"remote::cerebras",
"remote::databricks",
"remote::runpod",
"remote::sambanova",
"remote::tgi",

View file

@ -0,0 +1,728 @@
{
"request": {
"method": "POST",
"url": "__databricks__/serving-endpoints/v1/chat/completions",
"headers": {},
"body": {
"model": "databricks-meta-llama-3-3-70b-instruct",
"messages": [
{
"role": "user",
"content": "Hello, world!"
}
],
"stream": true
},
"endpoint": "/v1/chat/completions",
"model": "databricks-meta-llama-3-3-70b-instruct"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 3,
"prompt_tokens": 14,
"total_tokens": 17,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "Hello! ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 3,
"prompt_tokens": 14,
"total_tokens": 17,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "It's ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 5,
"prompt_tokens": 14,
"total_tokens": 19,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "nice ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 6,
"prompt_tokens": 14,
"total_tokens": 20,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "to ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 7,
"prompt_tokens": 14,
"total_tokens": 21,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "meet ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 8,
"prompt_tokens": 14,
"total_tokens": 22,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "you. ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 10,
"prompt_tokens": 14,
"total_tokens": 24,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "Is ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 11,
"prompt_tokens": 14,
"total_tokens": 25,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "there ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 12,
"prompt_tokens": 14,
"total_tokens": 26,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "something ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 13,
"prompt_tokens": 14,
"total_tokens": 27,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "I ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 14,
"prompt_tokens": 14,
"total_tokens": 28,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "can ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 15,
"prompt_tokens": 14,
"total_tokens": 29,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "help ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 16,
"prompt_tokens": 14,
"total_tokens": 30,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "you ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 17,
"prompt_tokens": 14,
"total_tokens": 31,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "with ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 18,
"prompt_tokens": 14,
"total_tokens": 32,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "or ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 19,
"prompt_tokens": 14,
"total_tokens": 33,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "would ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 20,
"prompt_tokens": 14,
"total_tokens": 34,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "you ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 21,
"prompt_tokens": 14,
"total_tokens": 35,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "like ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 22,
"prompt_tokens": 14,
"total_tokens": 36,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "to ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 23,
"prompt_tokens": 14,
"total_tokens": 37,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "chat?",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 25,
"prompt_tokens": 14,
"total_tokens": 39,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_7268e4ee-3b8e-461e-80dc-608e76f3801d",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1758326500,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 25,
"prompt_tokens": 14,
"total_tokens": 39,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
}
],
"is_streaming": true
}
}

View file

@ -0,0 +1,56 @@
{
"request": {
"method": "POST",
"url": "__databricks__/serving-endpoints/v1/chat/completions",
"headers": {},
"body": {
"model": "databricks-meta-llama-3-3-70b-instruct",
"messages": [
{
"role": "user",
"content": "Hello, world!"
}
],
"stream": false
},
"endpoint": "/v1/chat/completions",
"model": "databricks-meta-llama-3-3-70b-instruct"
},
"response": {
"body": {
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
"__data__": {
"id": "chatcmpl_52eec823-4235-473d-b25a-f0af4ebd4837",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "Hello! It's great to meet you. Is there something I can help you with, or would you like to chat?",
"refusal": null,
"role": "assistant",
"annotations": null,
"audio": null,
"function_call": null,
"tool_calls": null
}
}
],
"created": 1758326506,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 26,
"prompt_tokens": 14,
"total_tokens": 40,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

@ -0,0 +1,56 @@
{
"request": {
"method": "POST",
"url": "__databricks__/serving-endpoints/v1/chat/completions",
"headers": {},
"body": {
"model": "databricks-meta-llama-3-3-70b-instruct",
"messages": [
{
"role": "user",
"content": "Which planet do humans live on?"
}
],
"stream": false
},
"endpoint": "/v1/chat/completions",
"model": "databricks-meta-llama-3-3-70b-instruct"
},
"response": {
"body": {
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
"__data__": {
"id": "chatcmpl_e846ea96-9636-4eb4-bde4-84510478617b",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "Humans live on the planet Earth.",
"refusal": null,
"role": "assistant",
"annotations": null,
"audio": null,
"function_call": null,
"tool_calls": null
}
}
],
"created": 1758326497,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 8,
"prompt_tokens": 17,
"total_tokens": 25,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

@ -0,0 +1,56 @@
{
"request": {
"method": "POST",
"url": "__databricks__/serving-endpoints/v1/chat/completions",
"headers": {},
"body": {
"model": "databricks-meta-llama-3-3-70b-instruct",
"messages": [
{
"role": "user",
"content": "Which planet has rings around it with a name starting with letter S?"
}
],
"stream": false
},
"endpoint": "/v1/chat/completions",
"model": "databricks-meta-llama-3-3-70b-instruct"
},
"response": {
"body": {
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
"__data__": {
"id": "chatcmpl_094a74d8-2e39-45ce-8eb9-64d505bd24e9",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The answer is Saturn! Saturn is a planet in our solar system that is known for its stunning ring system. The rings of Saturn are made up of ice and rock particles that range in size from tiny dust grains to massive boulders. They are a beautiful sight to behold, and astronomers and space enthusiasts alike have been fascinated by them for centuries.\n\nSo, the planet with rings around it with a name starting with the letter S is indeed Saturn!",
"refusal": null,
"role": "assistant",
"annotations": null,
"audio": null,
"function_call": null,
"tool_calls": null
}
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 91,
"prompt_tokens": 24,
"total_tokens": 115,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

@ -0,0 +1,344 @@
{
"request": {
"method": "POST",
"url": "__databricks__/serving-endpoints/v1/chat/completions",
"headers": {},
"body": {
"model": "databricks-meta-llama-3-3-70b-instruct",
"messages": [
{
"role": "user",
"content": "What's the name of the Sun in latin?"
}
],
"stream": true
},
"endpoint": "/v1/chat/completions",
"model": "databricks-meta-llama-3-3-70b-instruct"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_2c653de2-afd4-4075-bc8d-8200562a191b",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326497,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 2,
"prompt_tokens": 20,
"total_tokens": 22,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_2c653de2-afd4-4075-bc8d-8200562a191b",
"choices": [
{
"delta": {
"content": "The ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326497,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 2,
"prompt_tokens": 20,
"total_tokens": 22,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_2c653de2-afd4-4075-bc8d-8200562a191b",
"choices": [
{
"delta": {
"content": "Latin ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326497,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 3,
"prompt_tokens": 20,
"total_tokens": 23,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_2c653de2-afd4-4075-bc8d-8200562a191b",
"choices": [
{
"delta": {
"content": "name ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326497,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 4,
"prompt_tokens": 20,
"total_tokens": 24,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_2c653de2-afd4-4075-bc8d-8200562a191b",
"choices": [
{
"delta": {
"content": "for ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326497,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 5,
"prompt_tokens": 20,
"total_tokens": 25,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_2c653de2-afd4-4075-bc8d-8200562a191b",
"choices": [
{
"delta": {
"content": "the ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326497,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 6,
"prompt_tokens": 20,
"total_tokens": 26,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_2c653de2-afd4-4075-bc8d-8200562a191b",
"choices": [
{
"delta": {
"content": "Sun ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326497,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 7,
"prompt_tokens": 20,
"total_tokens": 27,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_2c653de2-afd4-4075-bc8d-8200562a191b",
"choices": [
{
"delta": {
"content": "is ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326497,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 8,
"prompt_tokens": 20,
"total_tokens": 28,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_2c653de2-afd4-4075-bc8d-8200562a191b",
"choices": [
{
"delta": {
"content": "\"Sol\".",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326498,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 11,
"prompt_tokens": 20,
"total_tokens": 31,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_2c653de2-afd4-4075-bc8d-8200562a191b",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1758326498,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 11,
"prompt_tokens": 20,
"total_tokens": 31,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
}
],
"is_streaming": true
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,83 @@
{
"request": {
"method": "POST",
"url": "__databricks__/serving-endpoints/v1/chat/completions",
"headers": {},
"body": {
"model": "databricks-meta-llama-3-3-70b-instruct",
"messages": [
{
"role": "user",
"content": "What's the weather in Tokyo? Use the get_weather function to get the weather."
}
],
"stream": false,
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get the weather for"
}
}
}
}
}
]
},
"endpoint": "/v1/chat/completions",
"model": "databricks-meta-llama-3-3-70b-instruct"
},
"response": {
"body": {
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
"__data__": {
"id": "chatcmpl_e54eaa97-ace3-4af6-b3a2-b1627bc77488",
"choices": [
{
"finish_reason": "tool_calls",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"refusal": null,
"role": "assistant",
"annotations": null,
"audio": null,
"function_call": null,
"tool_calls": [
{
"id": "call_9c7f9e5f-c6eb-4c3c-a7b3-e9fe0e786b50",
"function": {
"arguments": "{ \"city\": \"Tokyo\" }",
"name": "get_weather"
},
"type": "function"
}
]
}
}
],
"created": 1758326507,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 15,
"prompt_tokens": 682,
"total_tokens": 697,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,168 @@
{
"request": {
"method": "POST",
"url": "__databricks__/serving-endpoints/v1/chat/completions",
"headers": {},
"body": {
"model": "databricks-meta-llama-3-3-70b-instruct",
"messages": [
{
"role": "user",
"content": "What's the weather in Tokyo? Use the get_weather function to get the weather."
}
],
"stream": true,
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get the weather for"
}
}
}
}
}
]
},
"endpoint": "/v1/chat/completions",
"model": "databricks-meta-llama-3-3-70b-instruct"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_4c3ae1bf-991d-4266-a12d-b1e97ecbb7a0",
"choices": [
{
"delta": {
"content": null,
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"index": 0,
"id": "call_87aed80e-f856-468f-9523-52db3018d83d",
"function": {
"arguments": "",
"name": "get_weather"
},
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326502,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 15,
"prompt_tokens": 682,
"total_tokens": 697,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_4c3ae1bf-991d-4266-a12d-b1e97ecbb7a0",
"choices": [
{
"delta": {
"content": null,
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": [
{
"index": 0,
"id": null,
"function": {
"arguments": "{ \"city\": \"Tokyo\" }",
"name": null
},
"type": null
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326502,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 15,
"prompt_tokens": 682,
"total_tokens": 697,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_4c3ae1bf-991d-4266-a12d-b1e97ecbb7a0",
"choices": [
{
"delta": {
"content": null,
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": [
{
"index": 0,
"id": null,
"function": {
"arguments": "",
"name": null
},
"type": null
}
]
},
"finish_reason": "tool_calls",
"index": 0,
"logprobs": null
}
],
"created": 1758326502,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 15,
"prompt_tokens": 682,
"total_tokens": 697,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
}
],
"is_streaming": true
}
}

View file

@ -0,0 +1,536 @@
{
"request": {
"method": "POST",
"url": "__databricks__/serving-endpoints/v1/chat/completions",
"headers": {},
"body": {
"model": "databricks-meta-llama-3-3-70b-instruct",
"messages": [
{
"role": "user",
"content": "What is the name of the US captial?"
}
],
"stream": true
},
"endpoint": "/v1/chat/completions",
"model": "databricks-meta-llama-3-3-70b-instruct"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 2,
"prompt_tokens": 20,
"total_tokens": 22,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "The ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 2,
"prompt_tokens": 20,
"total_tokens": 22,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "capital ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 3,
"prompt_tokens": 20,
"total_tokens": 23,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "of ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 4,
"prompt_tokens": 20,
"total_tokens": 24,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "the ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 5,
"prompt_tokens": 20,
"total_tokens": 25,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "United ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 6,
"prompt_tokens": 20,
"total_tokens": 26,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "States ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 7,
"prompt_tokens": 20,
"total_tokens": 27,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "is ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 8,
"prompt_tokens": 20,
"total_tokens": 28,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "Washington, ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 10,
"prompt_tokens": 20,
"total_tokens": 30,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "D.C. ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 13,
"prompt_tokens": 20,
"total_tokens": 33,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "(short ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 15,
"prompt_tokens": 20,
"total_tokens": 35,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "for ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 16,
"prompt_tokens": 20,
"total_tokens": 36,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "District ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 17,
"prompt_tokens": 20,
"total_tokens": 37,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "of ",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 18,
"prompt_tokens": 20,
"total_tokens": 38,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "Columbia).",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 20,
"prompt_tokens": 20,
"total_tokens": 40,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl_40266680-5422-4e7a-bc40-74eb1efdafbc",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": null,
"tool_calls": null
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1758326504,
"model": "meta-llama-3.3-70b-instruct-121024",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": null,
"usage": {
"completion_tokens": 20,
"prompt_tokens": 20,
"total_tokens": 40,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
}
],
"is_streaming": true
}
}

File diff suppressed because it is too large Load diff

View file

@ -108,6 +108,14 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
"embedding_model": "together/togethercomputer/m2-bert-80M-32k-retrieval",
},
),
"databricks": Setup(
name="databricks",
description="Databricks models",
defaults={
"text_model": "databricks/databricks-meta-llama-3-3-70b-instruct",
"embedding_model": "databricks/databricks-bge-large-en",
},
),
}