Merge branch 'main' into add-localize-url-feature-to-openaimixin

This commit is contained in:
Matthew Farrellee 2025-09-26 16:31:59 -04:00
commit 17125fd2cf
421 changed files with 70880 additions and 5915 deletions

View file

@ -8,14 +8,24 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AnthropicConfig
from .models import MODEL_ENTRIES
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
# TODO: add support for voyageai, which is where these models are hosted
# embedding_model_metadata = {
# "voyage-3-large": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-3.5": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-3.5-lite": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-code-3": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-finance-2": {"embedding_dimension": 1024, "context_length": 32000},
# "voyage-law-2": {"embedding_dimension": 1024, "context_length": 16000},
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
# }
def __init__(self, config: AnthropicConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="anthropic",
api_key_from_config=config.api_key,
provider_data_api_key_field="anthropic_api_key",

View file

@ -1,40 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
LLM_MODEL_IDS = [
"claude-3-5-sonnet-latest",
"claude-3-7-sonnet-latest",
"claude-3-5-haiku-latest",
]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id="voyage-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="voyage-3-lite",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 512, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="voyage-code-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -14,14 +14,12 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import (
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AzureConfig
from .models import MODEL_ENTRIES
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: AzureConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="azure",
api_key_from_config=config.api_key.get_secret_value(),
provider_data_api_key_field="azure_api_key",

View file

@ -1,28 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
# https://learn.microsoft.com/en-us/azure/ai-foundry/openai/concepts/models?tabs=global-standard%2Cstandard-chat-completions
LLM_MODEL_IDS = [
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
"gpt-5-chat",
"o1",
"o1-mini",
"o3-mini",
"o4-mini",
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
]
SAFETY_MODELS_ENTRIES = list[ProviderModelEntry]()
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + SAFETY_MODELS_ENTRIES

View file

@ -98,7 +98,7 @@ class BedrockInferenceAdapter(
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self._config = config
self._client = None

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from urllib.parse import urljoin
from cerebras.cloud.sdk import AsyncCerebras
@ -35,42 +36,41 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from .config import CerebrasImplConfig
from .models import MODEL_ENTRIES
class CerebrasInferenceAdapter(
OpenAIMixin,
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: CerebrasImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_entries=MODEL_ENTRIES,
)
self.config = config
# TODO: make this use provider data, etc. like other providers
self.client = AsyncCerebras(
self._cerebras_client = AsyncCerebras(
base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(),
)
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")
async def initialize(self) -> None:
return
@ -107,14 +107,14 @@ class CerebrasInferenceAdapter(
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.completions.create(**params)
r = await self._cerebras_client.completions.create(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
stream = await self._cerebras_client.completions.create(**params)
async for chunk in process_completion_stream_response(stream):
yield chunk
@ -156,14 +156,14 @@ class CerebrasInferenceAdapter(
async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.completions.create(**params)
r = await self._cerebras_client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
stream = await self._cerebras_client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk

View file

@ -20,8 +20,8 @@ class CerebrasImplConfig(BaseModel):
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API",
)
api_key: SecretStr | None = Field(
default=os.environ.get("CEREBRAS_API_KEY"),
api_key: SecretStr = Field(
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
description="Cerebras API Key",
)

View file

@ -1,28 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = []
# https://inference-docs.cerebras.ai/models
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"llama3.1-8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama-3.3-70b",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -5,10 +5,11 @@
# the root directory of this source tree.
from .config import DatabricksImplConfig
from .databricks import DatabricksInferenceAdapter
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
from .databricks import DatabricksInferenceAdapter
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
await impl.initialize()

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
@ -17,16 +17,16 @@ class DatabricksImplConfig(BaseModel):
default=None,
description="The URL for the Databricks model serving endpoint",
)
api_token: str = Field(
default=None,
api_token: SecretStr = Field(
default=SecretStr(None),
description="The Databricks API token",
)
@classmethod
def sample_run_config(
cls,
url: str = "${env.DATABRICKS_URL:=}",
api_token: str = "${env.DATABRICKS_API_TOKEN:=}",
url: str = "${env.DATABRICKS_HOST:=}",
api_token: str = "${env.DATABRICKS_TOKEN:=}",
**kwargs: Any,
) -> dict[str, Any]:
return {

View file

@ -4,23 +4,27 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from typing import Any
from openai import OpenAI
from databricks.sdk import WorkspaceClient
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
Model,
OpenAICompletion,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -29,49 +33,34 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import DatabricksImplConfig
SAFETY_MODELS_ENTRIES = []
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"databricks-meta-llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"databricks-meta-llama-3-1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
] + SAFETY_MODELS_ENTRIES
logger = get_logger(name=__name__, category="inference::databricks")
class DatabricksInferenceAdapter(
ModelRegistryHelper,
OpenAIMixin,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
embedding_model_metadata = {
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
}
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config
def get_api_key(self) -> str:
return self.config.api_token.get_secret_value()
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
async def initialize(self) -> None:
return
@ -80,72 +69,54 @@ class DatabricksInferenceAdapter(
async def completion(
self,
model: str,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
raise NotImplementedError()
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
raise NotImplementedError()
async def chat_completion(
self,
model: str,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream:
return self._stream_chat_completion(request, client)
else:
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": request.model,
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
raise NotImplementedError()
async def embeddings(
self,
@ -157,12 +128,31 @@ class DatabricksInferenceAdapter(
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def list_models(self) -> list[Model] | None:
self._model_cache = {} # from OpenAIMixin
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
endpoints = ws_client.serving_endpoints.list()
for endpoint in endpoints:
model = Model(
provider_id=self.__provider_id__,
provider_resource_id=endpoint.name,
identifier=endpoint.name,
)
if endpoint.task == "llm/v1/chat":
model.model_type = ModelType.llm # this is redundant, but informative
elif endpoint.task == "llm/v1/embeddings":
if endpoint.name not in self.embedding_model_metadata:
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
continue
model.model_type = ModelType.embedding
model.metadata = self.embedding_model_metadata[endpoint.name]
else:
logger.warning(f"Unknown model type, skipping: {endpoint}")
continue
self._model_cache[endpoint.name] = model
return list(self._model_cache.values())
async def should_refresh_models(self) -> bool:
return False

View file

@ -4,11 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from collections.abc import AsyncGenerator
from fireworks.client import Fireworks
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -24,12 +22,6 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
ResponseFormatType,
SamplingParams,
@ -45,15 +37,14 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
convert_message_to_openai_dict,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
@ -63,15 +54,19 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import FireworksImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
embedding_model_metadata = {
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
}
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
ModelRegistryHelper.__init__(self)
self.config = config
self.allowed_models = config.allowed_models
async def initialize(self) -> None:
pass
@ -79,7 +74,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def shutdown(self) -> None:
pass
def _get_api_key(self) -> str:
def get_api_key(self) -> str:
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
return config_api_key
@ -91,15 +86,18 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
)
return provider_data.fireworks_api_key
def _get_base_url(self) -> str:
def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1"
def _get_client(self) -> Fireworks:
fireworks_api_key = self._get_api_key()
fireworks_api_key = self.get_api_key()
return Fireworks(api_key=fireworks_api_key)
def _get_openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key())
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
"""Remove BOS token as Fireworks automatically prepends it"""
if prompt.startswith("<|begin_of_text|>"):
return prompt[len("<|begin_of_text|>") :]
return prompt
async def completion(
self,
@ -285,153 +283,3 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
# Fireworks always prepends with BOS
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
self,
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
params = await prepare_openai_completion_params(
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
logger.debug(f"fireworks params: {params}")
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)

View file

@ -1,70 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama4-scout-instruct-basic",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama4-maverick-instruct-basic",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
ProviderModelEntry(
provider_model_id="nomic-ai/nomic-embed-text-v1.5",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
),
] + SAFETY_MODELS_ENTRIES

View file

@ -4,15 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import GeminiConfig
class GeminiProviderDataValidator(BaseModel):
gemini_api_key: str | None = None
async def get_adapter_impl(config: GeminiConfig, _deps):
from .gemini import GeminiInferenceAdapter

View file

@ -8,14 +8,16 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import GeminiConfig
from .models import MODEL_ENTRIES
class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
embedding_model_metadata = {
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
}
def __init__(self, config: GeminiConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="gemini",
api_key_from_config=config.api_key,
provider_data_api_key_field="gemini_api_key",

View file

@ -1,34 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
LLM_MODEL_IDS = [
"gemini-1.5-flash",
"gemini-1.5-pro",
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
"gemini-2.5-flash",
"gemini-2.5-flash-lite",
"gemini-2.5-pro",
]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id="text-embedding-004",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 768, "context_length": 2048},
),
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -4,12 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import GroqConfig
async def get_adapter_impl(config: GroqConfig, _deps) -> Inference:
async def get_adapter_impl(config: GroqConfig, _deps):
# import dynamically so the import is used only when it is needed
from .groq import GroqInferenceAdapter

View file

@ -9,8 +9,6 @@ from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES
class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
_config: GroqConfig
@ -18,7 +16,6 @@ class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: GroqConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
litellm_provider_name="groq",
api_key_from_config=config.api_key,
provider_data_api_key_field="groq_api_key",

View file

@ -1,48 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
build_model_entry,
)
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"llama3-8b-8192",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama-3.1-8b-instant",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama3-70b-8192",
CoreModelId.llama3_70b_instruct.value,
),
build_hf_repo_model_entry(
"llama-3.3-70b-versatile",
CoreModelId.llama3_3_70b_instruct.value,
),
# Groq only contains a preview version for llama-3.2-3b
# Preview models aren't recommended for production use, but we include this one
# to pass the test fixture
# TODO(aidand): Replace this with a stable model once Groq supports it
build_hf_repo_model_entry(
"llama-3.2-3b-preview",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -8,8 +8,6 @@ from llama_stack.providers.remote.inference.llama_openai_compat.config import Ll
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
@ -30,7 +28,6 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: LlamaCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
litellm_provider_name="meta_llama",
api_key_from_config=config.api_key,
provider_data_api_key_field="llama_api_key",

View file

@ -1,25 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"Llama-3.3-70B-Instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"Llama-4-Scout-17B-16E-Instruct-FP8",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]

View file

@ -1,109 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = []
# https://docs.nvidia.com/nim/large-language-models/latest/supported-llm-agnostic-architectures.html
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta/llama3-8b-instruct",
CoreModelId.llama3_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama3-70b-instruct",
CoreModelId.llama3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
ProviderModelEntry(
provider_model_id="nvidia/vila",
model_type=ModelType.llm,
),
# NeMo Retriever Text Embedding models -
#
# https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
#
# +-----------------------------------+--------+-----------+-----------+------------+
# | Model ID | Max | Publisher | Embedding | Dynamic |
# | | Tokens | | Dimension | Embeddings |
# +-----------------------------------+--------+-----------+-----------+------------+
# | nvidia/llama-3.2-nv-embedqa-1b-v2 | 8192 | NVIDIA | 2048 | Yes |
# | nvidia/nv-embedqa-e5-v5 | 512 | NVIDIA | 1024 | No |
# | nvidia/nv-embedqa-mistral-7b-v2 | 512 | NVIDIA | 4096 | No |
# | snowflake/arctic-embed-l | 512 | Snowflake | 1024 | No |
# +-----------------------------------+--------+-----------+-----------+------------+
ProviderModelEntry(
provider_model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 2048,
"context_length": 8192,
},
),
ProviderModelEntry(
provider_model_id="nvidia/nv-embedqa-e5-v5",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
ProviderModelEntry(
provider_model_id="nvidia/nv-embedqa-mistral-7b-v2",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 4096,
"context_length": 512,
},
),
ProviderModelEntry(
provider_model_id="snowflake/arctic-embed-l",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
# TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
] + SAFETY_MODELS_ENTRIES

View file

@ -37,9 +37,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
@ -48,7 +45,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
from . import NVIDIAConfig
from .models import MODEL_ENTRIES
from .openai_utils import (
convert_chat_completion_request,
convert_completion_request,
@ -60,7 +56,7 @@ from .utils import _is_nvidia_hosted
logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
"""
NVIDIA Inference Adapter for Llama Stack.
@ -74,10 +70,15 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
"""
def __init__(self, config: NVIDIAConfig) -> None:
# TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
embedding_model_metadata = {
"nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192},
"nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024},
"nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096},
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
}
def __init__(self, config: NVIDIAConfig) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
if _is_nvidia_hosted(config):

View file

@ -1,106 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
build_model_entry,
)
SAFETY_MODELS_ENTRIES = [
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
]
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"llama3.1:8b-instruct-fp16",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama3.1:8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:70b-instruct-fp16",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_entry(
"llama3.1:70b",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:405b-instruct-fp16",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_entry(
"llama3.1:405b",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:1b-instruct-fp16",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_entry(
"llama3.2:1b",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_entry(
"llama3.2:3b",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:11b-instruct-fp16",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:latest",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:90b-instruct-fp16",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:90b",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.3:70b",
CoreModelId.llama3_3_70b_instruct.value,
),
ProviderModelEntry(
provider_model_id="all-minilm:l6-v2",
aliases=["all-minilm"],
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
),
ProviderModelEntry(
provider_model_id="nomic-embed-text",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
),
] + SAFETY_MODELS_ENTRIES

View file

@ -40,8 +40,9 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
@ -50,6 +51,7 @@ from llama_stack.providers.datatypes import (
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
@ -70,8 +72,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
request_has_media,
)
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::ollama")
@ -84,8 +84,44 @@ class OllamaInferenceAdapter(
# automatically set by the resolver when instantiating the provider
__provider_id__: str
embedding_model_metadata = {
"all-minilm:l6-v2": {
"embedding_dimension": 384,
"context_length": 512,
},
"nomic-embed-text:latest": {
"embedding_dimension": 768,
"context_length": 8192,
},
"nomic-embed-text:v1.5": {
"embedding_dimension": 768,
"context_length": 8192,
},
"nomic-embed-text:137m-v1.5-fp16": {
"embedding_dimension": 768,
"context_length": 8192,
},
}
def __init__(self, config: OllamaImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
# TODO: remove ModelRegistryHelper.__init__ when completion and
# chat_completion are. this exists to satisfy the input /
# output processing for llama models. specifically,
# tool_calling is handled by raw template processing,
# instead of using the /api/chat endpoint w/ tools=...
ModelRegistryHelper.__init__(
self,
model_entries=[
build_hf_repo_model_entry(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
],
)
self.config = config
# Ollama does not support image urls, so we need to download the image and convert it to base64
self.download_images = True
@ -116,60 +152,6 @@ class OllamaInferenceAdapter(
async def should_refresh_models(self) -> bool:
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__
response = await self.ollama_client.list()
# always add the two embedding models which can be pulled on demand
models = [
Model(
identifier="all-minilm:l6-v2",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
# add all-minilm alias
Model(
identifier="all-minilm",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
Model(
identifier="nomic-embed-text",
provider_resource_id="nomic-embed-text:latest",
provider_id=provider_id,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
model_type=ModelType.embedding,
),
]
for m in response.models:
# kill embedding models since we don't know dimensions for them
if "bert" in m.details.family:
continue
models.append(
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={},
model_type=ModelType.llm,
)
)
self._model_cache = {m.identifier: m for m in models} # for fast check_model_availability
return models
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the Ollama server.
@ -403,37 +385,16 @@ class OllamaInferenceAdapter(
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model:
try:
model = await super().register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
if await self.check_model_availability(model.provider_model_id):
return model
elif await self.check_model_availability(f"{model.provider_model_id}:latest"):
model.provider_resource_id = f"{model.provider_model_id}:latest"
logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_model_id}'"
)
return model
if model.model_type == ModelType.embedding:
response = await self.ollama_client.list()
if model.provider_resource_id not in [m.model for m in response.models]:
await self.ollama_client.pull(model.provider_resource_id)
# we use list() here instead of ps() -
# - ps() only lists running models, not available models
# - models not currently running are run by the ollama server as needed
response = await self.ollama_client.list()
available_models = [m.model for m in response.models]
provider_resource_id = model.provider_resource_id
assert provider_resource_id is not None # mypy
if provider_resource_id not in available_models:
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
if provider_resource_id in available_models_latest:
logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
)
return model
raise UnsupportedModelError(provider_resource_id, available_models)
# mutating this should be considered an anti-pattern
model.provider_resource_id = provider_resource_id
return model
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:

View file

@ -4,15 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import OpenAIConfig
class OpenAIProviderDataValidator(BaseModel):
openai_api_key: str | None = None
async def get_adapter_impl(config: OpenAIConfig, _deps):
from .openai import OpenAIInferenceAdapter

View file

@ -1,60 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
LLM_MODEL_IDS = [
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo",
"gpt-3.5-turbo-instruct",
"gpt-4",
"gpt-4-turbo",
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-mini",
"gpt-4o-audio-preview",
"chatgpt-4o-latest",
"o1",
"o1-mini",
"o3-mini",
"o4-mini",
]
@dataclass
class EmbeddingModelInfo:
"""Structured representation of embedding model information."""
embedding_dimension: int
context_length: int
EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
}
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id=model_id,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": model_info.embedding_dimension,
"context_length": model_info.context_length,
},
)
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -9,7 +9,6 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::openai")
@ -40,10 +39,14 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
embedding_model_metadata = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
}
def __init__(self, config: OpenAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="openai",
api_key_from_config=config.api_key,
provider_data_api_key_field="openai_api_key",

View file

@ -43,7 +43,7 @@ from .config import PassthroughImplConfig
class PassthroughInferenceAdapter(Inference):
def __init__(self, config: PassthroughImplConfig) -> None:
ModelRegistryHelper.__init__(self, [])
ModelRegistryHelper.__init__(self)
self.config = config
async def initialize(self) -> None:

View file

@ -4,12 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import SambaNovaImplConfig
async def get_adapter_impl(config: SambaNovaImplConfig, _deps) -> Inference:
async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
from .sambanova import SambaNovaInferenceAdapter
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"

View file

@ -1,28 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"Meta-Llama-3.1-8B-Instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.3-70B-Instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"Llama-4-Maverick-17B-128E-Instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -9,7 +9,6 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import SambaNovaImplConfig
from .models import MODEL_ENTRIES
class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
@ -26,10 +25,9 @@ class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: SambaNovaImplConfig):
self.config = config
self.environment_available_models = []
self.environment_available_models: list[str] = []
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
litellm_provider_name="sambanova",
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
provider_data_api_key_field="sambanova_api_key",

View file

@ -1,103 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
# source: https://docs.together.ai/docs/serverless-models#embedding-models
EMBEDDING_MODEL_ENTRIES = {
"togethercomputer/m2-bert-80M-32k-retrieval": ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-32k-retrieval",
metadata={
"embedding_dimension": 768,
"context_length": 32768,
},
),
"BAAI/bge-large-en-v1.5": ProviderModelEntry(
provider_model_id="BAAI/bge-large-en-v1.5",
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
"BAAI/bge-base-en-v1.5": ProviderModelEntry(
provider_model_id="BAAI/bge-base-en-v1.5",
metadata={
"embedding_dimension": 768,
"context_length": 512,
},
),
"Alibaba-NLP/gte-modernbert-base": ProviderModelEntry(
provider_model_id="Alibaba-NLP/gte-modernbert-base",
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
),
"intfloat/multilingual-e5-large-instruct": ProviderModelEntry(
provider_model_id="intfloat/multilingual-e5-large-instruct",
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
}
MODEL_ENTRIES = (
[
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]
+ SAFETY_MODELS_ENTRIES
+ list(EMBEDDING_MODEL_ENTRIES.values())
)

View file

@ -6,7 +6,7 @@
from collections.abc import AsyncGenerator
from openai import NOT_GIVEN, AsyncOpenAI
from openai import AsyncOpenAI
from together import AsyncTogether
from together.constants import BASE_URL
@ -56,15 +56,23 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import TogetherImplConfig
from .models import EMBEDDING_MODEL_ENTRIES, MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
embedding_model_metadata = {
"togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768},
"BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512},
"BAAI/bge-base-en-v1.5": {"embedding_dimension": 768, "context_length": 512},
"Alibaba-NLP/gte-modernbert-base": {"embedding_dimension": 768, "context_length": 8192},
"intfloat/multilingual-e5-large-instruct": {"embedding_dimension": 1024, "context_length": 512},
}
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
ModelRegistryHelper.__init__(self)
self.config = config
self.allowed_models = config.allowed_models
self._model_cache: dict[str, Model] = {}
def get_api_key(self):
@ -264,15 +272,16 @@ class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Need
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
for m in await self._get_client().models.list():
if m.type == "embedding":
if m.id not in EMBEDDING_MODEL_ENTRIES:
if m.id not in self.embedding_model_metadata:
logger.warning(f"Unknown embedding dimension for model {m.id}, skipping.")
continue
metadata = self.embedding_model_metadata[m.id]
self._model_cache[m.id] = Model(
provider_id=self.__provider_id__,
provider_resource_id=EMBEDDING_MODEL_ENTRIES[m.id].provider_model_id,
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.embedding,
metadata=EMBEDDING_MODEL_ENTRIES[m.id].metadata,
metadata=metadata,
)
else:
self._model_cache[m.id] = Model(
@ -303,10 +312,9 @@ class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Need
the standard OpenAI embeddings endpoint.
The endpoint -
- does not return usage information
- not all models return usage information
- does not support user param, returns 400 Unrecognized request arguments supplied: user
- does not support dimensions param, returns 400 Unrecognized request arguments supplied: dimensions
- does not support encoding_format param, always returns floats, never base64
"""
# Together support ticket #13332 -> will not fix
if user is not None:
@ -314,13 +322,11 @@ class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Need
# Together support ticket #13333 -> escalated
if dimensions is not None:
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
# Together support ticket #13331 -> will not fix, compute client side
if encoding_format not in (None, NOT_GIVEN, "float"):
raise ValueError("Together's embeddings endpoint only supports encoding_format='float'.")
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
encoding_format=encoding_format,
)
response.model = model # return the user the same model id they provided, avoid exposing the provider model id

View file

@ -1,20 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
# Vertex AI model IDs with vertex_ai/ prefix as required by litellm
LLM_MODEL_IDS = [
"vertex_ai/gemini-2.0-flash",
"vertex_ai/gemini-2.5-flash",
"vertex_ai/gemini-2.5-pro",
]
SAFETY_MODELS_ENTRIES = list[ProviderModelEntry]()
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + SAFETY_MODELS_ENTRIES

View file

@ -16,14 +16,12 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import (
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import VertexAIConfig
from .models import MODEL_ENTRIES
class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: VertexAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="vertex_ai",
api_key_from_config=None, # Vertex AI uses ADC, not API keys
provider_data_api_key_field="vertex_project", # Use project for validation

View file

@ -292,7 +292,7 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
build_hf_repo_model_entries(),
model_entries=build_hf_repo_model_entries(),
litellm_provider_name="vllm",
api_key_from_config=config.api_token,
provider_data_api_key_field="vllm_api_token",
@ -504,7 +504,7 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
except ValueError:
pass # Ignore statically unknown model, will check live listing
try:
res = await self.client.models.list()
res = self.client.models.list()
except APIConnectionError as e:
raise ValueError(
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."

View file

@ -76,7 +76,7 @@ logger = get_logger(name=__name__, category="inference::watsonx")
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: WatsonXConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
logger.info(f"Initializing watsonx InferenceAdapter({config.url})...")
self._config = config