Merge branch 'main' into use-openai-for-databricks

This commit is contained in:
Matthew Farrellee 2025-09-23 07:32:33 -04:00
commit 46ae101ca1
13 changed files with 815 additions and 1140 deletions

View file

@ -24,7 +24,7 @@ jobs:
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@557e51de59eb14aaaba2ed9621916900a91d50c6 # v6.6.1 uses: astral-sh/setup-uv@b75a909f75acd358c2196fb9a5f1299a9a8868a4 # v6.7.0
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
activate-environment: true activate-environment: true

View file

@ -147,7 +147,7 @@ WORKDIR /app
RUN dnf -y update && dnf install -y iputils git net-tools wget \ RUN dnf -y update && dnf install -y iputils git net-tools wget \
vim-minimal python3.12 python3.12-pip python3.12-wheel \ vim-minimal python3.12 python3.12-pip python3.12-wheel \
python3.12-setuptools python3.12-devel gcc make && \ python3.12-setuptools python3.12-devel gcc gcc-c++ make && \
ln -s /bin/pip3.12 /bin/pip && ln -s /bin/python3.12 /bin/python && dnf clean all ln -s /bin/pip3.12 /bin/pip && ln -s /bin/python3.12 /bin/python && dnf clean all
ENV UV_SYSTEM_PYTHON=1 ENV UV_SYSTEM_PYTHON=1
@ -164,7 +164,7 @@ RUN apt-get update && apt-get install -y \
procps psmisc lsof \ procps psmisc lsof \
traceroute \ traceroute \
bubblewrap \ bubblewrap \
gcc \ gcc g++ \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
ENV UV_SYSTEM_PYTHON=1 ENV UV_SYSTEM_PYTHON=1

View file

@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
try: try:
models = await provider.list_models() models = await provider.list_models()
except Exception as e: except Exception as e:
logger.exception(f"Model refresh failed for provider {provider_id}: {e}") logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
continue continue
self.listed_providers.add(provider_id) self.listed_providers.add(provider_id)

View file

@ -4,15 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel
from .config import AnthropicConfig from .config import AnthropicConfig
class AnthropicProviderDataValidator(BaseModel):
anthropic_api_key: str | None = None
async def get_adapter_impl(config: AnthropicConfig, _deps): async def get_adapter_impl(config: AnthropicConfig, _deps):
from .anthropic import AnthropicInferenceAdapter from .anthropic import AnthropicInferenceAdapter

View file

@ -4,11 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator
from typing import Any
from fireworks.client import Fireworks from fireworks.client import Fireworks
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -24,12 +22,6 @@ from llama_stack.apis.inference import (
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
@ -45,15 +37,14 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
convert_message_to_openai_dict, convert_message_to_openai_dict,
get_sampling_options, get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response, process_completion_response,
process_completion_stream_response, process_completion_stream_response,
) )
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
@ -68,7 +59,7 @@ from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::fireworks") logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models) ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config self.config = config
@ -79,7 +70,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def _get_api_key(self) -> str: def get_api_key(self) -> str:
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key: if config_api_key:
return config_api_key return config_api_key
@ -91,15 +82,18 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
) )
return provider_data.fireworks_api_key return provider_data.fireworks_api_key
def _get_base_url(self) -> str: def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1" return "https://api.fireworks.ai/inference/v1"
def _get_client(self) -> Fireworks: def _get_client(self) -> Fireworks:
fireworks_api_key = self._get_api_key() fireworks_api_key = self.get_api_key()
return Fireworks(api_key=fireworks_api_key) return Fireworks(api_key=fireworks_api_key)
def _get_openai_client(self) -> AsyncOpenAI: def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key()) """Remove BOS token as Fireworks automatically prepends it"""
if prompt.startswith("<|begin_of_text|>"):
return prompt[len("<|begin_of_text|>") :]
return prompt
async def completion( async def completion(
self, self,
@ -285,153 +279,3 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
embeddings = [data.embedding for data in response.data] embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings) return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
# Fireworks always prepends with BOS
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
self,
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
params = await prepare_openai_completion_params(
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
logger.debug(f"fireworks params: {params}")
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)

View file

@ -4,15 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel
from .config import GeminiConfig from .config import GeminiConfig
class GeminiProviderDataValidator(BaseModel):
gemini_api_key: str | None = None
async def get_adapter_impl(config: GeminiConfig, _deps): async def get_adapter_impl(config: GeminiConfig, _deps):
from .gemini import GeminiInferenceAdapter from .gemini import GeminiInferenceAdapter

View file

@ -4,15 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel
from .config import OpenAIConfig from .config import OpenAIConfig
class OpenAIProviderDataValidator(BaseModel):
openai_api_key: str | None = None
async def get_adapter_impl(config: OpenAIConfig, _deps): async def get_adapter_impl(config: OpenAIConfig, _deps):
from .openai import OpenAIInferenceAdapter from .openai import OpenAIInferenceAdapter

View file

@ -103,7 +103,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
Model( Model(
identifier=id, identifier=id,
provider_resource_id=entry.provider_model_id, provider_resource_id=entry.provider_model_id,
model_type=ModelType.llm, model_type=entry.model_type,
metadata=entry.metadata, metadata=entry.metadata,
provider_id=self.__provider_id__, provider_id=self.__provider_id__,
) )

View file

@ -203,6 +203,11 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ] - '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
Returns a list of unique identifiers or None if structure doesn't match. Returns a list of unique identifiers or None if structure doesn't match.
""" """
if "models" in response["body"]:
# ollama
items = response["body"]["models"]
else:
# openai
items = response["body"] items = response["body"]
idents = [m.model if endpoint == "/api/tags" else m.id for m in items] idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
return sorted(set(idents)) return sorted(set(idents))

File diff suppressed because it is too large Load diff

View file

@ -14,7 +14,7 @@
}, },
"dependencies": { "dependencies": {
"@radix-ui/react-collapsible": "^1.1.12", "@radix-ui/react-collapsible": "^1.1.12",
"@radix-ui/react-dialog": "^1.1.13", "@radix-ui/react-dialog": "^1.1.15",
"@radix-ui/react-dropdown-menu": "^2.1.16", "@radix-ui/react-dropdown-menu": "^2.1.16",
"@radix-ui/react-select": "^2.2.6", "@radix-ui/react-select": "^2.2.6",
"@radix-ui/react-separator": "^1.1.7", "@radix-ui/react-separator": "^1.1.7",
@ -32,7 +32,7 @@
"react-dom": "^19.1.1", "react-dom": "^19.1.1",
"react-markdown": "^10.1.0", "react-markdown": "^10.1.0",
"remark-gfm": "^4.0.1", "remark-gfm": "^4.0.1",
"remeda": "^2.30.0", "remeda": "^2.32.0",
"shiki": "^1.29.2", "shiki": "^1.29.2",
"sonner": "^2.0.7", "sonner": "^2.0.7",
"tailwind-merge": "^3.3.1" "tailwind-merge": "^3.3.1"
@ -52,7 +52,7 @@
"eslint-config-prettier": "^10.1.8", "eslint-config-prettier": "^10.1.8",
"eslint-plugin-prettier": "^5.5.4", "eslint-plugin-prettier": "^5.5.4",
"jest": "^29.7.0", "jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0", "jest-environment-jsdom": "^30.1.2",
"prettier": "3.6.2", "prettier": "3.6.2",
"tailwindcss": "^4", "tailwindcss": "^4",
"ts-node": "^10.9.2", "ts-node": "^10.9.2",

View file

@ -33,6 +33,7 @@ def skip_if_model_doesnt_support_user_param(client, model_id):
provider = provider_from_model(client, model_id) provider = provider_from_model(client, model_id)
if provider.provider_type in ( if provider.provider_type in (
"remote::together", # service returns 400 "remote::together", # service returns 400
"remote::fireworks", # service returns 400 malformed input
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.")
@ -42,6 +43,7 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
if provider.provider_type in ( if provider.provider_type in (
"remote::together", # param silently ignored, always returns floats "remote::together", # param silently ignored, always returns floats
"remote::databricks", # param silently ignored, always returns floats "remote::databricks", # param silently ignored, always returns floats
"remote::fireworks", # param silently ignored, always returns list of floats
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
@ -289,7 +291,6 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
input=input_texts, input=input_texts,
encoding_format="base64", encoding_format="base64",
) )
# Validate response structure # Validate response structure
assert response.object == "list" assert response.object == "list"
assert response.model == embedding_model_id assert response.model == embedding_model_id

View file

@ -116,6 +116,15 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
"embedding_model": "databricks/databricks-bge-large-en", "embedding_model": "databricks/databricks-bge-large-en",
}, },
), ),
"fireworks": Setup(
name="fireworks",
description="Fireworks provider with a text model",
defaults={
"text_model": "accounts/fireworks/models/llama-v3p1-8b-instruct",
"vision_model": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
"embedding_model": "nomic-ai/nomic-embed-text-v1.5",
},
),
} }