chore: turn OpenAIMixin into a pydantic.BaseModel (#3671)

# What does this PR do?

- implement get_api_key instead of relying on
LiteLLMOpenAIMixin.get_api_key
 - remove use of LiteLLMOpenAIMixin
 - add default initialize/shutdown methods to OpenAIMixin
 - remove __init__s to allow proper pydantic construction
- remove dead code from vllm adapter and associated / duplicate unit
tests
 - update vllm adapter to use openaimixin for model registration
 - remove ModelRegistryHelper from fireworks & together adapters
 - remove Inference from nvidia adapter
 - complete type hints on embedding_model_metadata
- allow extra fields on OpenAIMixin, for model_store, __provider_id__,
etc
 - new recordings for ollama
 - enhance the list models error handling
- update cerebras (remove cerebras-cloud-sdk) and anthropic (custom
model listing) inference adapters
 - parametrized test_inference_client_caching
- remove cerebras, databricks, fireworks, together from blanket mypy
exclude
 - removed unnecessary litellm deps

## Test Plan

ci
This commit is contained in:
Matthew Farrellee 2025-10-06 11:33:19 -04:00 committed by GitHub
parent 724dac498c
commit d23ed26238
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
131 changed files with 83634 additions and 1760 deletions

View file

@ -52,9 +52,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="cerebras",
provider_type="remote::cerebras",
pip_packages=[
"cerebras_cloud_sdk",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
@ -169,7 +167,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="openai",
provider_type="remote::openai",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
@ -179,7 +177,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="anthropic",
provider_type="remote::anthropic",
pip_packages=["litellm"],
pip_packages=["anthropic"],
module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
@ -189,9 +187,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="gemini",
provider_type="remote::gemini",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
@ -202,7 +198,6 @@ def available_providers() -> list[ProviderSpec]:
adapter_type="vertexai",
provider_type="remote::vertexai",
pip_packages=[
"litellm",
"google-cloud-aiplatform",
],
module="llama_stack.providers.remote.inference.vertexai",
@ -233,9 +228,7 @@ Available Models:
api=Api.inference,
adapter_type="groq",
provider_type="remote::groq",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
@ -245,7 +238,7 @@ Available Models:
api=Api.inference,
adapter_type="llama-openai-compat",
provider_type="remote::llama-openai-compat",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
@ -255,9 +248,7 @@ Available Models:
api=Api.inference,
adapter_type="sambanova",
provider_type="remote::sambanova",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
@ -287,7 +278,7 @@ Available Models:
api=Api.inference,
provider_type="remote::azure",
adapter_type="azure",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.azure",
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",

View file

@ -10,6 +10,6 @@ from .config import AnthropicConfig
async def get_adapter_impl(config: AnthropicConfig, _deps):
from .anthropic import AnthropicInferenceAdapter
impl = AnthropicInferenceAdapter(config)
impl = AnthropicInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,13 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from collections.abc import Iterable
from anthropic import AsyncAnthropic
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AnthropicConfig
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class AnthropicInferenceAdapter(OpenAIMixin):
config: AnthropicConfig
provider_data_api_key_field: str = "anthropic_api_key"
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
# TODO: add support for voyageai, which is where these models are hosted
# embedding_model_metadata = {
@ -23,22 +29,11 @@ class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
# }
def __init__(self, config: AnthropicConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="anthropic",
api_key_from_config=config.api_key,
provider_data_api_key_field="anthropic_api_key",
)
self.config = config
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self):
return "https://api.anthropic.com/v1"
async def list_provider_model_ids(self) -> Iterable[str]:
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]

View file

@ -10,6 +10,6 @@ from .config import AzureConfig
async def get_adapter_impl(config: AzureConfig, _deps):
from .azure import AzureInferenceAdapter
impl = AzureInferenceAdapter(config)
impl = AzureInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,31 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from urllib.parse import urljoin
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AzureConfig
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: AzureConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="azure",
api_key_from_config=config.api_key.get_secret_value(),
provider_data_api_key_field="azure_api_key",
openai_compat_api_base=str(config.api_base),
)
self.config = config
class AzureInferenceAdapter(OpenAIMixin):
config: AzureConfig
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
provider_data_api_key_field: str = "azure_api_key"
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
def get_base_url(self) -> str:
"""
@ -37,26 +26,3 @@ class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
Returns the Azure API base URL from the configuration.
"""
return urljoin(str(self.config.api_base), "/openai/v1")
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# Add Azure specific parameters
provider_data = self.get_request_provider_data()
if provider_data:
if getattr(provider_data, "azure_api_key", None):
params["api_key"] = provider_data.azure_api_key
if getattr(provider_data, "azure_api_base", None):
params["api_base"] = provider_data.azure_api_base
if getattr(provider_data, "azure_api_version", None):
params["api_version"] = provider_data.azure_api_version
if getattr(provider_data, "azure_api_type", None):
params["api_type"] = provider_data.azure_api_type
else:
params["api_key"] = self.config.api_key.get_secret_value()
params["api_base"] = str(self.config.api_base)
params["api_version"] = self.config.api_version
params["api_type"] = self.config.api_type
return params

