Merge branch 'main' into chroma

This commit is contained in:
Bwook (Byoungwook) Kim 2025-09-19 22:53:03 +09:00 committed by GitHub
commit c71bcd5479
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
124 changed files with 25574 additions and 2425 deletions

View file

@ -131,6 +131,15 @@ class ProviderSpec(BaseModel):
""",
)
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
provider_data_validator: str | None = Field(
default=None,
)
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
# used internally by the resolver; this is a hack for now
@ -145,45 +154,8 @@ class RoutingTable(Protocol):
async def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
default_factory=str,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: str = Field(
description="Fully-qualified classname of the config for this provider",
)
provider_data_validator: str | None = Field(
default=None,
)
description: str | None = Field(
default=None,
description="""
A description of the provider. This is used to display in the documentation.
""",
)
@json_schema_type
class InlineProviderSpec(ProviderSpec):
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
container_image: str | None = Field(
default=None,
description="""
@ -191,10 +163,6 @@ The container image to use for this implementation. If one is provided, pip_pack
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
""",
)
# module field is inherited from ProviderSpec
provider_data_validator: str | None = Field(
default=None,
)
description: str | None = Field(
default=None,
description="""
@ -223,10 +191,15 @@ class RemoteProviderConfig(BaseModel):
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: AdapterSpec = Field(
adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
description: str | None = Field(
default=None,
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here.
A description of the provider. This is used to display in the documentation.
""",
)
@ -234,33 +207,6 @@ API responses, specify the adapter here.
def container_image(self) -> str | None:
return None
# module field is inherited from ProviderSpec
@property
def pip_packages(self) -> list[str]:
return self.adapter.pip_packages
@property
def provider_data_validator(self) -> str | None:
return self.adapter.provider_data_validator
def remote_provider_spec(
api: Api,
adapter: AdapterSpec,
api_dependencies: list[Api] | None = None,
optional_api_dependencies: list[Api] | None = None,
) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
module=adapter.module,
adapter=adapter,
api_dependencies=api_dependencies or [],
optional_api_dependencies=optional_api_dependencies or [],
)
class HealthStatus(StrEnum):
OK = "OK"

View file

@ -75,6 +75,13 @@ class MetaReferenceEvalImpl(
)
self.benchmarks[task_def.identifier] = task_def
async def unregister_benchmark(self, benchmark_id: str) -> None:
if benchmark_id in self.benchmarks:
del self.benchmarks[benchmark_id]
key = f"{EVAL_TASKS_PREFIX}{benchmark_id}"
await self.kvstore.delete(key)
async def run_eval(
self,
benchmark_id: str,

View file

@ -63,6 +63,9 @@ class LlmAsJudgeScoringImpl(
async def register_scoring_function(self, function_def: ScoringFn) -> None:
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
self.llm_as_judge_fn.unregister_scoring_fn_def(scoring_fn_id)
async def score_batch(
self,
dataset_id: str,

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
@ -25,28 +24,26 @@ def available_providers() -> list[ProviderSpec]:
api_dependencies=[],
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="huggingface",
pip_packages=[
"datasets>=4.0.0",
],
module="llama_stack.providers.remote.datasetio.huggingface",
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
),
adapter_type="huggingface",
provider_type="remote::huggingface",
pip_packages=[
"datasets>=4.0.0",
],
module="llama_stack.providers.remote.datasetio.huggingface",
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=[
"datasets>=4.0.0",
],
module="llama_stack.providers.remote.datasetio.nvidia",
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
),
adapter_type="nvidia",
provider_type="remote::nvidia",
module="llama_stack.providers.remote.datasetio.nvidia",
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
pip_packages=[
"datasets>=4.0.0",
],
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
),
]

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
def available_providers() -> list[ProviderSpec]:
@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]:
],
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.eval,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=[
"requests",
],
module="llama_stack.providers.remote.eval.nvidia",
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
),
adapter_type="nvidia",
pip_packages=[
"requests",
],
provider_type="remote::nvidia",
module="llama_stack.providers.remote.eval.nvidia",
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
api_dependencies=[
Api.datasetio,
Api.datasets,

View file

@ -4,13 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
description="Local filesystem-based file storage provider for managing files and documents locally.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.files,
adapter=AdapterSpec(
adapter_type="s3",
pip_packages=["boto3"] + sql_store_pip_packages,
module="llama_stack.providers.remote.files.s3",
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
),
provider_type="remote::s3",
adapter_type="s3",
pip_packages=["boto3"] + sql_store_pip_packages,
module="llama_stack.providers.remote.files.s3",
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
),
]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
META_REFERENCE_DEPS = [
@ -49,176 +48,167 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
description="Sentence Transformers inference provider for text embeddings and similarity search.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="cerebras",
pip_packages=[
"cerebras_cloud_sdk",
],
module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
),
adapter_type="cerebras",
provider_type="remote::cerebras",
pip_packages=[
"cerebras_cloud_sdk",
],
module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="ollama",
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
module="llama_stack.providers.remote.inference.ollama",
description="Ollama inference provider for running local models through the Ollama runtime.",
),
adapter_type="ollama",
provider_type="remote::ollama",
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
module="llama_stack.providers.remote.inference.ollama",
description="Ollama inference provider for running local models through the Ollama runtime.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="vllm",
pip_packages=[],
module="llama_stack.providers.remote.inference.vllm",
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
description="Remote vLLM inference provider for connecting to vLLM servers.",
),
adapter_type="vllm",
provider_type="remote::vllm",
pip_packages=[],
module="llama_stack.providers.remote.inference.vllm",
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
description="Remote vLLM inference provider for connecting to vLLM servers.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="tgi",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
),
adapter_type="tgi",
provider_type="remote::tgi",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="hf::serverless",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
description="HuggingFace Inference API serverless provider for on-demand model inference.",
),
adapter_type="hf::serverless",
provider_type="remote::hf::serverless",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
description="HuggingFace Inference API serverless provider for on-demand model inference.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="hf::endpoint",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
),
provider_type="remote::hf::endpoint",
adapter_type="hf::endpoint",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="fireworks",
pip_packages=[
"fireworks-ai<=0.17.16",
],
module="llama_stack.providers.remote.inference.fireworks",
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
),
adapter_type="fireworks",
provider_type="remote::fireworks",
pip_packages=[
"fireworks-ai<=0.17.16",
],
module="llama_stack.providers.remote.inference.fireworks",
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="together",
pip_packages=[
"together",
],
module="llama_stack.providers.remote.inference.together",
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
description="Together AI inference provider for open-source models and collaborative AI development.",
),
adapter_type="together",
provider_type="remote::together",
pip_packages=[
"together",
],
module="llama_stack.providers.remote.inference.together",
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
description="Together AI inference provider for open-source models and collaborative AI development.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="bedrock",
pip_packages=["boto3"],
module="llama_stack.providers.remote.inference.bedrock",
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
),
adapter_type="bedrock",
provider_type="remote::bedrock",
pip_packages=["boto3"],
module="llama_stack.providers.remote.inference.bedrock",
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="databricks",
pip_packages=[],
module="llama_stack.providers.remote.inference.databricks",
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
),
adapter_type="databricks",
provider_type="remote::databricks",
pip_packages=[],
module="llama_stack.providers.remote.inference.databricks",
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=[],
module="llama_stack.providers.remote.inference.nvidia",
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
),
adapter_type="nvidia",
provider_type="remote::nvidia",
pip_packages=[],
module="llama_stack.providers.remote.inference.nvidia",
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="runpod",
pip_packages=[],
module="llama_stack.providers.remote.inference.runpod",
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
),
adapter_type="runpod",
provider_type="remote::runpod",
pip_packages=[],
module="llama_stack.providers.remote.inference.runpod",
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="openai",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
),
adapter_type="openai",
provider_type="remote::openai",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="anthropic",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
),
adapter_type="anthropic",
provider_type="remote::anthropic",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="gemini",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
),
adapter_type="gemini",
provider_type="remote::gemini",
pip_packages=[
"litellm",
],
module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="vertexai",
pip_packages=["litellm", "google-cloud-aiplatform"],
module="llama_stack.providers.remote.inference.vertexai",
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
adapter_type="vertexai",
provider_type="remote::vertexai",
pip_packages=[
"litellm",
"google-cloud-aiplatform",
],
module="llama_stack.providers.remote.inference.vertexai",
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
Enterprise-grade security: Uses Google Cloud's security controls and IAM
Better integration: Seamless integration with other Google Cloud services
@ -238,76 +228,73 @@ Available Models:
- vertex_ai/gemini-2.0-flash
- vertex_ai/gemini-2.5-flash
- vertex_ai/gemini-2.5-pro""",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
),
adapter_type="groq",
provider_type="remote::groq",
pip_packages=[
"litellm",
],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="llama-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
),
adapter_type="llama-openai-compat",
provider_type="remote::llama-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="sambanova",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
),
adapter_type="sambanova",
provider_type="remote::sambanova",
pip_packages=[
"litellm",
],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="passthrough",
pip_packages=[],
module="llama_stack.providers.remote.inference.passthrough",
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
),
adapter_type="passthrough",
provider_type="remote::passthrough",
pip_packages=[],
module="llama_stack.providers.remote.inference.passthrough",
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="watsonx",
pip_packages=["ibm_watsonx_ai"],
module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
),
adapter_type="watsonx",
provider_type="remote::watsonx",
pip_packages=["ibm_watsonx_ai"],
module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="azure",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.azure",
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",
description="""
provider_type="remote::azure",
adapter_type="azure",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.azure",
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",
description="""
Azure OpenAI inference provider for accessing GPT models and other Azure services.
Provider documentation
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
""",
),
),
]

View file

@ -7,7 +7,7 @@
from typing import cast
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
@ -57,14 +57,13 @@ def available_providers() -> list[ProviderSpec]:
],
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.post_training,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=["requests", "aiohttp"],
module="llama_stack.providers.remote.post_training.nvidia",
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
),
adapter_type="nvidia",
provider_type="remote::nvidia",
pip_packages=["requests", "aiohttp"],
module="llama_stack.providers.remote.post_training.nvidia",
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
),
]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
@ -48,35 +47,32 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="bedrock",
pip_packages=["boto3"],
module="llama_stack.providers.remote.safety.bedrock",
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
),
adapter_type="bedrock",
provider_type="remote::bedrock",
pip_packages=["boto3"],
module="llama_stack.providers.remote.safety.bedrock",
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=["requests"],
module="llama_stack.providers.remote.safety.nvidia",
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
description="NVIDIA's safety provider for content moderation and safety filtering.",
),
adapter_type="nvidia",
provider_type="remote::nvidia",
pip_packages=["requests"],
module="llama_stack.providers.remote.safety.nvidia",
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
description="NVIDIA's safety provider for content moderation and safety filtering.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="sambanova",
pip_packages=["litellm", "requests"],
module="llama_stack.providers.remote.safety.sambanova",
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova's safety provider for content moderation and safety filtering.",
),
adapter_type="sambanova",
provider_type="remote::sambanova",
pip_packages=["litellm", "requests"],
module="llama_stack.providers.remote.safety.sambanova",
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova's safety provider for content moderation and safety filtering.",
),
]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
@ -35,59 +34,54 @@ def available_providers() -> list[ProviderSpec]:
api_dependencies=[Api.vector_io, Api.inference, Api.files],
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="brave-search",
module="llama_stack.providers.remote.tool_runtime.brave_search",
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
description="Brave Search tool for web search capabilities with privacy-focused results.",
),
adapter_type="brave-search",
provider_type="remote::brave-search",
module="llama_stack.providers.remote.tool_runtime.brave_search",
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
description="Brave Search tool for web search capabilities with privacy-focused results.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="bing-search",
module="llama_stack.providers.remote.tool_runtime.bing_search",
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
),
adapter_type="bing-search",
provider_type="remote::bing-search",
module="llama_stack.providers.remote.tool_runtime.bing_search",
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="tavily-search",
module="llama_stack.providers.remote.tool_runtime.tavily_search",
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
description="Tavily Search tool for AI-optimized web search with structured results.",
),
adapter_type="tavily-search",
provider_type="remote::tavily-search",
module="llama_stack.providers.remote.tool_runtime.tavily_search",
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
description="Tavily Search tool for AI-optimized web search with structured results.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="wolfram-alpha",
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
),
adapter_type="wolfram-alpha",
provider_type="remote::wolfram-alpha",
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="model-context-protocol",
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
pip_packages=["mcp>=1.8.1"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
),
adapter_type="model-context-protocol",
provider_type="remote::model-context-protocol",
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
pip_packages=["mcp>=1.8.1"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
),
]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
@ -300,14 +299,16 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
Please refer to the sqlite-vec provider documentation.
""",
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="chromadb",
pip_packages=["chromadb-client"],
module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
description="""
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="chromadb",
provider_type="remote::chromadb",
pip_packages=["chromadb-client"],
module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[Chroma](https://www.trychroma.com/) is an inline and remote vector
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
That means you're not limited to storing vectors in memory or in a separate service.
@ -340,9 +341,6 @@ pip install chromadb
## Documentation
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
@ -387,14 +385,16 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
""",
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="pgvector",
pip_packages=["psycopg2-binary"],
module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
description="""
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="pgvector",
provider_type="remote::pgvector",
pip_packages=["psycopg2-binary"],
module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval.
@ -495,19 +495,18 @@ docker pull pgvector/pgvector:pg17
## Documentation
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
""",
),
),
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="weaviate",
provider_type="remote::weaviate",
pip_packages=["weaviate-client"],
module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="weaviate",
pip_packages=["weaviate-client"],
module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
description="""
description="""
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
It allows you to store and query vectors directly within a Weaviate database.
That means you're not limited to storing vectors in memory or in a separate service.
@ -538,9 +537,6 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate
## Documentation
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
@ -594,28 +590,29 @@ docker pull qdrant/qdrant
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
""",
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="qdrant",
pip_packages=["qdrant-client"],
module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
description="""
Please refer to the inline provider documentation.
""",
),
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="qdrant",
provider_type="remote::qdrant",
pip_packages=["qdrant-client"],
module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
Please refer to the inline provider documentation.
""",
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="milvus",
pip_packages=["pymilvus>=2.4.10"],
module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
description="""
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="milvus",
provider_type="remote::milvus",
pip_packages=["pymilvus>=2.4.10"],
module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
allows you to store and query vectors directly within a Milvus database.
That means you're not limited to storing vectors in memory or in a separate service.
@ -806,9 +803,6 @@ See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for m
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,

View file

@ -51,18 +51,23 @@ class NVIDIAEvalImpl(
async def shutdown(self) -> None: ...
async def _evaluator_get(self, path):
async def _evaluator_get(self, path: str):
"""Helper for making GET requests to the evaluator service."""
response = requests.get(url=f"{self.config.evaluator_url}{path}")
response.raise_for_status()
return response.json()
async def _evaluator_post(self, path, data):
async def _evaluator_post(self, path: str, data: dict[str, Any]):
"""Helper for making POST requests to the evaluator service."""
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
response.raise_for_status()
return response.json()
async def _evaluator_delete(self, path: str) -> None:
"""Helper for making DELETE requests to the evaluator service."""
response = requests.delete(url=f"{self.config.evaluator_url}{path}")
response.raise_for_status()
async def register_benchmark(self, task_def: Benchmark) -> None:
"""Register a benchmark as an evaluation configuration."""
await self._evaluator_post(
@ -75,6 +80,10 @@ class NVIDIAEvalImpl(
},
)
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""Unregister a benchmark evaluation configuration from NeMo Evaluator."""
await self._evaluator_delete(f"/v1/evaluation/configs/{DEFAULT_NAMESPACE}/{benchmark_id}")
async def run_eval(
self,
benchmark_id: str,

View file

@ -7,12 +7,10 @@
import asyncio
import base64
import uuid
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from ollama import AsyncClient # type: ignore[attr-defined]
from openai import AsyncOpenAI
from ollama import AsyncClient as AsyncOllamaClient
from llama_stack.apis.common.content_types import (
ImageContentItem,
@ -37,9 +35,6 @@ from llama_stack.apis.inference import (
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
@ -64,15 +59,14 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
b64_encode_openai_embeddings_response,
get_sampling_options,
prepare_openai_completion_params,
prepare_openai_embeddings_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
@ -89,6 +83,7 @@ logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(
OpenAIMixin,
InferenceProvider,
ModelsProtocolPrivate,
):
@ -98,23 +93,21 @@ class OllamaInferenceAdapter(
def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.config = config
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
self._openai_client = None
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
@property
def client(self) -> AsyncClient:
def ollama_client(self) -> AsyncOllamaClient:
# ollama client attaches itself to the current event loop (sadly?)
loop = asyncio.get_running_loop()
if loop not in self._clients:
self._clients[loop] = AsyncClient(host=self.config.url)
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
return self._clients[loop]
@property
def openai_client(self) -> AsyncOpenAI:
if self._openai_client is None:
url = self.config.url.rstrip("/")
self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama")
return self._openai_client
def get_api_key(self):
return "NO_KEY"
def get_base_url(self):
return self.config.url.rstrip("/") + "/v1"
async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
@ -129,7 +122,7 @@ class OllamaInferenceAdapter(
async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__
response = await self.client.list()
response = await self.ollama_client.list()
# always add the two embedding models which can be pulled on demand
models = [
@ -189,7 +182,7 @@ class OllamaInferenceAdapter(
HealthResponse: A dictionary containing the health status.
"""
try:
await self.client.ps()
await self.ollama_client.ps()
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
@ -238,7 +231,7 @@ class OllamaInferenceAdapter(
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
s = await self.ollama_client.generate(**params)
async for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
@ -254,7 +247,7 @@ class OllamaInferenceAdapter(
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.generate(**params)
r = await self.ollama_client.generate(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
@ -346,9 +339,9 @@ class OllamaInferenceAdapter(
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self.client.chat(**params)
r = await self.ollama_client.chat(**params)
else:
r = await self.client.generate(**params)
r = await self.ollama_client.generate(**params)
if "message" in r:
choice = OpenAICompatCompletionChoice(
@ -372,9 +365,9 @@ class OllamaInferenceAdapter(
async def _generate_and_convert_to_openai_compat():
if "messages" in params:
s = await self.client.chat(**params)
s = await self.ollama_client.chat(**params)
else:
s = await self.client.generate(**params)
s = await self.ollama_client.generate(**params)
async for chunk in s:
if "message" in chunk:
choice = OpenAICompatCompletionChoice(
@ -407,7 +400,7 @@ class OllamaInferenceAdapter(
assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings"
)
response = await self.client.embed(
response = await self.ollama_client.embed(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
)
@ -422,14 +415,14 @@ class OllamaInferenceAdapter(
pass # Ignore statically unknown model, will check live listing
if model.model_type == ModelType.embedding:
response = await self.client.list()
response = await self.ollama_client.list()
if model.provider_resource_id not in [m.model for m in response.models]:
await self.client.pull(model.provider_resource_id)
await self.ollama_client.pull(model.provider_resource_id)
# we use list() here instead of ps() -
# - ps() only lists running models, not available models
# - models not currently running are run by the ollama server as needed
response = await self.client.list()
response = await self.ollama_client.list()
available_models = [m.model for m in response.models]
provider_resource_id = model.provider_resource_id
@ -448,90 +441,6 @@ class OllamaInferenceAdapter(
return model
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_obj = await self._get_model(model)
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id set")
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
params = prepare_openai_embeddings_params(
model=model_obj.provider_resource_id,
input=input,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
)
response = await self.openai_client.embeddings.create(**params)
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
# TODO: Investigate why model_obj.identifier is used instead of response.model
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.identifier,
usage=usage,
)
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
if not isinstance(prompt, str):
raise ValueError("Ollama does not support non-string prompts for completion")
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
return await self.openai_client.completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
@ -599,25 +508,7 @@ class OllamaInferenceAdapter(
top_p=top_p,
user=user,
)
response = await self.openai_client.chat.completions.create(**params)
return await self._adjust_ollama_chat_completion_response_ids(response)
async def _adjust_ollama_chat_completion_response_ids(
self,
response: OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk],
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
id = f"chatcmpl-{uuid.uuid4()}"
if isinstance(response, AsyncIterator):
async def stream_with_chunk_ids() -> AsyncIterator[OpenAIChatCompletionChunk]:
async for chunk in response:
chunk.id = id
yield chunk
return stream_with_chunk_ids()
else:
response.id = id
return response
return await OpenAIMixin.openai_chat_completion(self, **params)
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:

View file

@ -8,6 +8,7 @@
from collections.abc import AsyncGenerator
from huggingface_hub import AsyncInferenceClient, HfApi
from pydantic import SecretStr
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -33,6 +34,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.apis.models.models import ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
@ -41,16 +43,15 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
completion_request_to_prompt_model_input_info,
@ -73,26 +74,49 @@ def build_hf_repo_model_entries():
class _HfAdapter(
OpenAIMixin,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate,
):
client: AsyncInferenceClient
url: str
api_key: SecretStr
hf_client: AsyncInferenceClient
max_tokens: int
model_id: str
overwrite_completion_id = True # TGI always returns id=""
def __init__(self) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
}
def get_api_key(self):
return self.api_key.get_secret_value()
def get_base_url(self):
return self.url
async def shutdown(self) -> None:
pass
async def list_models(self) -> list[Model] | None:
models = []
async for model in self.client.models.list():
models.append(
Model(
identifier=model.id,
provider_resource_id=model.id,
provider_id=self.__provider_id__,
metadata={},
model_type=ModelType.llm,
)
)
return models
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
if model.provider_resource_id != self.model_id:
raise ValueError(
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
@ -176,7 +200,7 @@ class _HfAdapter(
params = await self._get_params_for_completion(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
s = await self.hf_client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
finish_reason = None
@ -194,7 +218,7 @@ class _HfAdapter(
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_completion(request)
r = await self.client.text_generation(**params)
r = await self.hf_client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
@ -241,7 +265,7 @@ class _HfAdapter(
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = await self.client.text_generation(**params)
r = await self.hf_client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
@ -256,7 +280,7 @@ class _HfAdapter(
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
s = await self.hf_client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
@ -308,18 +332,21 @@ class TGIAdapter(_HfAdapter):
if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(model=config.url, provider="hf-inference")
endpoint_info = await self.client.get_endpoint_info()
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
self.url = f"{config.url.rstrip('/')}/v1"
self.api_key = SecretStr("NO_KEY")
class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
endpoint_info = await self.client.get_endpoint_info()
self.hf_client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
# TODO: how do we set url for this?
class InferenceEndpointAdapter(_HfAdapter):
@ -331,6 +358,7 @@ class InferenceEndpointAdapter(_HfAdapter):
endpoint.wait(timeout=60)
# Initialize the adapter
self.client = endpoint.async_client
self.hf_client = endpoint.async_client
self.model_id = endpoint.repository
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
# TODO: how do we set url for this?

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
@ -21,57 +20,84 @@ SAFETY_MODELS_ENTRIES = [
CoreModelId.llama_guard_3_11b_vision.value,
),
]
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value,
),
ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
),
ProviderModelEntry(
# source: https://docs.together.ai/docs/serverless-models#embedding-models
EMBEDDING_MODEL_ENTRIES = {
"togethercomputer/m2-bert-80M-32k-retrieval": ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-32k-retrieval",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
"context_length": 32768,
},
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
"BAAI/bge-large-en-v1.5": ProviderModelEntry(
provider_model_id="BAAI/bge-large-en-v1.5",
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
"BAAI/bge-base-en-v1.5": ProviderModelEntry(
provider_model_id="BAAI/bge-base-en-v1.5",
metadata={
"embedding_dimension": 768,
"context_length": 512,
},
),
] + SAFETY_MODELS_ENTRIES
"Alibaba-NLP/gte-modernbert-base": ProviderModelEntry(
provider_model_id="Alibaba-NLP/gte-modernbert-base",
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
),
"intfloat/multilingual-e5-large-instruct": ProviderModelEntry(
provider_model_id="intfloat/multilingual-e5-large-instruct",
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
}
MODEL_ENTRIES = (
[
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]
+ SAFETY_MODELS_ENTRIES
+ list(EMBEDDING_MODEL_ENTRIES.values())
)

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from collections.abc import AsyncGenerator
from openai import AsyncOpenAI
from openai import NOT_GIVEN, AsyncOpenAI
from together import AsyncTogether
from together.constants import BASE_URL
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -23,12 +23,7 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
ResponseFormatType,
SamplingParams,
@ -38,18 +33,20 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
from llama_stack.apis.models import Model, ModelType
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
@ -59,15 +56,22 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import TogetherImplConfig
from .models import MODEL_ENTRIES
from .models import EMBEDDING_MODEL_ENTRIES, MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config
self._model_cache: dict[str, Model] = {}
def get_api_key(self):
return self.config.api_key.get_secret_value()
def get_base_url(self):
return BASE_URL
async def initialize(self) -> None:
pass
@ -255,6 +259,37 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
embeddings = [item.embedding for item in r.data]
return EmbeddingsResponse(embeddings=embeddings)
async def list_models(self) -> list[Model] | None:
self._model_cache = {}
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
for m in await self._get_client().models.list():
if m.type == "embedding":
if m.id not in EMBEDDING_MODEL_ENTRIES:
logger.warning(f"Unknown embedding dimension for model {m.id}, skipping.")
continue
self._model_cache[m.id] = Model(
provider_id=self.__provider_id__,
provider_resource_id=EMBEDDING_MODEL_ENTRIES[m.id].provider_model_id,
identifier=m.id,
model_type=ModelType.embedding,
metadata=EMBEDDING_MODEL_ENTRIES[m.id].metadata,
)
else:
self._model_cache[m.id] = Model(
provider_id=self.__provider_id__,
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.llm,
)
return self._model_cache.values()
async def should_refresh_models(self) -> bool:
return True
async def check_model_availability(self, model):
return model in self._model_cache
async def openai_embeddings(
self,
model: str,
@ -263,125 +298,39 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
"""
Together's OpenAI-compatible embeddings endpoint is not compatible with
the standard OpenAI embeddings endpoint.
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
The endpoint -
- does not return usage information
- does not support user param, returns 400 Unrecognized request arguments supplied: user
- does not support dimensions param, returns 400 Unrecognized request arguments supplied: dimensions
- does not support encoding_format param, always returns floats, never base64
"""
# Together support ticket #13332 -> will not fix
if user is not None:
raise ValueError("Together's embeddings endpoint does not support user param.")
# Together support ticket #13333 -> escalated
if dimensions is not None:
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
# Together support ticket #13331 -> will not fix, compute client side
if encoding_format not in (None, NOT_GIVEN, "float"):
raise ValueError("Together's embeddings endpoint only supports encoding_format='float'.")
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
)
return await self._get_openai_client().completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
if params.get("stream", False):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
response.model = model # return the user the same model id they provided, avoid exposing the provider model id
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
# together.ai sometimes adds usage data to the stream, even if include_usage is False
# This causes an unexpected final chunk with empty choices array to be sent
# to clients that may not handle it gracefully.
include_usage = False
if params.get("stream_options", None):
include_usage = params["stream_options"].get("include_usage", False)
stream = await self._get_openai_client().chat.completions.create(**params)
# Together support ticket #13330 -> escalated
# - togethercomputer/m2-bert-80M-32k-retrieval *does not* return usage information
if not hasattr(response, "usage") or response.usage is None:
logger.warning(
f"Together's embedding endpoint for {model} did not return usage information, substituting -1s."
)
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
seen_finish_reason = False
async for chunk in stream:
# Final usage chunk with no choices that the user didn't request, so discard
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
break
yield chunk
for choice in chunk.choices:
if choice.finish_reason:
seen_finish_reason = True
break
return response

View file

@ -4,9 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import VLLMInferenceAdapterConfig
class VLLMProviderDataValidator(BaseModel):
vllm_api_token: str | None = None
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
from .vllm import VLLMInferenceAdapter

View file

@ -4,8 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from urllib.parse import urljoin
import httpx
from openai import APIConnectionError, AsyncOpenAI
@ -55,6 +56,7 @@ from llama_stack.providers.datatypes import (
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
@ -62,6 +64,7 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
UnparseableToolCall,
convert_message_to_openai_dict,
convert_openai_chat_completion_stream,
convert_tool_call,
get_sampling_options,
process_chat_completion_stream_response,
@ -281,15 +284,31 @@ async def _process_vllm_chat_completion_stream_response(
yield c
class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
build_hf_repo_model_entries(),
litellm_provider_name="vllm",
api_key_from_config=config.api_token,
provider_data_api_key_field="vllm_api_token",
openai_compat_api_base=config.url,
)
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""Get the base URL from config."""
if not self.config.url:
raise ValueError("No base URL configured")
return self.config.url
async def initialize(self) -> None:
if not self.config.url:
raise ValueError(
@ -297,6 +316,7 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
)
async def should_refresh_models(self) -> bool:
# Strictly respecting the refresh_models directive
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
@ -325,13 +345,19 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
Performs a health check by verifying connectivity to the remote vLLM server.
This method is used by the Provider API to verify
that the service is running correctly.
Uses the unauthenticated /health endpoint.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
_ = [m async for m in self.client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK)
base_url = self.get_base_url()
health_url = urljoin(base_url, "health")
async with httpx.AsyncClient() as client:
response = await client.get(health_url)
response.raise_for_status()
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
@ -340,16 +366,10 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
def get_api_key(self):
return self.config.api_token
def get_base_url(self):
return self.config.url
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
async def completion(
async def completion( # type: ignore[override] # Return type more specific than base class which is allows for both streaming and non-streaming responses.
self,
model_id: str,
content: InterleavedContent,
@ -411,13 +431,14 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request, self.client)
return self._stream_chat_completion_with_client(request, self.client)
else:
return await self._nonstream_chat_completion(request, self.client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> ChatCompletionResponse:
assert self.client is not None
params = await self._get_params(request)
r = await client.chat.completions.create(**params)
choice = r.choices[0]
@ -431,9 +452,24 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
)
return result
async def _stream_chat_completion(
async def _stream_chat_completion(self, response: Any) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
# This method is called from LiteLLMOpenAIMixin.chat_completion
# The response parameter contains the litellm response
# We need to convert it to our format
async def _stream_generator():
async for chunk in response:
yield chunk
async for chunk in convert_openai_chat_completion_stream(
_stream_generator(), enable_incremental_tool_calls=True
):
yield chunk
async def _stream_chat_completion_with_client(
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
"""Helper method for streaming with explicit client parameter."""
assert self.client is not None
params = await self._get_params(request)
stream = await client.chat.completions.create(**params)
@ -445,7 +481,8 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
assert self.client is not None
if self.client is None:
raise RuntimeError("Client is not initialized")
params = await self._get_params(request)
r = await self.client.completions.create(**params)
return process_completion_response(r)
@ -453,7 +490,8 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
async def _stream_completion(
self, request: CompletionRequest
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
assert self.client is not None
if self.client is None:
raise RuntimeError("Client is not initialized")
params = await self._get_params(request)
stream = await self.client.completions.create(**params)

View file

@ -26,11 +26,11 @@ class WatsonXConfig(BaseModel):
)
api_key: SecretStr | None = Field(
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
description="The watsonx API key, only needed of using the hosted service",
description="The watsonx API key",
)
project_id: str | None = Field(
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
description="The Project ID key, only needed of using the hosted service",
description="The Project ID key",
)
timeout: int = Field(
default=60,

View file

@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
@ -57,14 +58,29 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from . import WatsonXConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::watsonx")
# Note on structured output
# WatsonX returns responses with a json embedded into a string.
# Examples:
# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n
# "first_name": "Michael",\n "last_name": "Jordan",\n'...)
# Not even a valid JSON, but we can still extract the JSON from the content
# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan",
# "year_born": "1963", "year_retired": "2003"\\}}$')
# Find the start of the boxed content
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: WatsonXConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
print(f"Initializing watsonx InferenceAdapter({config.url})...")
logger.info(f"Initializing watsonx InferenceAdapter({config.url})...")
self._config = config
self._openai_client: AsyncOpenAI | None = None
self._project_id = self._config.project_id

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import hashlib
import uuid
from typing import Any
@ -49,10 +50,13 @@ def convert_id(_id: str) -> str:
Converts any string into a UUID string based on a seed.
Qdrant accepts UUID strings and unsigned integers as point ID.
We use a seed to convert each string into a UUID string deterministically.
We use a SHA-256 hash to convert each string into a UUID string deterministically.
This allows us to overwrite the same point with the original ID.
"""
return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id))
hash_input = f"qdrant_id:{_id}".encode()
sha256_hash = hashlib.sha256(hash_input).hexdigest()
# Use the first 32 characters to create a valid UUID
return str(uuid.UUID(sha256_hash[:32]))
class QdrantIndex(EmbeddingIndex):

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any
import openai
from openai import NOT_GIVEN, AsyncOpenAI
from llama_stack.apis.inference import (
@ -22,6 +22,7 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
@ -43,6 +44,16 @@ class OpenAIMixin(ABC):
The model_store is set in routing_tables/common.py during provider initialization.
"""
# Allow subclasses to control whether to overwrite the 'id' field in OpenAI responses
# is overwritten with a client-side generated id.
#
# This is useful for providers that do not return a unique id in the response.
overwrite_completion_id: bool = False
# Cache of available models keyed by model ID
# This is set in list_models() and used in check_model_availability()
_model_cache: dict[str, Model] = {}
@abstractmethod
def get_api_key(self) -> str:
"""
@ -110,6 +121,23 @@ class OpenAIMixin(ABC):
raise ValueError(f"Model {model} has no provider_resource_id")
return model_obj.provider_resource_id
async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
if not self.overwrite_completion_id:
return resp
new_id = f"cltsd-{uuid.uuid4()}"
if stream:
async def _gen():
async for chunk in resp:
chunk.id = new_id
yield chunk
return _gen()
else:
resp.id = new_id
return resp
async def openai_completion(
self,
model: str,
@ -147,7 +175,7 @@ class OpenAIMixin(ABC):
extra_body["guided_choice"] = guided_choice
# TODO: fix openai_completion to return type compatible with OpenAI's API response
return await self.client.completions.create( # type: ignore[no-any-return]
resp = await self.client.completions.create(
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
prompt=prompt,
@ -171,6 +199,8 @@ class OpenAIMixin(ABC):
extra_body=extra_body,
)
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
async def openai_chat_completion(
self,
model: str,
@ -200,8 +230,7 @@ class OpenAIMixin(ABC):
"""
Direct OpenAI chat completion API call.
"""
# Type ignore because return types are compatible
return await self.client.chat.completions.create( # type: ignore[no-any-return]
resp = await self.client.chat.completions.create(
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
messages=messages,
@ -229,6 +258,8 @@ class OpenAIMixin(ABC):
)
)
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
async def openai_embeddings(
self,
model: str,
@ -269,22 +300,35 @@ class OpenAIMixin(ABC):
usage=usage,
)
async def list_models(self) -> list[Model] | None:
"""
List available models from the provider's /v1/models endpoint.
Also, caches the models in self._model_cache for use in check_model_availability().
:return: A list of Model instances representing available models.
"""
self._model_cache = {
m.id: Model(
# __provider_id__ is dynamically added by instantiate_provider in resolver.py
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.llm,
)
async for m in self.client.models.list()
}
return list(self._model_cache.values())
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from OpenAI.
Check if a specific model is available from the provider's /v1/models.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
# Direct model lookup - returns model or raises NotFoundError
await self.client.models.retrieve(model)
return True
except openai.NotFoundError:
# Model doesn't exist - this is expected for unavailable models
pass
except Exception as e:
# All other errors (auth, rate limit, network, etc.)
logger.warning(f"Failed to check model availability for {model}: {e}")
if not self._model_cache:
await self.list_models()
return False
return model in self._model_cache

View file

@ -12,14 +12,12 @@ import uuid
def generate_chunk_id(document_id: str, chunk_text: str, chunk_window: str | None = None) -> str:
"""
Generate a unique chunk ID using a hash of the document ID and chunk text.
Note: MD5 is used only to calculate an identifier, not for security purposes.
Adding usedforsecurity=False for compatibility with FIPS environments.
Then use the first 32 characters of the hash to create a UUID.
"""
hash_input = f"{document_id}:{chunk_text}".encode()
if chunk_window:
hash_input += f":{chunk_window}".encode()
return str(uuid.UUID(hashlib.md5(hash_input, usedforsecurity=False).hexdigest()))
return str(uuid.UUID(hashlib.sha256(hash_input).hexdigest()[:32]))
def proper_case(s: str) -> str: