Merge origin/main into add-missing-provider-data-impls

Resolved conflicts in:
- benchmarking/k8s-benchmark/stack_run_config.yaml (accepted new storage schema)
- llama_stack/providers/remote/inference/cerebras/cerebras.py (kept provider data support)
- llama_stack/providers/remote/inference/cerebras/config.py (kept provider data support)
- llama_stack/providers/remote/inference/nvidia/config.py (kept provider data support)
- llama_stack/providers/remote/inference/runpod/config.py (merged imports)
- pyproject.toml (kept databricks-sdk dependency)
This commit is contained in:
Ashwin Bharambe 2025-10-27 11:39:00 -07:00
commit 9eb9a37ee4
1880 changed files with 804868 additions and 70533 deletions

View file

@ -7,20 +7,17 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.core.storage.datatypes import KVStoreReference
class HuggingfaceDatasetIOConfig(BaseModel):
kvstore: KVStoreConfig
kvstore: KVStoreReference
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="huggingface_datasetio.db",
)
"kvstore": KVStoreReference(
backend="kv_default",
namespace="datasetio::huggingface",
).model_dump(exclude_none=True)
}

View file

@ -20,7 +20,7 @@ This provider enables dataset management using NVIDIA's NeMo Customizer service.
Build the NVIDIA environment:
```bash
llama stack build --distro nvidia --image-type venv
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
```
### Basic Usage using the LlamaStack Python Client

View file

@ -8,7 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
from llama_stack.core.storage.datatypes import SqlStoreReference
class S3FilesImplConfig(BaseModel):
@ -24,7 +24,7 @@ class S3FilesImplConfig(BaseModel):
auto_create_bucket: bool = Field(
default=False, description="Automatically create the S3 bucket if it doesn't exist"
)
metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata")
metadata_store: SqlStoreReference = Field(description="SQL store configuration for file metadata")
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
@ -35,8 +35,8 @@ class S3FilesImplConfig(BaseModel):
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="s3_files_metadata.db",
),
"metadata_store": SqlStoreReference(
backend="sql_default",
table_name="s3_files_metadata",
).model_dump(exclude_none=True),
}

View file

@ -23,6 +23,7 @@ from llama_stack.apis.files import (
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
@ -198,7 +199,7 @@ class S3FilesImpl(Files):
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}"
file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
filename = getattr(file, "filename", None) or "uploaded_file"

View file

@ -29,9 +29,6 @@ class AnthropicInferenceAdapter(OpenAIMixin):
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
# }
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self):
return "https://api.anthropic.com/v1"

View file

@ -21,11 +21,6 @@ class AnthropicProviderDataValidator(BaseModel):
@json_schema_type
class AnthropicConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for Anthropic models",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {

View file

@ -16,9 +16,6 @@ class AzureInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "azure_api_key"
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
def get_base_url(self) -> str:
"""
Get the Azure API base URL.

View file

@ -32,9 +32,6 @@ class AzureProviderDataValidator(BaseModel):
@json_schema_type
class AzureConfig(RemoteInferenceProviderConfig):
api_key: SecretStr = Field(
description="Azure API key for Azure",
)
api_base: HttpUrl = Field(
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
)

View file

@ -6,21 +6,21 @@
import json
from collections.abc import AsyncIterator
from typing import Any
from botocore.client import BaseClient
from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
@ -125,66 +125,18 @@ class BedrockInferenceAdapter(
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
# Standard OpenAI completion parameters
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")

View file

@ -25,8 +25,9 @@ class DatabricksImplConfig(RemoteInferenceProviderConfig):
default=None,
description="The URL for the Databricks model serving endpoint",
)
api_token: SecretStr = Field(
default=SecretStr(None), # type: ignore[arg-type]
auth_credential: SecretStr | None = Field(
default=None,
alias="api_token",
description="The Databricks API token",
)

View file

@ -5,11 +5,10 @@
# the root directory of this source tree.
from collections.abc import Iterable
from typing import Any
from databricks.sdk import WorkspaceClient
from llama_stack.apis.inference import OpenAICompletion
from llama_stack.apis.inference import OpenAICompletion, OpenAICompletionRequestWithExtraBody
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -29,9 +28,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
}
def get_api_key(self) -> str:
return self.config.api_token.get_secret_value()
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
@ -45,25 +41,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError()

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -18,10 +18,6 @@ class FireworksImplConfig(RemoteInferenceProviderConfig):
default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server",
)
api_key: SecretStr | None = Field(
default=None,
description="The Fireworks.ai API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:

View file

@ -23,8 +23,5 @@ class FireworksInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "fireworks_api_key"
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value() if self.config.api_key else None # type: ignore[return-value]
def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1"

View file

@ -21,11 +21,6 @@ class GeminiProviderDataValidator(BaseModel):
@json_schema_type
class GeminiConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for Gemini models",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {

View file

@ -4,6 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from openai import NOT_GIVEN
from llama_stack.apis.inference import (
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import GeminiConfig
@ -14,11 +22,61 @@ class GeminiInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "gemini_api_key"
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
"models/text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
"models/gemini-embedding-001": {"embedding_dimension": 3072, "context_length": 2048},
}
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self):
return "https://generativelanguage.googleapis.com/v1beta/openai/"
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
"""
Override embeddings method to handle Gemini's missing usage statistics.
Gemini's embedding API doesn't return usage information, so we provide default values.
"""
# Prepare request parameters
request_params = {
"model": await self._get_provider_model_id(params.model),
"input": params.input,
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
"user": params.user if params.user is not None else NOT_GIVEN,
}
# Add extra_body if present
extra_body = params.model_extra
if extra_body:
request_params["extra_body"] = extra_body
# Call OpenAI embeddings API with properly typed parameters
response = await self.client.embeddings.create(**request_params)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
# Gemini doesn't return usage statistics - use default values
if hasattr(response, "usage") and response.usage:
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
else:
usage = OpenAIEmbeddingUsage(
prompt_tokens=0,
total_tokens=0,
)
return OpenAIEmbeddingsResponse(
data=data,
model=params.model,
usage=usage,
)

View file

@ -21,12 +21,6 @@ class GroqProviderDataValidator(BaseModel):
@json_schema_type
class GroqConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
# The Groq client library loads the GROQ_API_KEY environment variable by default
default=None,
description="The Groq API key",
)
url: str = Field(
default="https://api.groq.com",
description="The URL for the Groq AI server",

View file

@ -14,8 +14,5 @@ class GroqInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "groq_api_key"
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self) -> str:
return f"{self.config.url}/openai/v1"

View file

@ -21,11 +21,6 @@ class LlamaProviderDataValidator(BaseModel):
@json_schema_type
class LlamaCompatConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="The Llama API key",
)
openai_compat_api_base: str = Field(
default="https://api.llama.com/compat/v1/",
description="The URL for the Llama API server",

View file

@ -3,9 +3,13 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.inference.inference import OpenAICompletion, OpenAIEmbeddingsResponse
from llama_stack.apis.inference.inference import (
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -21,9 +25,6 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
Llama API Inference Adapter for Llama Stack.
"""
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self) -> str:
"""
Get the base URL for OpenAI mixin.
@ -34,35 +35,12 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()

View file

@ -18,7 +18,7 @@ This provider enables running inference using NVIDIA NIM.
Build the NVIDIA environment:
```bash
llama stack build --distro nvidia --image-type venv
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
```
### Basic Usage using the LlamaStack Python Client
@ -45,7 +45,7 @@ The following example shows how to create a chat completion for an NVIDIA NIM.
```python
response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
model="nvidia/meta/llama-3.1-8b-instruct",
messages=[
{
"role": "system",
@ -67,37 +67,40 @@ print(f"Response: {response.choices[0].message.content}")
The following example shows how to do tool calling for an NVIDIA NIM.
```python
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
tool_definition = ToolDefinition(
tool_name="get_weather",
description="Get current weather information for a location",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
required=True,
),
"unit": ToolParamDefinition(
param_type="string",
description="Temperature unit (celsius or fahrenheit)",
required=False,
default="celsius",
),
tool_definition = {
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather information for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"description": "Temperature unit (celsius or fahrenheit)",
"default": "celsius",
},
},
"required": ["location"],
},
},
)
}
tool_response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
model="nvidia/meta/llama-3.1-8b-instruct",
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
tools=[tool_definition],
)
print(f"Tool Response: {tool_response.choices[0].message.content}")
print(f"Response content: {tool_response.choices[0].message.content}")
if tool_response.choices[0].message.tool_calls:
for tool_call in tool_response.choices[0].message.tool_calls:
print(f"Tool Called: {tool_call.tool_name}")
print(f"Arguments: {tool_call.arguments}")
print(f"Tool Called: {tool_call.function.name}")
print(f"Arguments: {tool_call.function.arguments}")
```
### Structured Output Example
@ -105,33 +108,26 @@ if tool_response.choices[0].message.tool_calls:
The following example shows how to do structured output for an NVIDIA NIM.
```python
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
person_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"age": {"type": "number"},
"occupation": {"type": "string"},
},
"required": ["name", "age", "occupation"],
}
response_format = JsonSchemaResponseFormat(
type=ResponseFormatType.json_schema, json_schema=person_schema
)
structured_response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
model="nvidia/meta/llama-3.1-8b-instruct",
messages=[
{
"role": "user",
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
}
],
response_format=response_format,
extra_body={"nvext": {"guided_json": person_schema}},
)
print(f"Structured Response: {structured_response.choices[0].message.content}")
```
@ -139,16 +135,13 @@ print(f"Structured Response: {structured_response.choices[0].message.content}")
The following example shows how to create embeddings for an NVIDIA NIM.
> [!NOTE]
> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`.
```python
response = client.inference.embeddings(
model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
contents=["What is the capital of France?"],
task_type="query",
response = client.embeddings.create(
model="nvidia/nvidia/llama-3.2-nv-embedqa-1b-v2",
input=["What is the capital of France?"],
extra_body={"input_type": "query"},
)
print(f"Embeddings: {response.embeddings}")
print(f"Embeddings: {response.data}")
```
### Vision Language Models Example
@ -166,15 +159,15 @@ image_path = {path_to_the_image}
demo_image_b64 = load_image_as_base64(image_path)
vlm_response = client.chat.completions.create(
model="nvidia/vila",
model="nvidia/meta/llama-3.2-11b-vision-instruct",
messages=[
{
"role": "user",
"content": [
{
"type": "image",
"image": {
"data": demo_image_b64,
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{demo_image_b64}",
},
},
{

View file

@ -10,7 +10,7 @@ from .config import NVIDIAConfig
async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
# import dynamically so `llama stack build` does not fail due to missing dependencies
# import dynamically so `llama stack list-deps` does not fail due to missing dependencies
from .nvidia import NVIDIAInferenceAdapter
if not isinstance(config, NVIDIAConfig):

View file

@ -5,13 +5,6 @@
# the root directory of this source tree.
from openai import NOT_GIVEN
from llama_stack.apis.inference import (
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -28,15 +21,6 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
"""
NVIDIA Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability(). It also
must come before Inference to ensure that OpenAIMixin methods are available
in the Inference interface.
- OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
"""
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
@ -51,7 +35,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
if _is_nvidia_hosted(self.config):
if not self.config.api_key:
if not self.config.auth_credential:
raise RuntimeError(
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
)
@ -62,7 +46,13 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
:return: The NVIDIA API key
"""
return self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY"
if self.config.auth_credential:
return self.config.auth_credential.get_secret_value()
if not _is_nvidia_hosted(self.config):
return "NO KEY REQUIRED"
return None
def get_base_url(self) -> str:
"""
@ -71,54 +61,3 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
:return: The NVIDIA API base URL
"""
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
"""
OpenAI-compatible embeddings for NVIDIA NIM.
Note: NVIDIA NIM asymmetric embedding models require an "input_type" field not present in the standard OpenAI embeddings API.
We default this to "query" to ensure requests succeed when using the
OpenAI-compatible endpoint. For passage embeddings, use the embeddings API with
`task_type='document'`.
"""
extra_body: dict[str, object] = {"input_type": "query"}
logger.warning(
"NVIDIA OpenAI-compatible embeddings: defaulting to input_type='query'. "
"For passage embeddings, use the embeddings API with task_type='document'."
)
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
user=user if user is not None else NOT_GIVEN,
extra_body=extra_body,
)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
usage=usage,
)

View file

@ -6,12 +6,16 @@
from typing import Any
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = DEFAULT_OLLAMA_URL
@classmethod

View file

@ -59,7 +59,7 @@ class OllamaInferenceAdapter(OpenAIMixin):
return self._clients[loop]
def get_api_key(self):
return "NO_KEY"
return "NO KEY REQUIRED"
def get_base_url(self):
return self.config.url.rstrip("/") + "/v1"

View file

@ -21,10 +21,6 @@ class OpenAIProviderDataValidator(BaseModel):
@json_schema_type
class OpenAIConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for OpenAI models",
)
base_url: str = Field(
default="https://api.openai.com/v1",
description="Base URL for OpenAI API",

View file

@ -29,9 +29,6 @@ class OpenAIInferenceAdapter(OpenAIMixin):
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
}
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self) -> str:
"""
Get the OpenAI API base URL.

View file

@ -13,15 +13,15 @@ from llama_stack.apis.inference import (
Inference,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model
from llama_stack.core.library_client import convert_pydantic_to_json_value
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from .config import PassthroughImplConfig
@ -70,120 +70,37 @@ class PassthroughInferenceAdapter(Inference):
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
model_obj = await self.model_store.get_model(params.model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
params = params.model_copy()
params.model = model_obj.provider_resource_id
return await client.inference.openai_completion(**params)
request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_completion(**request_params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
model_obj = await self.model_store.get_model(params.model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
params = params.model_copy()
params.model = model_obj.provider_resource_id
return await client.inference.openai_chat_completion(**params)
request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_chat_completion(**request_params)
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
json_params = {}

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -25,8 +25,9 @@ class RunpodImplConfig(RemoteInferenceProviderConfig):
default=None,
description="The URL for the Runpod model serving endpoint",
)
api_token: str | None = Field(
auth_credential: SecretStr | None = Field(
default=None,
alias="api_token",
description="The API token",
)

View file

@ -4,11 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from collections.abc import AsyncIterator
from llama_stack.apis.inference import (
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -25,66 +26,18 @@ class RunpodInferenceAdapter(OpenAIMixin):
config: RunpodImplConfig
provider_data_api_key_field: str = "runpod_api_token"
def get_api_key(self) -> str:
"""Get API key for OpenAI client."""
return self.config.api_token
def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""
return self.config.url
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
):
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Override to add RunPod-specific stream_options requirement."""
if stream and not stream_options:
stream_options = {"include_usage": True}
params = params.model_copy()
return await super().openai_chat_completion(
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
if params.stream and not params.stream_options:
params.stream_options = {"include_usage": True}
return await super().openai_chat_completion(params)

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -25,10 +25,6 @@ class SambaNovaImplConfig(RemoteInferenceProviderConfig):
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",
)
api_key: SecretStr | None = Field(
default=None,
description="The SambaNova cloud API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:

View file

@ -19,9 +19,6 @@ class SambaNovaInferenceAdapter(OpenAIMixin):
SambaNova Inference Adapter for Llama Stack.
"""
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value() if self.config.api_key else ""
def get_base_url(self) -> str:
"""
Get the base URL for OpenAI mixin.

View file

@ -13,6 +13,8 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class TGIImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = Field(
description="The URL for the TGI serving endpoint",
)

View file

@ -10,7 +10,10 @@ from collections.abc import Iterable
from huggingface_hub import AsyncInferenceClient, HfApi
from pydantic import SecretStr
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
from llama_stack.apis.inference import (
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -30,7 +33,7 @@ class _HfAdapter(OpenAIMixin):
overwrite_completion_id = True # TGI always returns id=""
def get_api_key(self):
return self.api_key.get_secret_value()
return "NO KEY REQUIRED"
def get_base_url(self):
return self.url
@ -40,11 +43,7 @@ class _HfAdapter(OpenAIMixin):
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -18,10 +18,6 @@ class TogetherImplConfig(RemoteInferenceProviderConfig):
default="https://api.together.xyz/v1",
description="The URL for the Together AI server",
)
api_key: SecretStr | None = Field(
default=None,
description="The Together AI API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:

View file

@ -11,6 +11,7 @@ from together import AsyncTogether
from together.constants import BASE_URL
from llama_stack.apis.inference import (
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
@ -39,15 +40,12 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
provider_data_api_key_field: str = "together_api_key"
def get_api_key(self):
return self.config.api_key.get_secret_value() if self.config.api_key else None
def get_base_url(self):
return BASE_URL
def _get_client(self) -> AsyncTogether:
together_api_key = None
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
config_api_key = self.config.auth_credential.get_secret_value() if self.config.auth_credential else None
if config_api_key:
together_api_key = config_api_key
else:
@ -65,11 +63,7 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
"""
Together's OpenAI-compatible embeddings endpoint is not compatible with
@ -81,25 +75,27 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
- does not support dimensions param, returns 400 Unrecognized request arguments supplied: dimensions
"""
# Together support ticket #13332 -> will not fix
if user is not None:
if params.user is not None:
raise ValueError("Together's embeddings endpoint does not support user param.")
# Together support ticket #13333 -> escalated
if dimensions is not None:
if params.dimensions is not None:
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
encoding_format=encoding_format,
model=await self._get_provider_model_id(params.model),
input=params.input,
encoding_format=params.encoding_format,
)
response.model = model # return the user the same model id they provided, avoid exposing the provider model id
response.model = (
params.model
) # return the user the same model id they provided, avoid exposing the provider model id
# Together support ticket #13330 -> escalated
# - togethercomputer/m2-bert-80M-32k-retrieval *does not* return usage information
if not hasattr(response, "usage") or response.usage is None:
logger.warning(
f"Together's embedding endpoint for {model} did not return usage information, substituting -1s."
f"Together's embedding endpoint for {params.model} did not return usage information, substituting -1s."
)
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -25,6 +25,8 @@ class VertexAIProviderDataValidator(BaseModel):
@json_schema_type
class VertexAIConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
project: str = Field(
description="Google Cloud project ID for Vertex AI",
)

View file

@ -6,7 +6,7 @@
from pathlib import Path
from pydantic import Field, field_validator
from pydantic import Field, SecretStr, field_validator
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -22,8 +22,9 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
default=4096,
description="Maximum number of tokens to generate.",
)
api_token: str | None = Field(
default="fake",
auth_credential: SecretStr | None = Field(
default=None,
alias="api_token",
description="The API token",
)
tls_verify: bool | str = Field(

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from typing import Any
from urllib.parse import urljoin
import httpx
@ -15,8 +14,7 @@ from pydantic import ConfigDict
from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAIChatCompletionRequestWithExtraBody,
ToolChoice,
)
from llama_stack.log import get_logger
@ -38,8 +36,10 @@ class VLLMInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "vllm_api_token"
def get_api_key(self) -> str:
return self.config.api_token or ""
def get_api_key(self) -> str | None:
if self.config.auth_credential:
return self.config.auth_credential.get_secret_value()
return "NO KEY REQUIRED"
def get_base_url(self) -> str:
"""Get the base URL from config."""
@ -77,63 +77,35 @@ class VLLMInferenceAdapter(OpenAIMixin):
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
async def check_model_availability(self, model: str) -> bool:
"""
Skip the check when running without authentication.
"""
if not self.config.auth_credential:
model_ids = []
async for m in self.client.models.list():
if m.id == model: # Found exact match
return True
model_ids.append(m.id)
raise ValueError(f"Model '{model}' not found. Available models: {model_ids}")
log.warning(f"Not checking model availability for {model} as API token may trigger OAuth workflow")
return True
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
max_tokens = max_tokens or self.config.max_tokens
params = params.model_copy()
# Apply vLLM-specific defaults
if params.max_tokens is None and self.config.max_tokens:
params.max_tokens = self.config.max_tokens
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# * https://github.com/vllm-project/vllm/pull/10000
if not tools and tool_choice is not None:
tool_choice = ToolChoice.none.value
if not params.tools and params.tool_choice is not None:
params.tool_choice = ToolChoice.none.value
return await super().openai_chat_completion(
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await super().openai_chat_completion(params)

View file

@ -7,18 +7,18 @@
import os
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, SecretStr
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class WatsonXProviderDataValidator(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="forbid",
watsonx_project_id: str | None = Field(
default=None,
description="IBM WatsonX project ID",
)
watsonx_api_key: str | None
watsonx_api_key: str | None = None
@json_schema_type
@ -27,14 +27,6 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai",
)
# This seems like it should be required, but none of the other remote inference
# providers require it, so this is optional here too for consistency.
# The OpenAIConfig uses default=None instead, so this is following that precedent.
api_key: SecretStr | None = Field(
default=None,
description="The watsonx.ai API key",
)
# As above, this is optional here too for consistency.
project_id: str | None = Field(
default=None,
description="The watsonx.ai project ID",

View file

@ -4,42 +4,259 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from typing import Any
import litellm
import requests
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionUsage,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.models import Model
from llama_stack.apis.models.models import ModelType
from llama_stack.core.telemetry.tracing import get_current_span
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
logger = get_logger(name=__name__, category="providers::remote::watsonx")
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
_model_cache: dict[str, Model] = {}
provider_data_api_key_field: str = "watsonx_api_key"
def __init__(self, config: WatsonXConfig):
self.available_models = None
self.config = config
api_key = config.auth_credential.get_secret_value() if config.auth_credential else None
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="watsonx",
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
api_key_from_config=api_key,
provider_data_api_key_field="watsonx_api_key",
openai_compat_api_base=self.get_base_url(),
)
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""
Override parent method to add timeout and inject usage object when missing.
This works around a LiteLLM defect where usage block is sometimes dropped.
"""
# Add usage tracking for streaming when telemetry is active
stream_options = params.stream_options
if params.stream and get_current_span() is not None:
if stream_options is None:
stream_options = {"include_usage": True}
elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True}
model_obj = await self.model_store.get_model(params.model)
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
messages=params.messages,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
functions=params.functions,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_completion_tokens=params.max_completion_tokens,
max_tokens=params.max_tokens,
n=params.n,
parallel_tool_calls=params.parallel_tool_calls,
presence_penalty=params.presence_penalty,
response_format=params.response_format,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=stream_options,
temperature=params.temperature,
tool_choice=params.tool_choice,
tools=params.tools,
top_logprobs=params.top_logprobs,
top_p=params.top_p,
user=params.user,
api_key=self.get_api_key(),
api_base=self.api_base,
# These are watsonx-specific parameters
timeout=self.config.timeout,
project_id=self.config.project_id,
)
result = await litellm.acompletion(**request_params)
# If not streaming, check and inject usage if missing
if not params.stream:
# Use getattr to safely handle cases where usage attribute might not exist
if getattr(result, "usage", None) is None:
# Create usage object with zeros
usage_obj = OpenAIChatCompletionUsage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
# Use model_copy to create a new response with the usage injected
result = result.model_copy(update={"usage": usage_obj})
return result
# For streaming, wrap the iterator to normalize chunks
return self._normalize_stream(result)
def _normalize_chunk(self, chunk: OpenAIChatCompletionChunk) -> OpenAIChatCompletionChunk:
"""
Normalize a chunk to ensure it has all expected attributes.
This works around LiteLLM not always including all expected attributes.
"""
# Ensure chunk has usage attribute with zeros if missing
if not hasattr(chunk, "usage") or chunk.usage is None:
usage_obj = OpenAIChatCompletionUsage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
chunk = chunk.model_copy(update={"usage": usage_obj})
# Ensure all delta objects in choices have expected attributes
if hasattr(chunk, "choices") and chunk.choices:
normalized_choices = []
for choice in chunk.choices:
if hasattr(choice, "delta") and choice.delta:
delta = choice.delta
# Build update dict for missing attributes
delta_updates = {}
if not hasattr(delta, "refusal"):
delta_updates["refusal"] = None
if not hasattr(delta, "reasoning_content"):
delta_updates["reasoning_content"] = None
# If we need to update delta, create a new choice with updated delta
if delta_updates:
new_delta = delta.model_copy(update=delta_updates)
new_choice = choice.model_copy(update={"delta": new_delta})
normalized_choices.append(new_choice)
else:
normalized_choices.append(choice)
else:
normalized_choices.append(choice)
# If we modified any choices, create a new chunk with updated choices
if any(normalized_choices[i] is not chunk.choices[i] for i in range(len(chunk.choices))):
chunk = chunk.model_copy(update={"choices": normalized_choices})
return chunk
async def _normalize_stream(
self, stream: AsyncIterator[OpenAIChatCompletionChunk]
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""
Normalize all chunks in the stream to ensure they have expected attributes.
This works around LiteLLM sometimes not including expected attributes.
"""
try:
async for chunk in stream:
# Normalize and yield each chunk immediately
yield self._normalize_chunk(chunk)
except Exception as e:
logger.error(f"Error normalizing stream: {e}", exc_info=True)
raise
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
"""
Override parent method to add watsonx-specific parameters.
"""
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
model_obj = await self.model_store.get_model(params.model)
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
frequency_penalty=params.frequency_penalty,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_tokens=params.max_tokens,
n=params.n,
presence_penalty=params.presence_penalty,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
temperature=params.temperature,
top_p=params.top_p,
user=params.user,
suffix=params.suffix,
api_key=self.get_api_key(),
api_base=self.api_base,
# These are watsonx-specific parameters
timeout=self.config.timeout,
project_id=self.config.project_id,
)
return await litellm.atext_completion(**request_params)
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
"""
Override parent method to add watsonx-specific parameters.
"""
model_obj = await self.model_store.get_model(params.model)
# Convert input to list if it's a string
input_list = [params.input] if isinstance(params.input, str) else params.input
# Call litellm embedding function with watsonx-specific parameters
response = litellm.embedding(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
input=input_list,
api_key=self.get_api_key(),
api_base=self.api_base,
dimensions=params.dimensions,
# These are watsonx-specific parameters
timeout=self.config.timeout,
project_id=self.config.project_id,
)
# Convert response to OpenAI format
from llama_stack.apis.inference import OpenAIEmbeddingUsage
from llama_stack.providers.utils.inference.litellm_openai_mixin import b64_encode_openai_embeddings_response
data = b64_encode_openai_embeddings_response(response.data, params.encoding_format)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response["usage"]["prompt_tokens"],
total_tokens=response["usage"]["total_tokens"],
)
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
usage=usage,
)
self.available_models = None
self.config = config
def get_base_url(self) -> str:
return self.config.url
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# Add watsonx.ai specific parameters
params["project_id"] = self.config.project_id
params["time_limit"] = self.config.timeout
return params
# Copied from OpenAIMixin
async def check_model_availability(self, model: str) -> bool:
"""

View file

@ -22,7 +22,7 @@ This provider enables fine-tuning of LLMs using NVIDIA's NeMo Customizer service
Build the NVIDIA environment:
```bash
llama stack build --distro nvidia --image-type venv
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
```
### Basic Usage using the LlamaStack Python Client

View file

@ -7,7 +7,7 @@
import json
from typing import Any
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -56,7 +56,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] = None
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:

View file

@ -19,7 +19,7 @@ This provider enables safety checks and guardrails for LLM interactions using NV
Build the NVIDIA environment:
```bash
llama stack build --distro nvidia --image-type venv
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
```
### Basic Usage using the LlamaStack Python Client

View file

@ -8,12 +8,11 @@ from typing import Any
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import NVIDIASafetyConfig
@ -44,7 +43,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
) -> RunShieldResponse:
"""
Run a safety shield check against the provided messages.
@ -67,7 +66,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
@ -118,7 +117,7 @@ class NeMoGuardrails:
response.raise_for_status()
return response.json()
async def run(self, messages: list[Message]) -> RunShieldResponse:
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
"""
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
@ -132,10 +131,9 @@ class NeMoGuardrails:
Raises:
requests.HTTPError: If the POST request fails.
"""
request_messages = [await convert_message_to_openai_dict_new(message) for message in messages]
request_data = {
"model": self.model,
"messages": request_messages,
"messages": [{"role": message.role, "content": message.content} for message in messages],
"temperature": self.temperature,
"top_p": 1,
"frequency_penalty": 0,

View file

@ -4,13 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any
import litellm
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -21,7 +20,6 @@ from llama_stack.apis.shields import Shield
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import SambaNovaSafetyConfig
@ -72,7 +70,7 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
@ -80,12 +78,8 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = litellm.completion(
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
)
response = litellm.completion(model=shield.provider_resource_id, messages=messages, api_key=self._get_api_key())
shield_message = response.choices[0].message.content
if "unsafe" in shield_message.lower():

View file

@ -12,24 +12,16 @@ import chromadb
from numpy.typing import NDArray
from llama_stack.apis.files import Files
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
@ -38,7 +30,7 @@ log = get_logger(name=__name__, category="vector_io::chroma")
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:chroma:{VERSION}::"
VECTOR_DBS_PREFIX = f"vector_stores:chroma:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:chroma:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:chroma:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::"
@ -68,19 +60,13 @@ class ChromaIndex(EmbeddingIndex):
ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
await maybe_await(
self.collection.add(
documents=[chunk.model_dump_json() for chunk in chunks],
embeddings=embeddings,
ids=ids,
)
self.collection.add(documents=[chunk.model_dump_json() for chunk in chunks], embeddings=embeddings, ids=ids)
)
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
results = await maybe_await(
self.collection.query(
query_embeddings=[embedding.tolist()],
n_results=k,
include=["documents", "distances"],
query_embeddings=[embedding.tolist()], n_results=k, include=["documents", "distances"]
)
)
distances = results["distances"][0]
@ -108,12 +94,7 @@ class ChromaIndex(EmbeddingIndex):
async def delete(self):
await maybe_await(self.client.delete_collection(self.collection.name))
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
@ -133,11 +114,11 @@ class ChromaIndex(EmbeddingIndex):
raise NotImplementedError("Hybrid search is not supported in Chroma")
class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(
self,
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
inference_api: Api.inference,
inference_api: Inference,
files_api: Files | None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
@ -146,11 +127,11 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.inference_api = inference_api
self.client = None
self.cache = {}
self.vector_db_store = None
self.vector_store_table = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
self.vector_db_store = self.kvstore
self.kvstore = await kvstore_impl(self.config.persistence)
self.vector_store_table = self.kvstore
if isinstance(self.config, RemoteChromaVectorIOConfig):
log.info(f"Connecting to Chroma server at: {self.config.url}")
@ -170,70 +151,58 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_store(self, vector_store: VectorStore) -> None:
collection = await maybe_await(
self.client.get_or_create_collection(
name=vector_db.identifier,
metadata={"vector_db": vector_db.model_dump_json()},
name=vector_store.identifier, metadata={"vector_store": vector_store.model_dump_json()}
)
)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db, ChromaIndex(self.client, collection), self.inference_api
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, ChromaIndex(self.client, collection), self.inference_api
)
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
log.warning(f"Vector DB {vector_db_id} not found")
async def unregister_vector_store(self, vector_store_id: str) -> None:
if vector_store_id not in self.cache:
log.warning(f"Vector DB {vector_store_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id]
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id)
if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
index = await self._get_and_cache_vector_store_index(vector_db_id)
if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
return await index.query_chunks(query, params)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex:
if vector_store_id in self.cache:
return self.cache[vector_store_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found in Llama Stack")
collection = await maybe_await(self.client.get_collection(vector_db_id))
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
if not vector_store:
raise ValueError(f"Vector DB {vector_store_id} not found in Llama Stack")
collection = await maybe_await(self.client.get_collection(vector_store_id))
if not collection:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
self.cache[vector_db_id] = index
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
index = VectorStoreWithIndex(vector_store, ChromaIndex(self.client, collection), self.inference_api)
self.cache[vector_store_id] = index
return index
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a Chroma vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
index = await self._get_and_cache_vector_store_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")

View file

@ -8,21 +8,21 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class ChromaVectorIOConfig(BaseModel):
url: str | None
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
persistence: KVStoreReference = Field(description="Config for KV store backend")
@classmethod
def sample_run_config(cls, __distro_dir__: str, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> dict[str, Any]:
return {
"url": url,
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="chroma_remote_registry.db",
),
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::chroma_remote",
).model_dump(exclude_none=True),
}

View file

@ -13,7 +13,6 @@ async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, Provide
from .milvus import MilvusVectorIOAdapter
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -8,7 +8,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
uri: str = Field(description="The URI of the Milvus server")
token: str | None = Field(description="The token of the Milvus server")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
persistence: KVStoreReference = Field(description="Config for KV store backend")
# This configuration allows additional fields to be passed through to the underlying Milvus client.
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
@ -28,8 +28,8 @@ class MilvusVectorIOConfig(BaseModel):
return {
"uri": "${env.MILVUS_ENDPOINT}",
"token": "${env.MILVUS_TOKEN}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="milvus_remote_registry.db",
),
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::milvus_remote",
).model_dump(exclude_none=True),
}

View file

@ -12,16 +12,12 @@ from numpy.typing import NDArray
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
@ -30,7 +26,7 @@ from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_WEIGHTED,
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
VectorStoreWithIndex,
)
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
@ -39,7 +35,7 @@ from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
logger = get_logger(name=__name__, category="vector_io::milvus")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
VECTOR_DBS_PREFIX = f"vector_stores:milvus:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:milvus:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:milvus:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:milvus:{VERSION}::"
@ -73,46 +69,23 @@ class MilvusIndex(EmbeddingIndex):
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
schema.add_field(
field_name="chunk_id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=100,
)
schema.add_field(field_name="chunk_id", datatype=DataType.VARCHAR, is_primary=True, max_length=100)
schema.add_field(
field_name="content",
datatype=DataType.VARCHAR,
max_length=65535,
enable_analyzer=True, # Enable text analysis for BM25
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=len(embeddings[0]),
)
schema.add_field(
field_name="chunk_content",
datatype=DataType.JSON,
)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=len(embeddings[0]))
schema.add_field(field_name="chunk_content", datatype=DataType.JSON)
# Add sparse vector field for BM25 (required by the function)
schema.add_field(
field_name="sparse",
datatype=DataType.SPARSE_FLOAT_VECTOR,
)
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
# Create indexes
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="FLAT",
metric_type="COSINE",
)
index_params.add_index(field_name="vector", index_type="FLAT", metric_type="COSINE")
# Add index for sparse field (required by BM25 function)
index_params.add_index(
field_name="sparse",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
index_params.add_index(field_name="sparse", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25")
# Add BM25 function for full-text search
bm25_function = Function(
@ -143,11 +116,7 @@ class MilvusIndex(EmbeddingIndex):
}
)
try:
await asyncio.to_thread(
self.client.insert,
self.collection_name,
data=data,
)
await asyncio.to_thread(self.client.insert, self.collection_name, data=data)
except Exception as e:
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
raise e
@ -166,12 +135,7 @@ class MilvusIndex(EmbeddingIndex):
scores = [res["distance"] for res in search_res[0]]
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Perform BM25-based keyword search using Milvus's built-in full-text search.
"""
@ -209,12 +173,7 @@ class MilvusIndex(EmbeddingIndex):
# Fallback to simple text search
return await self._fallback_keyword_search(query_string, k, score_threshold)
async def _fallback_keyword_search(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def _fallback_keyword_search(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Fallback to simple text search when BM25 search is not available.
"""
@ -302,7 +261,7 @@ class MilvusIndex(EmbeddingIndex):
raise
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(
self,
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
@ -314,28 +273,28 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache = {}
self.client = None
self.inference_api = inference_api
self.vector_db_store = None
self.vector_store_table = None
self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
self.kvstore = await kvstore_impl(self.config.persistence)
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
for vector_store_data in stored_vector_stores:
vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorStoreWithIndex(
vector_store,
index=MilvusIndex(
client=self.client,
collection_name=vector_db.identifier,
collection_name=vector_store.identifier,
consistency_level=self.config.consistency_level,
kvstore=self.kvstore,
),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
self.cache[vector_store.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
@ -352,72 +311,61 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_store(self, vector_store: VectorStore) -> None:
if isinstance(self.config, RemoteMilvusVectorIOConfig):
consistency_level = self.config.consistency_level
else:
consistency_level = "Strong"
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
index = VectorStoreWithIndex(
vector_store=vector_store,
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
self.cache[vector_store.identifier] = index
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_store_id in self.cache:
return self.cache[vector_store_id]
if self.vector_db_store is None:
raise VectorStoreNotFoundError(vector_db_id)
if self.vector_store_table is None:
raise VectorStoreNotFoundError(vector_store_id)
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise VectorStoreNotFoundError(vector_db_id)
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
if not vector_store:
raise VectorStoreNotFoundError(vector_store_id)
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore),
index = VectorStoreWithIndex(
vector_store=vector_store,
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
self.cache[vector_store_id] = index
return index
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id in self.cache:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def unregister_vector_store(self, vector_store_id: str) -> None:
if vector_store_id in self.cache:
await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id]
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a chunk from a milvus vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
index = await self._get_and_cache_vector_store_index(store_id)
if not index:
raise VectorStoreNotFoundError(store_id)

View file

@ -12,6 +12,6 @@ from .config import PGVectorVectorIOConfig
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .pgvector import PGVectorVectorIOAdapter
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -8,10 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@ -22,7 +19,9 @@ class PGVectorVectorIOConfig(BaseModel):
db: str | None = Field(default="postgres")
user: str | None = Field(default="postgres")
password: str | None = Field(default="mysecretpassword")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
persistence: KVStoreReference | None = Field(
description="Config for KV store backend (SQLite only for now)", default=None
)
@classmethod
def sample_run_config(
@ -41,8 +40,8 @@ class PGVectorVectorIOConfig(BaseModel):
"db": db,
"user": user,
"password": password,
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="pgvector_registry.db",
),
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::pgvector",
).model_dump(exclude_none=True),
}

View file

@ -14,27 +14,17 @@ from psycopg2.extras import Json, execute_values
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
from .config import PGVectorVectorIOConfig
@ -42,7 +32,7 @@ from .config import PGVectorVectorIOConfig
log = get_logger(name=__name__, category="vector_io::pgvector")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
VECTOR_DBS_PREFIX = f"vector_stores:pgvector:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
@ -89,13 +79,13 @@ class PGVectorIndex(EmbeddingIndex):
def __init__(
self,
vector_db: VectorDB,
vector_store: VectorStore,
dimension: int,
conn: psycopg2.extensions.connection,
kvstore: KVStore | None = None,
distance_metric: str = "COSINE",
):
self.vector_db = vector_db
self.vector_store = vector_store
self.dimension = dimension
self.conn = conn
self.kvstore = kvstore
@ -107,9 +97,9 @@ class PGVectorIndex(EmbeddingIndex):
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
# Sanitize the table name by replacing hyphens with underscores
# SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens
# SQL doesn't allow hyphens in table names, and vector_store.identifier may contain hyphens
# when created with patterns like "test-vector-db-{uuid4()}"
sanitized_identifier = sanitize_collection_name(self.vector_db.identifier)
sanitized_identifier = sanitize_collection_name(self.vector_store.identifier)
self.table_name = f"vs_{sanitized_identifier}"
cur.execute(
@ -132,8 +122,8 @@ class PGVectorIndex(EmbeddingIndex):
"""
)
except Exception as e:
log.exception(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}")
raise RuntimeError(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}") from e
log.exception(f"Error creating PGVectorIndex for vector_store: {self.vector_store.identifier}")
raise RuntimeError(f"Error creating PGVectorIndex for vector_store: {self.vector_store.identifier}") from e
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
@ -204,12 +194,7 @@ class PGVectorIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring.
@ -316,7 +301,7 @@ class PGVectorIndex(EmbeddingIndex):
"""Remove a chunk from the PostgreSQL table."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids))
def get_pgvector_search_function(self) -> str:
return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric]
@ -338,24 +323,21 @@ class PGVectorIndex(EmbeddingIndex):
)
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(
self,
config: PGVectorVectorIOConfig,
inference_api: Api.inference,
files_api: Files | None = None,
self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.conn = None
self.cache = {}
self.vector_db_store = None
self.vector_store_table = None
self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None:
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
self.kvstore = await kvstore_impl(self.config.kvstore)
self.kvstore = await kvstore_impl(self.config.persistence)
await self.initialize_openai_vector_stores()
try:
@ -393,71 +375,59 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(self, vector_db: VectorDB) -> None:
async def register_vector_store(self, vector_store: VectorStore) -> None:
# Persist vector DB metadata in the KV store
assert self.kvstore is not None
# Upsert model metadata in Postgres
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
upsert_models(self.conn, [(vector_store.identifier, vector_store)])
# Create and cache the PGVector index table for the vector DB
pgvector_index = PGVectorIndex(
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore
vector_store=vector_store, dimension=vector_store.embedding_dimension, conn=self.conn, kvstore=self.kvstore
)
await pgvector_index.initialize()
index = VectorDBWithIndex(
vector_db,
index=pgvector_index,
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api)
self.cache[vector_store.identifier] = index
async def unregister_vector_db(self, vector_db_id: str) -> None:
async def unregister_vector_store(self, vector_store_id: str) -> None:
# Remove provider index and cache
if vector_db_id in self.cache:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
if vector_store_id in self.cache:
await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id]
# Delete vector DB metadata from KV store
assert self.kvstore is not None
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_store_id}")
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id)
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
index = await self._get_and_cache_vector_store_index(vector_db_id)
return await index.query_chunks(query, params)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex:
if vector_store_id in self.cache:
return self.cache[vector_store_id]
if self.vector_db_store is None:
raise VectorStoreNotFoundError(vector_db_id)
if self.vector_store_table is None:
raise VectorStoreNotFoundError(vector_store_id)
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise VectorStoreNotFoundError(vector_db_id)
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
if not vector_store:
raise VectorStoreNotFoundError(vector_store_id)
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
await index.initialize()
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id]
self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api)
return self.cache[vector_store_id]
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a chunk from a PostgreSQL vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
index = await self._get_and_cache_vector_store_index(store_id)
if not index:
raise VectorStoreNotFoundError(store_id)