View file

@ -12,7 +12,7 @@ async def get_adapter_impl(config: CerebrasImplConfig, _deps):
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
impl = CerebrasInferenceAdapter(config)
impl = CerebrasInferenceAdapter(config=config)
await impl.initialize()

View file

@ -6,39 +6,14 @@
from urllib.parse import urljoin
from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
Inference,
OpenAIEmbeddingsResponse,
TopKSamplingStrategy,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
)
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from .config import CerebrasImplConfig
class CerebrasInferenceAdapter(
OpenAIMixin,
Inference,
):
def __init__(self, config: CerebrasImplConfig) -> None:
self.config = config
# TODO: make this use provider data, etc. like other providers
self._cerebras_client = AsyncCerebras(
base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(),
)
class CerebrasInferenceAdapter(OpenAIMixin):
config: CerebrasImplConfig
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
@ -46,31 +21,6 @@ class CerebrasInferenceAdapter(
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""
if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
elif isinstance(request, CompletionRequest):
prompt = await completion_request_to_prompt(request)
else:
raise ValueError(f"Unknown request type {type(request)}")
return {
"model": request.model,
"prompt": prompt,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def openai_embeddings(
self,
model: str,

View file

@ -22,7 +22,7 @@ class CerebrasImplConfig(RemoteInferenceProviderConfig):
description="Base URL for the Cerebras API",
)
api_key: SecretStr = Field(
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
description="Cerebras API Key",
)

View file

@ -11,6 +11,6 @@ async def get_adapter_impl(config: DatabricksImplConfig, _deps):
from .databricks import DatabricksInferenceAdapter
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
impl = DatabricksInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -14,12 +14,12 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
url: str | None = Field(
default=None,
description="The URL for the Databricks model serving endpoint",
)
api_token: SecretStr = Field(
default=SecretStr(None),
default=SecretStr(None), # type: ignore[arg-type]
description="The Databricks API token",
)

View file

@ -9,10 +9,7 @@ from typing import Any
from databricks.sdk import WorkspaceClient
from llama_stack.apis.inference import (
Inference,
OpenAICompletion,
)
from llama_stack.apis.inference import OpenAICompletion
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -21,30 +18,31 @@ from .config import DatabricksImplConfig
logger = get_logger(name=__name__, category="inference::databricks")
class DatabricksInferenceAdapter(
OpenAIMixin,
Inference,
):
class DatabricksInferenceAdapter(OpenAIMixin):
config: DatabricksImplConfig
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
}
def __init__(self, config: DatabricksImplConfig) -> None:
self.config = config
def get_api_key(self) -> str:
return self.config.api_token.get_secret_value()
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
async def initialize(self) -> None:
return
async def list_provider_model_ids(self) -> Iterable[str]:
return [
endpoint.name
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
).serving_endpoints.list() # TODO: this is not async
]
async def shutdown(self) -> None:
pass
async def should_refresh_models(self) -> bool:
return False
async def openai_completion(
self,
@ -70,14 +68,3 @@ class DatabricksInferenceAdapter(
suffix: str | None = None,
) -> OpenAICompletion:
raise NotImplementedError()
async def list_provider_model_ids(self) -> Iterable[str]:
return [
endpoint.name
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
).serving_endpoints.list() # TODO: this is not async
]
async def should_refresh_models(self) -> bool:
return False

View file

@ -17,6 +17,6 @@ async def get_adapter_impl(config: FireworksImplConfig, _deps):
from .fireworks import FireworksInferenceAdapter
assert isinstance(config, FireworksImplConfig), f"Unexpected config type: {type(config)}"
impl = FireworksInferenceAdapter(config)
impl = FireworksInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -5,124 +5,26 @@
# the root directory of this source tree.
from fireworks.client import Fireworks
from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
LogProbConfig,
ResponseFormat,
ResponseFormatType,
SamplingParams,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
request_has_media,
)
from .config import FireworksImplConfig
logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData):
embedding_model_metadata = {
class FireworksInferenceAdapter(OpenAIMixin):
config: FireworksImplConfig
embedding_model_metadata: dict[str, dict[str, int]] = {
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
}
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self)
self.config = config
self.allowed_models = config.allowed_models
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
provider_data_api_key_field: str = "fireworks_api_key"
def get_api_key(self) -> str:
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
return config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
raise ValueError(
'Pass Fireworks API Key in the header X-LlamaStack-Provider-Data as { "fireworks_api_key": <your api key>}'
)
return provider_data.fireworks_api_key
return self.config.api_key.get_secret_value() if self.config.api_key else None # type: ignore[return-value]
def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1"
def _get_client(self) -> Fireworks:
fireworks_api_key = self.get_api_key()
return Fireworks(api_key=fireworks_api_key)
def _build_options(
self,
sampling_params: SamplingParams | None,
fmt: ResponseFormat | None,
logprobs: LogProbConfig | None,
) -> dict:
options = get_sampling_options(sampling_params)
options.setdefault("max_tokens", 512)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
options["logprobs"] = logprobs.top_k
if options["logprobs"] <= 0 or options["logprobs"] >= 5:
raise ValueError("Required range: 0 < top_k < 5")
return options
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
# TODO: tools are never added to the request, so we need to add them here
if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
# Fireworks always prepends with BOS
if "prompt" in input_dict:
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
params = {
"model": request.model,
**input_dict,
"stream": bool(request.stream),
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
logger.debug(f"params to fireworks: {params}")
return params

View file

@ -10,6 +10,6 @@ from .config import GeminiConfig
async def get_adapter_impl(config: GeminiConfig, _deps):
from .gemini import GeminiInferenceAdapter
impl = GeminiInferenceAdapter(config)
impl = GeminiInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,33 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import GeminiConfig
class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
embedding_model_metadata = {
class GeminiInferenceAdapter(OpenAIMixin):
config: GeminiConfig
provider_data_api_key_field: str = "gemini_api_key"
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
}
def __init__(self, config: GeminiConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="gemini",
api_key_from_config=config.api_key,
provider_data_api_key_field="gemini_api_key",
)
self.config = config
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self):
return "https://generativelanguage.googleapis.com/v1beta/openai/"
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()

View file

@ -11,5 +11,5 @@ async def get_adapter_impl(config: GroqConfig, _deps):
# import dynamically so the import is used only when it is needed
from .groq import GroqInferenceAdapter
adapter = GroqInferenceAdapter(config)
adapter = GroqInferenceAdapter(config=config)
return adapter

View file

@ -6,30 +6,16 @@
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
_config: GroqConfig
class GroqInferenceAdapter(OpenAIMixin):
config: GroqConfig
def __init__(self, config: GroqConfig):
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="groq",
api_key_from_config=config.api_key,
provider_data_api_key_field="groq_api_key",
)
self.config = config
provider_data_api_key_field: str = "groq_api_key"
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self) -> str:
return f"{self.config.url}/openai/v1"
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -4,14 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import LlamaCompatConfig
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> InferenceProvider:
async def get_adapter_impl(config: LlamaCompatConfig, _deps):
# import dynamically so the import is used only when it is needed
from .llama import LlamaCompatInferenceAdapter
adapter = LlamaCompatInferenceAdapter(config)
adapter = LlamaCompatInferenceAdapter(config=config)
return adapter

View file

@ -8,38 +8,21 @@ from typing import Any
from llama_stack.apis.inference.inference import OpenAICompletion
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class LlamaCompatInferenceAdapter(OpenAIMixin):
config: LlamaCompatConfig
provider_data_api_key_field: str = "llama_api_key"
"""
Llama API Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
_config: LlamaCompatConfig
def __init__(self, config: LlamaCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="meta_llama",
api_key_from_config=config.api_key,
provider_data_api_key_field="llama_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self) -> str:
"""
@ -49,12 +32,6 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
return self.config.openai_compat_api_base
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()
async def openai_completion(
self,
model: str,

View file

@ -15,7 +15,8 @@ async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
if not isinstance(config, NVIDIAConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = NVIDIAInferenceAdapter(config)
adapter = NVIDIAInferenceAdapter(config=config)
await adapter.initialize()
return adapter

View file

@ -8,7 +8,6 @@
from openai import NOT_GIVEN
from llama_stack.apis.inference import (
Inference,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
@ -22,7 +21,9 @@ from .utils import _is_nvidia_hosted
logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
class NVIDIAInferenceAdapter(OpenAIMixin):
config: NVIDIAConfig
"""
NVIDIA Inference Adapter for Llama Stack.
@ -37,32 +38,21 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
"""
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192},
"nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024},
"nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096},
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
}
def __init__(self, config: NVIDIAConfig) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
async def initialize(self) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
if _is_nvidia_hosted(config):
if not config.api_key:
if _is_nvidia_hosted(self.config):
if not self.config.api_key:
raise RuntimeError(
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
)
# elif self._config.api_key:
#
# we don't raise this warning because a user may have deployed their
# self-hosted NIM with an API key requirement.
#
# warnings.warn(
# "API key is not required for self-hosted NVIDIA NIM. "
# "Consider removing the api_key from the configuration."
# )
self._config = config
def get_api_key(self) -> str:
"""
@ -70,7 +60,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
:return: The NVIDIA API key
"""
return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"
return self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY"
def get_base_url(self) -> str:
"""
@ -78,7 +68,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
:return: The NVIDIA API base URL
"""
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
async def openai_embeddings(
self,

View file

@ -10,6 +10,6 @@ from .config import OllamaImplConfig
async def get_adapter_impl(config: OllamaImplConfig, _deps):
from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config)
impl = OllamaInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,58 +6,29 @@
import asyncio
from typing import Any
from ollama import AsyncClient as AsyncOllamaClient
from llama_stack.apis.common.content_types import (
ImageContentItem,
TextContentItem,
)
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.inference import (
ChatCompletionRequest,
GrammarResponseFormat,
InferenceProvider,
JsonSchemaResponseFormat,
Message,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
convert_image_content_to_url,
request_has_media,
)
logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(
OpenAIMixin,
ModelRegistryHelper,
InferenceProvider,
ModelsProtocolPrivate,
):
class OllamaInferenceAdapter(OpenAIMixin):
config: OllamaImplConfig
# automatically set by the resolver when instantiating the provider
__provider_id__: str
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"all-minilm:l6-v2": {
"embedding_dimension": 384,
"context_length": 512,
@ -76,29 +47,8 @@ class OllamaInferenceAdapter(
},
}
def __init__(self, config: OllamaImplConfig) -> None:
# TODO: remove ModelRegistryHelper.__init__ when completion and
# chat_completion are. this exists to satisfy the input /
# output processing for llama models. specifically,
# tool_calling is handled by raw template processing,
# instead of using the /api/chat endpoint w/ tools=...
ModelRegistryHelper.__init__(
self,
model_entries=[
build_hf_repo_model_entry(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
],
)
self.config = config
# Ollama does not support image urls, so we need to download the image and convert it to base64
self.download_images = True
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
download_images: bool = True
_clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
@property
def ollama_client(self) -> AsyncOllamaClient:
@ -142,50 +92,6 @@ class OllamaInferenceAdapter(
async def shutdown(self) -> None:
self._clients.clear()
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def _get_params(self, request: ChatCompletionRequest) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options.get("max_tokens") is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
input_dict: dict[str, Any] = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if media_present or not llama_model:
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
# flatten the list of lists
input_dict["messages"] = [item for sublist in contents for item in sublist]
else:
input_dict["raw"] = True
input_dict["prompt"] = await chat_completion_request_to_prompt(
request,
llama_model,
)
if fmt := request.response_format:
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["format"] = fmt.json_schema
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format is not supported")
else:
raise ValueError(f"Unknown response format type: {fmt.type}")
params = {
"model": request.model,
**input_dict,
"options": sampling_options,
"stream": request.stream,
}
logger.debug(f"params to ollama: {params}")
return params
async def register_model(self, model: Model) -> Model:
if await self.check_model_availability(model.provider_model_id):
return model
@ -197,24 +103,3 @@ class OllamaInferenceAdapter(
return model
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
return {
"role": message.role,
"images": [await convert_image_content_to_url(content, download=True, include_format=False)],
}
else:
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return {
"role": message.role,
"content": text,
}
if isinstance(message.content, list):
return [await _convert_content(c) for c in message.content]
else:
return [await _convert_content(message.content)]

View file

@ -10,6 +10,6 @@ from .config import OpenAIConfig
async def get_adapter_impl(config: OpenAIConfig, _deps):
from .openai import OpenAIInferenceAdapter
impl = OpenAIInferenceAdapter(config)
impl = OpenAIInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
@ -14,52 +13,24 @@ logger = get_logger(name=__name__, category="inference::openai")
#
# This OpenAI adapter implements Inference methods using two mixins -
# This OpenAI adapter implements Inference methods using OpenAIMixin
#
# | Inference Method | Implementation Source |
# |----------------------------|--------------------------|
# | completion | LiteLLMOpenAIMixin |
# | chat_completion | LiteLLMOpenAIMixin |
# | embedding | LiteLLMOpenAIMixin |
# | openai_completion | OpenAIMixin |
# | openai_chat_completion | OpenAIMixin |
# | openai_embeddings | OpenAIMixin |
#
class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class OpenAIInferenceAdapter(OpenAIMixin):
"""
OpenAI Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
embedding_model_metadata = {
config: OpenAIConfig
provider_data_api_key_field: str = "openai_api_key"
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
}
def __init__(self, config: OpenAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="openai",
api_key_from_config=config.api_key,
provider_data_api_key_field="openai_api_key",
)
self.config = config
# we set is_openai_compat so users can use the canonical
# openai model names like "gpt-4" or "gpt-3.5-turbo"
# and the model name will be translated to litellm's
# "openai/gpt-4" or "openai/gpt-3.5-turbo" transparently.
# if we do not set this, users will be exposed to the
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self) -> str:
"""
@ -68,9 +39,3 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
Returns the OpenAI API base URL from the configuration.
"""
return self.config.base_url
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()

View file

@ -31,12 +31,6 @@ class PassthroughInferenceAdapter(Inference):
ModelRegistryHelper.__init__(self)
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -53,12 +53,6 @@ class RunpodInferenceAdapter(
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
self.config = config
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),

View file

@ -11,6 +11,6 @@ async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
from .sambanova import SambaNovaInferenceAdapter
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"
impl = SambaNovaInferenceAdapter(config)
impl = SambaNovaInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -5,39 +5,22 @@
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import SambaNovaImplConfig
class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class SambaNovaInferenceAdapter(OpenAIMixin):
config: SambaNovaImplConfig
provider_data_api_key_field: str = "sambanova_api_key"
download_images: bool = True # SambaNova does not support image downloads server-size, perform them on the client
"""
SambaNova Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of LiteLLMOpenAIMixin.check_model_availability().
- OpenAIMixin.check_model_availability() queries the /v1/models to check if a model exists
- LiteLLMOpenAIMixin.check_model_availability() checks the static registry within LiteLLM
"""
def __init__(self, config: SambaNovaImplConfig):
self.config = config
self.environment_available_models: list[str] = []
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="sambanova",
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
provider_data_api_key_field="sambanova_api_key",
openai_compat_api_base=self.config.url,
download_images=True, # SambaNova requires base64 image encoding
json_schema_strict=False, # SambaNova doesn't support strict=True yet
)
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value() if self.config.api_key else ""
def get_base_url(self) -> str:
"""

View file

@ -5,53 +5,21 @@
# the root directory of this source tree.
from collections.abc import Iterable
from huggingface_hub import AsyncInferenceClient, HfApi
from pydantic import SecretStr
from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
OpenAIEmbeddingsResponse,
ResponseFormat,
ResponseFormatType,
SamplingParams,
)
from llama_stack.apis.models import Model
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
)
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
log = get_logger(name=__name__, category="inference::tgi")
def build_hf_repo_model_entries():
return [
build_hf_repo_model_entry(
model.huggingface_repo,
model.descriptor(),
)
for model in all_registered_models()
if model.huggingface_repo
]
class _HfAdapter(
OpenAIMixin,
Inference,
):
class _HfAdapter(OpenAIMixin):
url: str
api_key: SecretStr
@ -61,90 +29,14 @@ class _HfAdapter(
overwrite_completion_id = True # TGI always returns id=""
def __init__(self) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
}
def get_api_key(self):
return self.api_key.get_secret_value()
def get_base_url(self):
return self.url
async def shutdown(self) -> None:
pass
async def list_models(self) -> list[Model] | None:
models = []
async for model in self.client.models.list():
models.append(
Model(
identifier=model.id,
provider_resource_id=model.id,
provider_id=self.__provider_id__,
metadata={},
model_type=ModelType.llm,
)
)
return models
async def register_model(self, model: Model) -> Model:
if model.provider_resource_id != self.model_id:
raise ValueError(
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
)
return model
async def unregister_model(self, model_id: str) -> None:
pass
def _get_max_new_tokens(self, sampling_params, input_tokens):
return min(
sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
def _build_options(
self,
sampling_params: SamplingParams | None = None,
fmt: ResponseFormat = None,
):
options = get_sampling_options(sampling_params)
# TGI does not support temperature=0 when using greedy sampling
# We set it to 1e-3 instead, anything lower outputs garbage from TGI
# We can use top_p sampling strategy to specify lower temperature
if abs(options["temperature"]) < 1e-10:
options["temperature"] = 1e-3
# delete key "max_tokens" from options since its not supported by the API
options.pop("max_tokens", None)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = {
"type": "json",
"value": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise ValueError("Grammar response format not supported yet")
else:
raise ValueError(f"Unexpected response format: {fmt.type}")
return options
async def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = await chat_completion_request_to_model_input_info(
request, self.register_helper.get_llama_model(request.model)
)
return dict(
prompt=prompt,
stream=request.stream,
details=True,
max_new_tokens=self._get_max_new_tokens(request.sampling_params, input_tokens),
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**self._build_options(request.sampling_params, request.response_format),
)
async def list_provider_model_ids(self) -> Iterable[str]:
return [self.model_id]
async def openai_embeddings(
self,

View file

@ -17,6 +17,6 @@ async def get_adapter_impl(config: TogetherImplConfig, _deps):
from .together import TogetherInferenceAdapter
assert isinstance(config, TogetherImplConfig), f"Unexpected config type: {type(config)}"
impl = TogetherInferenceAdapter(config)
impl = TogetherInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -5,41 +5,29 @@
# the root directory of this source tree.
from openai import AsyncOpenAI
from collections.abc import Iterable
from together import AsyncTogether
from together.constants import BASE_URL
from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
LogProbConfig,
OpenAIEmbeddingsResponse,
ResponseFormat,
ResponseFormatType,
SamplingParams,
)
from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.models import Model
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
request_has_media,
)
from .config import TogetherImplConfig
logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData):
embedding_model_metadata = {
class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
config: TogetherImplConfig
embedding_model_metadata: dict[str, dict[str, int]] = {
"togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768},
"BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512},
"BAAI/bge-base-en-v1.5": {"embedding_dimension": 768, "context_length": 512},
@ -47,24 +35,16 @@ class TogetherInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData)
"intfloat/multilingual-e5-large-instruct": {"embedding_dimension": 1024, "context_length": 512},
}
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self)
self.config = config
self.allowed_models = config.allowed_models
self._model_cache: dict[str, Model] = {}
_model_cache: dict[str, Model] = {}
provider_data_api_key_field: str = "together_api_key"
def get_api_key(self):
return self.config.api_key.get_secret_value()
return self.config.api_key.get_secret_value() if self.config.api_key else None
def get_base_url(self):
return BASE_URL
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
def _get_client(self) -> AsyncTogether:
together_api_key = None
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
@ -79,83 +59,9 @@ class TogetherInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData)
together_api_key = provider_data.together_api_key
return AsyncTogether(api_key=together_api_key)
def _get_openai_client(self) -> AsyncOpenAI:
together_client = self._get_client().client
return AsyncOpenAI(
base_url=together_client.base_url,
api_key=together_client.api_key,
)
def _build_options(
self,
sampling_params: SamplingParams | None,
logprobs: LogProbConfig | None,
fmt: ResponseFormat,
) -> dict:
options = get_sampling_options(sampling_params)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
if logprobs.top_k != 1:
raise ValueError(
f"Unsupported value: Together only supports logprobs top_k=1. {logprobs.top_k} was provided",
)
options["logprobs"] = 1
return options
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
params = {
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
}
logger.debug(f"params to together: {params}")
return params
async def list_models(self) -> list[Model] | None:
self._model_cache = {}
async def list_provider_model_ids(self) -> Iterable[str]:
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
for m in await self._get_client().models.list():
if m.type == "embedding":
if m.id not in self.embedding_model_metadata:
logger.warning(f"Unknown embedding dimension for model {m.id}, skipping.")
continue
metadata = self.embedding_model_metadata[m.id]
self._model_cache[m.id] = Model(
provider_id=self.__provider_id__,
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.embedding,
metadata=metadata,
)
else:
self._model_cache[m.id] = Model(
provider_id=self.__provider_id__,
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.llm,
)
return self._model_cache.values()
return [m.id for m in await self._get_client().models.list()]
async def should_refresh_models(self) -> bool:
return True
@ -203,4 +109,4 @@ class TogetherInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData)
)
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
return response
return response # type: ignore[no-any-return]