View file

@ -12,7 +12,6 @@ from .config import QdrantVectorIOConfig
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorIOAdapter
files_api = deps.get(Api.files)
impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api)
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -8,10 +8,7 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@ -27,14 +24,14 @@ class QdrantVectorIOConfig(BaseModel):
prefix: str | None = None
timeout: int | None = None
host: str | None = None
kvstore: KVStoreConfig
persistence: KVStoreReference
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"api_key": "${env.QDRANT_API_KEY:=}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="qdrant_registry.db",
),
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::qdrant_remote",
).model_dump(exclude_none=True),
}

View file

@ -15,8 +15,7 @@ from qdrant_client.models import PointStruct
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
@ -24,16 +23,13 @@ from llama_stack.apis.vector_io import (
VectorStoreChunkingStrategy,
VectorStoreFileObject,
)
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
@ -42,7 +38,7 @@ CHUNK_ID_KEY = "_chunk_id"
# KV store prefixes for vector databases
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::"
VECTOR_DBS_PREFIX = f"vector_stores:qdrant:{VERSION}::"
def convert_id(_id: str) -> str:
@ -98,8 +94,7 @@ class QdrantIndex(EmbeddingIndex):
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
try:
await self.client.delete(
collection_name=self.collection_name,
points_selector=models.PointIdsList(points=chunk_ids),
collection_name=self.collection_name, points_selector=models.PointIdsList(points=chunk_ids)
)
except Exception as e:
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
@ -132,12 +127,7 @@ class QdrantIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Qdrant")
async def query_hybrid(
@ -155,11 +145,11 @@ class QdrantIndex(EmbeddingIndex):
await self.client.delete_collection(collection_name=self.collection_name)
class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(
self,
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
inference_api: Api.inference,
inference_api: Inference,
files_api: Files | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
@ -167,26 +157,24 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.client: AsyncQdrantClient = None
self.cache = {}
self.inference_api = inference_api
self.vector_db_store = None
self.vector_store_table = None
self._qdrant_lock = asyncio.Lock()
async def initialize(self) -> None:
client_config = self.config.model_dump(exclude_none=True, exclude={"kvstore"})
client_config = self.config.model_dump(exclude_none=True, exclude={"persistence"})
self.client = AsyncQdrantClient(**client_config)
self.kvstore = await kvstore_impl(self.config.kvstore)
self.kvstore = await kvstore_impl(self.config.persistence)
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
QdrantIndex(self.client, vector_db.identifier),
self.inference_api,
for vector_store_data in stored_vector_stores:
vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorStoreWithIndex(
vector_store, QdrantIndex(self.client, vector_store.identifier), self.inference_api
)
self.cache[vector_db.identifier] = index
self.cache[vector_store.identifier] = index
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
@ -194,68 +182,57 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_store(self, vector_store: VectorStore) -> None:
assert self.kvstore is not None
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
index = VectorDBWithIndex(
vector_db=vector_db,
index=QdrantIndex(self.client, vector_db.identifier),
index = VectorStoreWithIndex(
vector_store=vector_store,
index=QdrantIndex(self.client, vector_store.identifier),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
self.cache[vector_store.identifier] = index
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id in self.cache:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def unregister_vector_store(self, vector_store_id: str) -> None:
if vector_store_id in self.cache:
await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id]
assert self.kvstore is not None
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_store_id in self.cache:
return self.cache[vector_store_id]
if self.vector_db_store is None:
raise ValueError(f"Vector DB not found {vector_db_id}")
if self.vector_store_table is None:
raise ValueError(f"Vector DB not found {vector_store_id}")
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise VectorStoreNotFoundError(vector_db_id)
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
if not vector_store:
raise VectorStoreNotFoundError(vector_store_id)
index = VectorDBWithIndex(
vector_db=vector_db,
index=QdrantIndex(client=self.client, collection_name=vector_db.identifier),
index = VectorStoreWithIndex(
vector_store=vector_store,
index=QdrantIndex(client=self.client, collection_name=vector_store.identifier),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
self.cache[vector_store_id] = index
return index
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
@ -276,7 +253,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a Qdrant vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
index = await self._get_and_cache_vector_store_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")

View file

@ -12,6 +12,6 @@ from .config import WeaviateVectorIOConfig
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .weaviate import WeaviateVectorIOAdapter
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -8,10 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@ -19,19 +16,17 @@ from llama_stack.schema_utils import json_schema_type
class WeaviateVectorIOConfig(BaseModel):
weaviate_api_key: str | None = Field(description="The API key for the Weaviate instance", default=None)
weaviate_cluster_url: str | None = Field(description="The URL of the Weaviate cluster", default="localhost:8080")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
persistence: KVStoreReference | None = Field(
description="Config for KV store backend (SQLite only for now)", default=None
)
@classmethod
def sample_run_config(
cls,
__distro_dir__: str,
**kwargs: Any,
) -> dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"weaviate_api_key": None,
"weaviate_cluster_url": "${env.WEAVIATE_CLUSTER_URL:=localhost:8080}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="weaviate_registry.db",
),
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::weaviate",
).model_dump(exclude_none=True),
}

View file

@ -14,22 +14,21 @@ from weaviate.classes.query import Filter, HybridFusion
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
OpenAIVectorStoreMixin,
)
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
VectorStoreWithIndex,
)
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
@ -38,7 +37,7 @@ from .config import WeaviateVectorIOConfig
log = get_logger(name=__name__, category="vector_io::weaviate")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
VECTOR_DBS_PREFIX = f"vector_stores:weaviate:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
@ -46,12 +45,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
class WeaviateIndex(EmbeddingIndex):
def __init__(
self,
client: weaviate.WeaviateClient,
collection_name: str,
kvstore: KVStore | None = None,
):
def __init__(self, client: weaviate.WeaviateClient, collection_name: str, kvstore: KVStore | None = None):
self.client = client
self.collection_name = sanitize_collection_name(collection_name, weaviate_format=True)
self.kvstore = kvstore
@ -106,9 +100,7 @@ class WeaviateIndex(EmbeddingIndex):
try:
results = collection.query.near_vector(
near_vector=embedding.tolist(),
limit=k,
return_metadata=wvc.query.MetadataQuery(distance=True),
near_vector=embedding.tolist(), limit=k, return_metadata=wvc.query.MetadataQuery(distance=True)
)
except Exception as e:
log.error(f"Weaviate client vector search failed: {e}")
@ -151,12 +143,7 @@ class WeaviateIndex(EmbeddingIndex):
collection = self.client.collections.get(sanitized_collection_name)
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Performs BM25-based keyword search using Weaviate's built-in full-text search.
Args:
@ -173,9 +160,7 @@ class WeaviateIndex(EmbeddingIndex):
# Perform BM25 keyword search on chunk_content field
try:
results = collection.query.bm25(
query=query_string,
limit=k,
return_metadata=wvc.query.MetadataQuery(score=True),
query=query_string, limit=k, return_metadata=wvc.query.MetadataQuery(score=True)
)
except Exception as e:
log.error(f"Weaviate client keyword search failed: {e}")
@ -272,24 +257,14 @@ class WeaviateIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
class WeaviateVectorIOAdapter(
OpenAIVectorStoreMixin,
VectorIO,
NeedsRequestProviderData,
VectorDBsProtocolPrivate,
):
def __init__(
self,
config: WeaviateVectorIOConfig,
inference_api: Api.inference,
files_api: Files | None,
) -> None:
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorStoresProtocolPrivate):
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.client_cache = {}
self.cache = {}
self.vector_db_store = None
self.vector_store_table = None
self.metadata_collection_name = "openai_vector_stores_metadata"
def _get_client(self) -> weaviate.WeaviateClient:
@ -297,10 +272,7 @@ class WeaviateVectorIOAdapter(
log.info("Using Weaviate locally in container")
host, port = self.config.weaviate_cluster_url.split(":")
key = "local_test"
client = weaviate.connect_to_local(
host=host,
port=port,
)
client = weaviate.connect_to_local(host=host, port=port)
else:
log.info("Using Weaviate remote cluster with URL")
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
@ -316,8 +288,8 @@ class WeaviateVectorIOAdapter(
async def initialize(self) -> None:
"""Set up KV store and load existing vector DBs and OpenAI vector stores."""
# Initialize KV store for metadata if configured
if self.config.kvstore is not None:
self.kvstore = await kvstore_impl(self.config.kvstore)
if self.config.persistence is not None:
self.kvstore = await kvstore_impl(self.config.persistence)
else:
self.kvstore = None
log.info("No kvstore configured, registry will not persist across restarts")
@ -328,17 +300,11 @@ class WeaviateVectorIOAdapter(
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored = await self.kvstore.values_in_range(start_key, end_key)
for raw in stored:
vector_db = VectorDB.model_validate_json(raw)
vector_store = VectorStore.model_validate_json(raw)
client = self._get_client()
idx = WeaviateIndex(
client=client,
collection_name=vector_db.identifier,
kvstore=self.kvstore,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db=vector_db,
index=idx,
inference_api=self.inference_api,
idx = WeaviateIndex(client=client, collection_name=vector_store.identifier, kvstore=self.kvstore)
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store=vector_store, index=idx, inference_api=self.inference_api
)
# Load OpenAI vector stores metadata into cache
@ -350,90 +316,74 @@ class WeaviateVectorIOAdapter(
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_store(self, vector_store: VectorStore) -> None:
client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
sanitized_collection_name = sanitize_collection_name(vector_store.identifier, weaviate_format=True)
# Create collection if it doesn't exist
if not client.collections.exists(sanitized_collection_name):
client.collections.create(
name=sanitized_collection_name,
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
wvc.config.Property(name="chunk_content", data_type=wvc.config.DataType.TEXT),
],
)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db,
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
self.inference_api,
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
)
async def unregister_vector_db(self, vector_db_id: str) -> None:
async def unregister_vector_store(self, vector_store_id: str) -> None:
client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
if vector_db_id not in self.cache or client.collections.exists(sanitized_collection_name) is False:
sanitized_collection_name = sanitize_collection_name(vector_store_id, weaviate_format=True)
if vector_store_id not in self.cache or client.collections.exists(sanitized_collection_name) is False:
return
client.collections.delete(sanitized_collection_name)
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id]
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_store_id in self.cache:
return self.cache[vector_store_id]
if self.vector_db_store is None:
raise VectorStoreNotFoundError(vector_db_id)
if self.vector_store_table is None:
raise VectorStoreNotFoundError(vector_store_id)
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise VectorStoreNotFoundError(vector_db_id)
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
if not vector_store:
raise VectorStoreNotFoundError(vector_store_id)
client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
sanitized_collection_name = sanitize_collection_name(vector_store.identifier, weaviate_format=True)
if not client.collections.exists(sanitized_collection_name):
raise ValueError(f"Collection with name `{sanitized_collection_name}` not found")
index = VectorDBWithIndex(
vector_db=vector_db,
index=WeaviateIndex(client=client, collection_name=vector_db.identifier),
index = VectorStoreWithIndex(
vector_store=vector_store,
index=WeaviateIndex(client=client, collection_name=vector_store.identifier),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
self.cache[vector_store_id] = index
return index
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
index = await self._get_and_cache_vector_db_index(store_id)
index = await self._get_and_cache_vector_store_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")