View file

@ -10,6 +10,6 @@ from .config import VertexAIConfig
async def get_adapter_impl(config: VertexAIConfig, _deps):
from .vertexai import VertexAIInferenceAdapter
impl = VertexAIInferenceAdapter(config)
impl = VertexAIInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,29 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
import google.auth.transport.requests
from google.auth import default
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import VertexAIConfig
class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: VertexAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="vertex_ai",
api_key_from_config=None, # Vertex AI uses ADC, not API keys
provider_data_api_key_field="vertex_project", # Use project for validation
)
self.config = config
class VertexAIInferenceAdapter(OpenAIMixin):
config: VertexAIConfig
provider_data_api_key_field: str = "vertex_project"
def get_api_key(self) -> str:
"""
@ -41,8 +31,7 @@ class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
credentials.refresh(google.auth.transport.requests.Request())
return str(credentials.token)
except Exception:
# If we can't get credentials, return empty string to let LiteLLM handle it
# This allows the LiteLLM mixin to work with ADC directly
# If we can't get credentials, return empty string to let the env work with ADC directly
return ""
def get_base_url(self) -> str:
@ -53,23 +42,3 @@ class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
"""
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# Add Vertex AI specific parameters
provider_data = self.get_request_provider_data()
if provider_data:
if getattr(provider_data, "vertex_project", None):
params["vertex_project"] = provider_data.vertex_project
if getattr(provider_data, "vertex_location", None):
params["vertex_location"] = provider_data.vertex_location
else:
params["vertex_project"] = self.config.project
params["vertex_location"] = self.config.location
# Remove api_key since Vertex AI uses ADC
params.pop("api_key", None)
return params

View file

@ -17,6 +17,6 @@ async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
from .vllm import VLLMInferenceAdapter
assert isinstance(config, VLLMInferenceAdapterConfig), f"Unexpected config type: {type(config)}"
impl = VLLMInferenceAdapter(config)
impl = VLLMInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -3,56 +3,27 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import AsyncIterator
from typing import Any
from urllib.parse import urljoin
import httpx
from openai import APIConnectionError
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from pydantic import ConfigDict
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
ModelStore,
OpenAIChatCompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ToolChoice,
ToolDefinition,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
UnparseableToolCall,
convert_message_to_openai_dict,
convert_tool_call,
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -61,210 +32,15 @@ from .config import VLLMInferenceAdapterConfig
log = get_logger(name=__name__, category="inference::vllm")
def build_hf_repo_model_entries():
return [
build_hf_repo_model_entry(
model.huggingface_repo,
model.descriptor(),
)
for model in all_registered_models()
if model.huggingface_repo
]
class VLLMInferenceAdapter(OpenAIMixin):
config: VLLMInferenceAdapterConfig
model_config = ConfigDict(arbitrary_types_allowed=True)
def _convert_to_vllm_tool_calls_in_response(
tool_calls,
) -> list[ToolCall]:
if not tool_calls:
return []
provider_data_api_key_field: str = "vllm_api_token"
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call.function.arguments,
)
for call in tool_calls
]
def _convert_to_vllm_tools_in_request(tools: list[ToolDefinition]) -> list[dict]:
compat_tools = []
for tool in tools:
# The tool.tool_name can be a str or a BuiltinTool enum. If
# it's the latter, convert to a string.
tool_name = tool.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
compat_tool = {
"type": "function",
"function": {
"name": tool_name,
"description": tool.description,
"parameters": tool.input_schema
or {
"type": "object",
"properties": {},
"required": [],
},
},
}
compat_tools.append(compat_tool)
return compat_tools
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
def _process_vllm_chat_completion_end_of_stream(
finish_reason: str | None,
last_chunk_content: str | None,
current_event_type: ChatCompletionResponseEventType,
tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
) -> list[OpenAIChatCompletionChunk]:
chunks = []
if finish_reason is not None:
stop_reason = _convert_to_vllm_finish_reason(finish_reason)
else:
stop_reason = StopReason.end_of_message
tool_call_bufs = tool_call_bufs or {}
for _index, tool_call_buf in sorted(tool_call_bufs.items()):
args_str = tool_call_buf.arguments or "{}"
try:
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=current_event_type,
delta=ToolCallDelta(
tool_call=ToolCall(
call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name,
arguments=args_str,
),
parse_status=ToolCallParseStatus.succeeded,
),
)
)
)
except Exception as e:
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=str(tool_call_buf),
parse_status=ToolCallParseStatus.failed,
),
)
)
)
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=last_chunk_content or ""),
logprobs=None,
stop_reason=stop_reason,
)
)
)
return chunks
async def _process_vllm_chat_completion_stream_response(
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
event_type = ChatCompletionResponseEventType.progress
tool_call_bufs: dict[str, UnparseableToolCall] = {}
end_of_stream_processed = False
async for chunk in stream:
if not chunk.choices:
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
return
choice = chunk.choices[0]
if choice.delta.tool_calls:
for delta_tool_call in choice.delta.tool_calls:
tool_call = convert_tool_call(delta_tool_call)
if delta_tool_call.index not in tool_call_bufs:
tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
tool_call_buf = tool_call_bufs[delta_tool_call.index]
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += (
tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments)
)
if choice.finish_reason:
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=choice.finish_reason,
last_chunk_content=choice.delta.content,
current_event_type=event_type,
tool_call_bufs=tool_call_bufs,
)
for c in chunks:
yield c
end_of_stream_processed = True
elif not choice.delta.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
)
)
event_type = ChatCompletionResponseEventType.progress
if end_of_stream_processed:
return
# the stream ended without a chunk containing finish_reason - we have to generate the
# respective completion chunks manually
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_bufs=tool_call_bufs
)
for c in chunks:
yield c
class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
model_entries=build_hf_repo_model_entries(),
litellm_provider_name="vllm",
api_key_from_config=config.api_token,
provider_data_api_key_field="vllm_api_token",
openai_compat_api_base=config.url,
)
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_token or ""
def get_base_url(self) -> str:
"""Get the base URL from config."""
@ -290,19 +66,13 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=self.__provider_id__,
provider_id=self.__provider_id__, # type: ignore[attr-defined]
metadata={},
model_type=model_type,
)
)
return models
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the remote vLLM server.
@ -324,63 +94,9 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
async def register_model(self, model: Model) -> Model:
try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
try:
res = self.client.models.list()
except APIConnectionError as e:
raise ValueError(
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
) from e
available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model {model.provider_resource_id} is not being served by vLLM. "
f"Available models: {', '.join(available_models)}"
)
return model
async def _get_params(self, request: ChatCompletionRequest) -> dict:
options = get_sampling_options(request.sampling_params)
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
input_dict: dict[str, Any] = {}
# Only include the 'tools' param if there is any. It can break things if an empty list is sent to the vLLM.
if isinstance(request, ChatCompletionRequest) and request.tools:
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
if fmt := request.response_format:
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["extra_body"] = {"guided_json": fmt.json_schema}
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
if request.logprobs and request.logprobs.top_k:
input_dict["logprobs"] = request.logprobs.top_k
return {
"model": request.model,
**input_dict,
"stream": request.stream,
**options,
}
async def openai_chat_completion(
self,
model: str,

View file

@ -65,12 +65,6 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
self._project_id = self._config.project_id
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
def _get_client(self, model_id) -> Model:
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
config_url = self._config.url

View file

@ -11,6 +11,7 @@ from collections.abc import AsyncIterator, Iterable
from typing import Any
from openai import NOT_GIVEN, AsyncOpenAI
from pydantic import BaseModel, ConfigDict
from llama_stack.apis.inference import (
Model,
@ -26,14 +27,14 @@ from llama_stack.apis.inference import (
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
logger = get_logger(name=__name__, category="providers::utils")
class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
"""
Mixin class that provides OpenAI-specific functionality for inference providers.
This class handles direct OpenAI API calls using the AsyncOpenAI client.
@ -42,12 +43,25 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
- get_api_key(): Method to retrieve the API key
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
The behavior of this class can be customized by child classes in the following ways:
- overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses
- download_images: If True, downloads images and converts to base64 for providers that require it
- embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata
- provider_data_api_key_field: Optional field name in provider data to look for API key
- list_provider_model_ids: Method to list available models from the provider
- get_extra_client_params: Method to provide extra parameters to the AsyncOpenAI client
Expected Dependencies:
- self.model_store: Injected by the Llama Stack distribution system at runtime.
This provides model registry functionality for looking up registered models.
The model_store is set in routing_tables/common.py during provider initialization.
"""
# Allow extra fields so the routing infra can inject model_store, __provider_id__, etc.
model_config = ConfigDict(extra="allow")
config: RemoteInferenceProviderConfig
# Allow subclasses to control whether to overwrite the 'id' field in OpenAI responses
# is overwritten with a client-side generated id.
#
@ -73,9 +87,6 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
# Optional field name in provider data to look for API key, which takes precedence
provider_data_api_key_field: str | None = None
# automatically set by the resolver when instantiating the provider
__provider_id__: str
@abstractmethod
def get_api_key(self) -> str:
"""
@ -123,6 +134,26 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
"""
return [m.id async for m in self.client.models.list()]
async def initialize(self) -> None:
"""
Initialize the OpenAI mixin.
This method provides a default implementation that does nothing.
Subclasses can override this method to perform initialization tasks
such as setting up clients, validating configurations, etc.
"""
pass
async def shutdown(self) -> None:
"""
Shutdown the OpenAI mixin.
This method provides a default implementation that does nothing.
Subclasses can override this method to perform cleanup tasks
such as closing connections, releasing resources, etc.
"""
pass
@property
def client(self) -> AsyncOpenAI:
"""
@ -383,7 +414,7 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
async def register_model(self, model: Model) -> Model:
if not await self.check_model_availability(model.provider_model_id):
raise ValueError(f"Model {model.provider_model_id} is not available from provider {self.__provider_id__}")
raise ValueError(f"Model {model.provider_model_id} is not available from provider {self.__provider_id__}") # type: ignore[attr-defined]
return model
async def unregister_model(self, model_id: str) -> None:
@ -399,17 +430,23 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
"""
self._model_cache = {}
# give subclasses a chance to provide custom model listing
iterable = await self.list_provider_model_ids()
try:
iterable = await self.list_provider_model_ids()
except Exception as e:
logger.error(f"{self.__class__.__name__}.list_provider_model_ids() failed with: {e}")
raise
if not hasattr(iterable, "__iter__"):
raise TypeError(
f"Failed to list models: {self.__class__.__name__}.list_provider_model_ids() must return an iterable of "
f"strings or None, but returned {type(iterable).__name__}"
f"strings, but returned {type(iterable).__name__}"
)
provider_models_ids = list(iterable)
logger.info(f"{self.__class__.__name__}.list_provider_model_ids() returned {len(provider_models_ids)} models")
for provider_model_id in provider_models_ids:
if not isinstance(provider_model_id, str):
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
if self.allowed_models and provider_model_id not in self.allowed_models:
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
continue
@ -445,3 +482,29 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
async def should_refresh_models(self) -> bool:
return False
#
# The model_dump implementations are to avoid serializing the extra fields,
# e.g. model_store, which are not pydantic.
#
def _filter_fields(self, **kwargs):
"""Helper to exclude extra fields from serialization."""
# Exclude any extra fields stored in __pydantic_extra__
if hasattr(self, "__pydantic_extra__") and self.__pydantic_extra__:
exclude = kwargs.get("exclude", set())
if not isinstance(exclude, set):
exclude = set(exclude) if exclude else set()
exclude.update(self.__pydantic_extra__.keys())
kwargs["exclude"] = exclude
return kwargs
def model_dump(self, **kwargs):
"""Override to exclude extra fields from serialization."""
kwargs = self._filter_fields(**kwargs)
return super().model_dump(**kwargs)
def model_dump_json(self, **kwargs):
"""Override to exclude extra fields from JSON serialization."""
kwargs = self._filter_fields(**kwargs)
return super().model_dump_json(**kwargs)