Merge branch 'main' into routeur

This commit is contained in:
Sébastien Han 2025-11-24 14:58:43 +01:00 committed by GitHub
commit 3770963130
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
255 changed files with 18366 additions and 1909 deletions

View file

@ -11,10 +11,9 @@ from typing import Any, Literal
from pydantic import BaseModel, TypeAdapter
from llama_stack.core.datatypes import AccessRule, StackRunConfig
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from llama_stack_api import (
Conversation,
ConversationDeletedResource,
@ -25,6 +24,7 @@ from llama_stack_api import (
Conversations,
Metadata,
)
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
logger = get_logger(name=__name__, category="openai_conversations")

View file

@ -10,7 +10,7 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
from llama_stack_api import ListPromptsResponse, Prompt, Prompts

View file

@ -11,9 +11,9 @@ from datetime import UTC, datetime, timedelta
from starlette.types import ASGIApp, Receive, Scope, Send
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendType
from llama_stack.core.storage.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
from llama_stack_api.internal.kvstore import KVStore
logger = get_logger(name=__name__, category="core::server")

View file

@ -385,8 +385,8 @@ def _initialize_storage(run_config: StackRunConfig):
else:
raise ValueError(f"Unknown storage backend type: {type}")
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
from llama_stack.core.storage.kvstore.kvstore import register_kvstore_backends
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
register_kvstore_backends(kv_backends)
register_sqlstore_backends(sql_backends)

View file

@ -12,6 +12,8 @@ from typing import Annotated, Literal
from pydantic import BaseModel, Field, field_validator
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
class StorageBackendType(StrEnum):
KV_REDIS = "kv_redis"
@ -256,15 +258,24 @@ class ResponsesStoreReference(InferenceStoreReference):
class ServerStoresConfig(BaseModel):
metadata: KVStoreReference | None = Field(
default=None,
default=KVStoreReference(
backend="kv_default",
namespace="registry",
),
description="Metadata store configuration (uses KV backend)",
)
inference: InferenceStoreReference | None = Field(
default=None,
default=InferenceStoreReference(
backend="sql_default",
table_name="inference_store",
),
description="Inference store configuration (uses SQL backend)",
)
conversations: SqlStoreReference | None = Field(
default=None,
default=SqlStoreReference(
backend="sql_default",
table_name="openai_conversations",
),
description="Conversations store configuration (uses SQL backend)",
)
responses: ResponsesStoreReference | None = Field(
@ -272,13 +283,21 @@ class ServerStoresConfig(BaseModel):
description="Responses store configuration (uses SQL backend)",
)
prompts: KVStoreReference | None = Field(
default=None,
default=KVStoreReference(backend="kv_default", namespace="prompts"),
description="Prompts store configuration (uses KV backend)",
)
class StorageConfig(BaseModel):
backends: dict[str, StorageBackendConfig] = Field(
default={
"kv_default": SqliteKVStoreConfig(
db_path=f"${{env.SQLITE_STORE_DIR:={DISTRIBS_BASE_DIR}}}/kvstore.db",
),
"sql_default": SqliteSqlStoreConfig(
db_path=f"${{env.SQLITE_STORE_DIR:={DISTRIBS_BASE_DIR}}}/sql_store.db",
),
},
description="Named backend configurations (e.g., 'default', 'cache')",
)
stores: ServerStoresConfig = Field(

View file

@ -4,4 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack_api.internal.kvstore import KVStore as KVStore
from .kvstore import * # noqa: F401, F403

View file

@ -13,11 +13,19 @@ from __future__ import annotations
import asyncio
from collections import defaultdict
from datetime import datetime
from typing import cast
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig
from llama_stack_api.internal.kvstore import KVStore
from .api import KVStore
from .config import KVStoreConfig
from .config import (
KVStoreConfig,
MongoDBKVStoreConfig,
PostgresKVStoreConfig,
RedisKVStoreConfig,
SqliteKVStoreConfig,
)
def kvstore_dependencies():
@ -33,7 +41,7 @@ def kvstore_dependencies():
class InmemoryKVStoreImpl(KVStore):
def __init__(self):
self._store = {}
self._store: dict[str, str] = {}
async def initialize(self) -> None:
pass
@ -41,7 +49,7 @@ class InmemoryKVStoreImpl(KVStore):
async def get(self, key: str) -> str | None:
return self._store.get(key)
async def set(self, key: str, value: str) -> None:
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
self._store[key] = value
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
@ -70,7 +78,8 @@ def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None
_KVSTORE_INSTANCES.clear()
_KVSTORE_LOCKS.clear()
for name, cfg in backends.items():
_KVSTORE_BACKENDS[name] = cfg
typed_cfg = cast(KVStoreConfig, cfg)
_KVSTORE_BACKENDS[name] = typed_cfg
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
@ -94,19 +103,20 @@ async def kvstore_impl(reference: KVStoreReference) -> KVStore:
config = backend_config.model_copy()
config.namespace = reference.namespace
if config.type == StorageBackendType.KV_REDIS.value:
impl: KVStore
if isinstance(config, RedisKVStoreConfig):
from .redis import RedisKVStoreImpl
impl = RedisKVStoreImpl(config)
elif config.type == StorageBackendType.KV_SQLITE.value:
elif isinstance(config, SqliteKVStoreConfig):
from .sqlite import SqliteKVStoreImpl
impl = SqliteKVStoreImpl(config)
elif config.type == StorageBackendType.KV_POSTGRES.value:
elif isinstance(config, PostgresKVStoreConfig):
from .postgres import PostgresKVStoreImpl
impl = PostgresKVStoreImpl(config)
elif config.type == StorageBackendType.KV_MONGODB.value:
elif isinstance(config, MongoDBKVStoreConfig):
from .mongodb import MongoDBKVStoreImpl
impl = MongoDBKVStoreImpl(config)

View file

@ -9,8 +9,8 @@ from datetime import datetime
from pymongo import AsyncMongoClient
from pymongo.asynchronous.collection import AsyncCollection
from llama_stack.core.storage.kvstore import KVStore
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
from ..config import MongoDBKVStoreConfig

View file

@ -6,12 +6,13 @@
from datetime import datetime
import psycopg2
from psycopg2.extras import DictCursor
import psycopg2 # type: ignore[import-not-found]
from psycopg2.extensions import connection as PGConnection # type: ignore[import-not-found]
from psycopg2.extras import DictCursor # type: ignore[import-not-found]
from llama_stack.log import get_logger
from llama_stack_api.internal.kvstore import KVStore
from ..api import KVStore
from ..config import PostgresKVStoreConfig
log = get_logger(name=__name__, category="providers::utils")
@ -20,12 +21,12 @@ log = get_logger(name=__name__, category="providers::utils")
class PostgresKVStoreImpl(KVStore):
def __init__(self, config: PostgresKVStoreConfig):
self.config = config
self.conn = None
self.cursor = None
self._conn: PGConnection | None = None
self._cursor: DictCursor | None = None
async def initialize(self) -> None:
try:
self.conn = psycopg2.connect(
self._conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
database=self.config.db,
@ -34,11 +35,11 @@ class PostgresKVStoreImpl(KVStore):
sslmode=self.config.ssl_mode,
sslrootcert=self.config.ca_cert_path,
)
self.conn.autocommit = True
self.cursor = self.conn.cursor(cursor_factory=DictCursor)
self._conn.autocommit = True
self._cursor = self._conn.cursor(cursor_factory=DictCursor)
# Create table if it doesn't exist
self.cursor.execute(
self._cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.config.table_name} (
key TEXT PRIMARY KEY,
@ -51,6 +52,11 @@ class PostgresKVStoreImpl(KVStore):
log.exception("Could not connect to PostgreSQL database server")
raise RuntimeError("Could not connect to PostgreSQL database server") from e
def _cursor_or_raise(self) -> DictCursor:
if self._cursor is None:
raise RuntimeError("Postgres client not initialized")
return self._cursor
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
return key
@ -58,7 +64,8 @@ class PostgresKVStoreImpl(KVStore):
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
key = self._namespaced_key(key)
self.cursor.execute(
cursor = self._cursor_or_raise()
cursor.execute(
f"""
INSERT INTO {self.config.table_name} (key, value, expiration)
VALUES (%s, %s, %s)
@ -70,7 +77,8 @@ class PostgresKVStoreImpl(KVStore):
async def get(self, key: str) -> str | None:
key = self._namespaced_key(key)
self.cursor.execute(
cursor = self._cursor_or_raise()
cursor.execute(
f"""
SELECT value FROM {self.config.table_name}
WHERE key = %s
@ -78,12 +86,13 @@ class PostgresKVStoreImpl(KVStore):
""",
(key,),
)
result = self.cursor.fetchone()
result = cursor.fetchone()
return result[0] if result else None
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
self.cursor.execute(
cursor = self._cursor_or_raise()
cursor.execute(
f"DELETE FROM {self.config.table_name} WHERE key = %s",
(key,),
)
@ -92,7 +101,8 @@ class PostgresKVStoreImpl(KVStore):
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
self.cursor.execute(
cursor = self._cursor_or_raise()
cursor.execute(
f"""
SELECT value FROM {self.config.table_name}
WHERE key >= %s AND key < %s
@ -101,14 +111,15 @@ class PostgresKVStoreImpl(KVStore):
""",
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]
return [row[0] for row in cursor.fetchall()]
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
self.cursor.execute(
cursor = self._cursor_or_raise()
cursor.execute(
f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]
return [row[0] for row in cursor.fetchall()]

View file

@ -6,18 +6,25 @@
from datetime import datetime
from redis.asyncio import Redis
from redis.asyncio import Redis # type: ignore[import-not-found]
from llama_stack_api.internal.kvstore import KVStore
from ..api import KVStore
from ..config import RedisKVStoreConfig
class RedisKVStoreImpl(KVStore):
def __init__(self, config: RedisKVStoreConfig):
self.config = config
self._redis: Redis | None = None
async def initialize(self) -> None:
self.redis = Redis.from_url(self.config.url)
self._redis = Redis.from_url(self.config.url)
def _client(self) -> Redis:
if self._redis is None:
raise RuntimeError("Redis client not initialized")
return self._redis
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
@ -26,30 +33,37 @@ class RedisKVStoreImpl(KVStore):
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
key = self._namespaced_key(key)
await self.redis.set(key, value)
client = self._client()
await client.set(key, value)
if expiration:
await self.redis.expireat(key, expiration)
await client.expireat(key, expiration)
async def get(self, key: str) -> str | None:
key = self._namespaced_key(key)
value = await self.redis.get(key)
client = self._client()
value = await client.get(key)
if value is None:
return None
await self.redis.ttl(key)
return value
await client.ttl(key)
if isinstance(value, bytes):
return value.decode("utf-8")
if isinstance(value, str):
return value
return str(value)
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
await self.redis.delete(key)
await self._client().delete(key)
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
client = self._client()
cursor = 0
pattern = start_key + "*" # Match all keys starting with start_key prefix
matching_keys = []
matching_keys: list[str | bytes] = []
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=1000)
cursor, keys = await client.scan(cursor, match=pattern, count=1000)
for key in keys:
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
@ -61,7 +75,7 @@ class RedisKVStoreImpl(KVStore):
# Then fetch all values in a single MGET call
if matching_keys:
values = await self.redis.mget(matching_keys)
values = await client.mget(matching_keys)
return [
value.decode("utf-8") if isinstance(value, bytes) else value for value in values if value is not None
]
@ -70,7 +84,18 @@ class RedisKVStoreImpl(KVStore):
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}")
if not matching_keys:
return []
return [k.decode("utf-8") for k in matching_keys]
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
client = self._client()
cursor = 0
pattern = start_key + "*"
result: list[str] = []
while True:
cursor, keys = await client.scan(cursor, match=pattern, count=1000)
for key in keys:
key_str = key.decode("utf-8") if isinstance(key, bytes) else str(key)
if start_key <= key_str <= end_key:
result.append(key_str)
if cursor == 0:
break
return result

View file

@ -10,8 +10,8 @@ from datetime import datetime
import aiosqlite
from llama_stack.log import get_logger
from llama_stack_api.internal.kvstore import KVStore
from ..api import KVStore
from ..config import SqliteKVStoreConfig
logger = get_logger(name=__name__, category="providers::utils")

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack_api.internal.sqlstore import (
ColumnDefinition as ColumnDefinition,
)
from llama_stack_api.internal.sqlstore import (
ColumnType as ColumnType,
)
from llama_stack_api.internal.sqlstore import (
SqlStore as SqlStore,
)
from .sqlstore import * # noqa: F401,F403

View file

@ -14,8 +14,8 @@ from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.core.storage.datatypes import StorageBackendType
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
from llama_stack_api import PaginatedResponse
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType, SqlStore
logger = get_logger(name=__name__, category="providers::utils")

View file

@ -29,8 +29,7 @@ from sqlalchemy.sql.elements import ColumnElement
from llama_stack.core.storage.datatypes import SqlAlchemySqlStoreConfig
from llama_stack.log import get_logger
from llama_stack_api import PaginatedResponse
from .api import ColumnDefinition, ColumnType, SqlStore
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType, SqlStore
logger = get_logger(name=__name__, category="providers::utils")

View file

@ -16,8 +16,7 @@ from llama_stack.core.storage.datatypes import (
StorageBackendConfig,
StorageBackendType,
)
from .api import SqlStore
from llama_stack_api.internal.sqlstore import SqlStore
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]

View file

@ -12,8 +12,8 @@ import pydantic
from llama_stack.core.datatypes import RoutableObjectWithProvider
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
logger = get_logger(__name__, category="core::registry")

View file

@ -17,44 +17,43 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers

View file

@ -17,44 +17,43 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers

View file

@ -16,9 +16,8 @@ providers:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: nvidia
provider_type: remote::nvidia
config:

View file

@ -16,9 +16,8 @@ providers:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
vector_io:
- provider_id: faiss
provider_type: inline::faiss

View file

@ -27,12 +27,12 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
vector_io:
- provider_id: sqlite-vec

View file

@ -11,7 +11,7 @@ providers:
- provider_id: vllm-inference
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=http://localhost:8000/v1}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}

View file

@ -17,44 +17,43 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers

View file

@ -17,44 +17,43 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers

View file

@ -17,44 +17,43 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers

View file

@ -17,44 +17,43 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers

View file

@ -17,6 +17,8 @@ from llama_stack.core.datatypes import (
ToolGroupInput,
VectorStoresConfig,
)
from llama_stack.core.storage.kvstore.config import PostgresKVStoreConfig
from llama_stack.core.storage.sqlstore.sqlstore import PostgresSqlStoreConfig
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
@ -35,8 +37,6 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
)
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
from llama_stack_api import RemoteProviderSpec

View file

@ -35,13 +35,13 @@ from llama_stack.core.storage.datatypes import (
SqlStoreReference,
StorageBackendType,
)
from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig
from llama_stack.core.storage.kvstore.config import get_pip_packages as get_kv_pip_packages
from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.core.storage.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.config import get_pip_packages as get_kv_pip_packages
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
from llama_stack_api import DatasetPurpose, ModelType

View file

@ -15,7 +15,7 @@ providers:
- provider_id: watsonx
provider_type: remote::watsonx
config:
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
base_url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
api_key: ${env.WATSONX_API_KEY:=}
project_id: ${env.WATSONX_PROJECT_ID:=}
vector_io:

View file

@ -23,12 +23,14 @@ async def get_provider_impl(
config,
deps[Api.inference],
deps[Api.vector_io],
deps[Api.safety],
deps.get(Api.safety),
deps[Api.tool_runtime],
deps[Api.tool_groups],
deps[Api.conversations],
policy,
deps[Api.prompts],
deps[Api.files],
telemetry_enabled,
policy,
)
await impl.initialize()
return impl

View file

@ -6,12 +6,13 @@
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.storage.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack_api import (
Agents,
Conversations,
Files,
Inference,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
@ -22,6 +23,7 @@ from llama_stack_api import (
OpenAIResponsePrompt,
OpenAIResponseText,
Order,
Prompts,
ResponseGuardrail,
Safety,
ToolGroups,
@ -41,10 +43,12 @@ class MetaReferenceAgentsImpl(Agents):
config: MetaReferenceAgentsImplConfig,
inference_api: Inference,
vector_io_api: VectorIO,
safety_api: Safety,
safety_api: Safety | None,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
conversations_api: Conversations,
prompts_api: Prompts,
files_api: Files,
policy: list[AccessRule],
telemetry_enabled: bool = False,
):
@ -56,7 +60,8 @@ class MetaReferenceAgentsImpl(Agents):
self.tool_groups_api = tool_groups_api
self.conversations_api = conversations_api
self.telemetry_enabled = telemetry_enabled
self.prompts_api = prompts_api
self.files_api = files_api
self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl: OpenAIResponsesImpl | None = None
self.policy = policy
@ -73,6 +78,8 @@ class MetaReferenceAgentsImpl(Agents):
vector_io_api=self.vector_io_api,
safety_api=self.safety_api,
conversations_api=self.conversations_api,
prompts_api=self.prompts_api,
files_api=self.files_api,
)
async def shutdown(self) -> None:
@ -92,6 +99,7 @@ class MetaReferenceAgentsImpl(Agents):
model: str,
prompt: OpenAIResponsePrompt | None = None,
instructions: str | None = None,
parallel_tool_calls: bool | None = True,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
@ -120,6 +128,7 @@ class MetaReferenceAgentsImpl(Agents):
include,
max_infer_iters,
guardrails,
parallel_tool_calls,
max_tool_calls,
)
return result # type: ignore[no-any-return]

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
import time
import uuid
from collections.abc import AsyncIterator
@ -18,13 +19,17 @@ from llama_stack.providers.utils.responses.responses_store import (
from llama_stack_api import (
ConversationItem,
Conversations,
Files,
Inference,
InvalidConversationIdError,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIChatCompletionContentPartParam,
OpenAIDeleteResponseObject,
OpenAIMessageParam,
OpenAIResponseInput,
OpenAIResponseInputMessageContentFile,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMessage,
@ -34,7 +39,9 @@ from llama_stack_api import (
OpenAIResponseText,
OpenAIResponseTextFormat,
OpenAISystemMessageParam,
OpenAIUserMessageParam,
Order,
Prompts,
ResponseGuardrailSpec,
Safety,
ToolGroups,
@ -46,6 +53,7 @@ from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor
from .types import ChatCompletionContext, ToolContext
from .utils import (
convert_response_content_to_chat_content,
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
extract_guardrail_ids,
@ -67,8 +75,10 @@ class OpenAIResponsesImpl:
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
safety_api: Safety,
safety_api: Safety | None,
conversations_api: Conversations,
prompts_api: Prompts,
files_api: Files,
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
@ -82,6 +92,8 @@ class OpenAIResponsesImpl:
tool_runtime_api=tool_runtime_api,
vector_io_api=vector_io_api,
)
self.prompts_api = prompts_api
self.files_api = files_api
async def _prepend_previous_response(
self,
@ -122,11 +134,13 @@ class OpenAIResponsesImpl:
# Use stored messages directly and convert only new input
message_adapter = TypeAdapter(list[OpenAIMessageParam])
messages = message_adapter.validate_python(previous_response.messages)
new_messages = await convert_response_input_to_chat_messages(input, previous_messages=messages)
new_messages = await convert_response_input_to_chat_messages(
input, previous_messages=messages, files_api=self.files_api
)
messages.extend(new_messages)
else:
# Backward compatibility: reconstruct from inputs
messages = await convert_response_input_to_chat_messages(all_input)
messages = await convert_response_input_to_chat_messages(all_input, files_api=self.files_api)
tool_context.recover_tools_from_previous_response(previous_response)
elif conversation is not None:
@ -138,7 +152,7 @@ class OpenAIResponsesImpl:
all_input = input
if not conversation_items.data:
# First turn - just convert the new input
messages = await convert_response_input_to_chat_messages(input)
messages = await convert_response_input_to_chat_messages(input, files_api=self.files_api)
else:
if not stored_messages:
all_input = conversation_items.data
@ -154,14 +168,82 @@ class OpenAIResponsesImpl:
all_input = input
messages = stored_messages or []
new_messages = await convert_response_input_to_chat_messages(all_input, previous_messages=messages)
new_messages = await convert_response_input_to_chat_messages(
all_input, previous_messages=messages, files_api=self.files_api
)
messages.extend(new_messages)
else:
all_input = input
messages = await convert_response_input_to_chat_messages(all_input)
messages = await convert_response_input_to_chat_messages(all_input, files_api=self.files_api)
return all_input, messages, tool_context
async def _prepend_prompt(
self,
messages: list[OpenAIMessageParam],
openai_response_prompt: OpenAIResponsePrompt | None,
) -> None:
"""Prepend prompt template to messages, resolving text/image/file variables.
:param messages: List of OpenAIMessageParam objects
:param openai_response_prompt: (Optional) OpenAIResponsePrompt object with variables
:returns: string of utf-8 characters
"""
if not openai_response_prompt or not openai_response_prompt.id:
return
prompt_version = int(openai_response_prompt.version) if openai_response_prompt.version else None
cur_prompt = await self.prompts_api.get_prompt(openai_response_prompt.id, prompt_version)
if not cur_prompt or not cur_prompt.prompt:
return
cur_prompt_text = cur_prompt.prompt
cur_prompt_variables = cur_prompt.variables
if not openai_response_prompt.variables:
messages.insert(0, OpenAISystemMessageParam(content=cur_prompt_text))
return
# Validate that all provided variables exist in the prompt
for name in openai_response_prompt.variables.keys():
if name not in cur_prompt_variables:
raise ValueError(f"Variable {name} not found in prompt {openai_response_prompt.id}")
# Separate text and media variables
text_substitutions = {}
media_content_parts: list[OpenAIChatCompletionContentPartParam] = []
for name, value in openai_response_prompt.variables.items():
# Text variable found
if isinstance(value, OpenAIResponseInputMessageContentText):
text_substitutions[name] = value.text
# Media variable found
elif isinstance(value, OpenAIResponseInputMessageContentImage | OpenAIResponseInputMessageContentFile):
converted_parts = await convert_response_content_to_chat_content([value], files_api=self.files_api)
if isinstance(converted_parts, list):
media_content_parts.extend(converted_parts)
# Eg: {{product_photo}} becomes "[Image: product_photo]"
# This gives the model textual context about what media exists in the prompt
var_type = value.type.replace("input_", "").replace("_", " ").title()
text_substitutions[name] = f"[{var_type}: {name}]"
def replace_variable(match: re.Match[str]) -> str:
var_name = match.group(1).strip()
return str(text_substitutions.get(var_name, match.group(0)))
pattern = r"\{\{\s*(\w+)\s*\}\}"
processed_prompt_text = re.sub(pattern, replace_variable, cur_prompt_text)
# Insert system message with resolved text
messages.insert(0, OpenAISystemMessageParam(content=processed_prompt_text))
# If we have media, create a new user message because allows to ingest images and files
if media_content_parts:
messages.append(OpenAIUserMessageParam(content=media_content_parts))
async def get_openai_response(
self,
response_id: str,
@ -252,6 +334,7 @@ class OpenAIResponsesImpl:
include: list[str] | None = None,
max_infer_iters: int | None = 10,
guardrails: list[str | ResponseGuardrailSpec] | None = None,
parallel_tool_calls: bool | None = None,
max_tool_calls: int | None = None,
):
stream = bool(stream)
@ -272,6 +355,14 @@ class OpenAIResponsesImpl:
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
# Validate that Safety API is available if guardrails are requested
if guardrail_ids and self.safety_api is None:
raise ValueError(
"Cannot process guardrails: Safety API is not configured.\n\n"
"To use guardrails, ensure the Safety API is configured in your stack, or remove "
"the 'guardrails' parameter from your request."
)
if conversation is not None:
if previous_response_id is not None:
raise ValueError(
@ -288,6 +379,7 @@ class OpenAIResponsesImpl:
input=input,
conversation=conversation,
model=model,
prompt=prompt,
instructions=instructions,
previous_response_id=previous_response_id,
store=store,
@ -296,6 +388,7 @@ class OpenAIResponsesImpl:
tools=tools,
max_infer_iters=max_infer_iters,
guardrail_ids=guardrail_ids,
parallel_tool_calls=parallel_tool_calls,
max_tool_calls=max_tool_calls,
)
@ -340,12 +433,14 @@ class OpenAIResponsesImpl:
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
prompt: OpenAIResponsePrompt | None = None,
store: bool | None = True,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
guardrail_ids: list[str] | None = None,
parallel_tool_calls: bool | None = True,
max_tool_calls: int | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# These should never be None when called from create_openai_response (which sets defaults)
@ -361,6 +456,9 @@ class OpenAIResponsesImpl:
if instructions:
messages.insert(0, OpenAISystemMessageParam(content=instructions))
# Prepend reusable prompt (if provided)
await self._prepend_prompt(messages, prompt)
# Structured outputs
response_format = await convert_response_text_to_chat_response_format(text)
@ -383,8 +481,10 @@ class OpenAIResponsesImpl:
ctx=ctx,
response_id=response_id,
created_at=created_at,
prompt=prompt,
text=text,
max_infer_iters=max_infer_iters,
parallel_tool_calls=parallel_tool_calls,
tool_executor=self.tool_executor,
safety_api=self.safety_api,
guardrail_ids=guardrail_ids,

View file

@ -66,6 +66,8 @@ from llama_stack_api import (
OpenAIResponseUsage,
OpenAIResponseUsageInputTokensDetails,
OpenAIResponseUsageOutputTokensDetails,
OpenAIToolMessageParam,
Safety,
WebSearchToolTypes,
)
@ -111,9 +113,10 @@ class StreamingResponseOrchestrator:
max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class
instructions: str | None,
safety_api,
safety_api: Safety | None,
guardrail_ids: list[str] | None = None,
prompt: OpenAIResponsePrompt | None = None,
parallel_tool_calls: bool | None = None,
max_tool_calls: int | None = None,
):
self.inference_api = inference_api
@ -128,6 +131,8 @@ class StreamingResponseOrchestrator:
self.prompt = prompt
# System message that is inserted into the model's context
self.instructions = instructions
# Whether to allow more than one function tool call generated per turn.
self.parallel_tool_calls = parallel_tool_calls
# Max number of total calls to built-in tools that can be processed in a response
self.max_tool_calls = max_tool_calls
self.sequence_number = 0
@ -190,6 +195,7 @@ class StreamingResponseOrchestrator:
usage=self.accumulated_usage,
instructions=self.instructions,
prompt=self.prompt,
parallel_tool_calls=self.parallel_tool_calls,
max_tool_calls=self.max_tool_calls,
)
@ -901,10 +907,16 @@ class StreamingResponseOrchestrator:
"""Coordinate execution of both function and non-function tool calls."""
# Execute non-function tool calls
for tool_call in non_function_tool_calls:
# Check if total calls made to built-in and mcp tools exceed max_tool_calls
# if total calls made to built-in and mcp tools exceed max_tool_calls
# then create a tool response message indicating the call was skipped
if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls:
logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.")
break
skipped_call_message = OpenAIToolMessageParam(
content=f"Tool call skipped: maximum tool calls limit ({self.max_tool_calls}) reached.",
tool_call_id=tool_call.id,
)
next_turn_messages.append(skipped_call_message)
continue
# Find the item_id for this tool call
matching_item_id = None

View file

@ -5,11 +5,14 @@
# the root directory of this source tree.
import asyncio
import base64
import mimetypes
import re
import uuid
from collections.abc import Sequence
from llama_stack_api import (
Files,
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
@ -18,6 +21,8 @@ from llama_stack_api import (
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIFile,
OpenAIFileFile,
OpenAIImageURL,
OpenAIJSONSchema,
OpenAIMessageParam,
@ -29,6 +34,7 @@ from llama_stack_api import (
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
OpenAIResponseInputMessageContentFile,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
@ -37,9 +43,11 @@ from llama_stack_api import (
OpenAIResponseMessage,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseText,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
@ -49,6 +57,46 @@ from llama_stack_api import (
)
async def extract_bytes_from_file(file_id: str, files_api: Files) -> bytes:
"""
Extract raw bytes from file using the Files API.
:param file_id: The file identifier (e.g., "file-abc123")
:param files_api: Files API instance
:returns: Raw file content as bytes
:raises: ValueError if file cannot be retrieved
"""
try:
response = await files_api.openai_retrieve_file_content(file_id)
return bytes(response.body)
except Exception as e:
raise ValueError(f"Failed to retrieve file content for file_id '{file_id}': {str(e)}") from e
def generate_base64_ascii_text_from_bytes(raw_bytes: bytes) -> str:
"""
Converts raw binary bytes into a safe ASCII text representation for URLs
:param raw_bytes: the actual bytes that represents file content
:returns: string of utf-8 characters
"""
return base64.b64encode(raw_bytes).decode("utf-8")
def construct_data_url(ascii_text: str, mime_type: str | None) -> str:
"""
Construct data url with decoded data inside
:param ascii_text: ASCII content
:param mime_type: MIME type of file
:returns: data url string (eg. data:image/png,base64,%3Ch1%3EHello%2C%20World%21%3C%2Fh1%3E)
"""
if not mime_type:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{ascii_text}"
async def convert_chat_choice_to_response_message(
choice: OpenAIChoice,
citation_files: dict[str, str] | None = None,
@ -78,11 +126,15 @@ async def convert_chat_choice_to_response_message(
async def convert_response_content_to_chat_content(
content: str | Sequence[OpenAIResponseInputMessageContent | OpenAIResponseOutputMessageContent],
files_api: Files | None,
) -> str | list[OpenAIChatCompletionContentPartParam]:
"""
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
The content schemas of each API look similar, but are not exactly the same.
:param content: The content to convert
:param files_api: Files API for resolving file_id to raw file content (required if content contains files/images)
"""
if isinstance(content, str):
return content
@ -95,9 +147,68 @@ async def convert_response_content_to_chat_content(
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
detail = content_part.detail
image_mime_type = None
if content_part.image_url:
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
image_url = OpenAIImageURL(url=content_part.image_url, detail=detail)
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
elif content_part.file_id:
if files_api is None:
raise ValueError("file_ids are not supported by this implementation of the Stack")
image_file_response = await files_api.openai_retrieve_file(content_part.file_id)
if image_file_response.filename:
image_mime_type, _ = mimetypes.guess_type(image_file_response.filename)
raw_image_bytes = await extract_bytes_from_file(content_part.file_id, files_api)
ascii_text = generate_base64_ascii_text_from_bytes(raw_image_bytes)
image_data_url = construct_data_url(ascii_text, image_mime_type)
image_url = OpenAIImageURL(url=image_data_url, detail=detail)
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
else:
raise ValueError(
f"Image content must have either 'image_url' or 'file_id'. "
f"Got image_url={content_part.image_url}, file_id={content_part.file_id}"
)
elif isinstance(content_part, OpenAIResponseInputMessageContentFile):
resolved_file_data = None
file_data = content_part.file_data
file_id = content_part.file_id
file_url = content_part.file_url
filename = content_part.filename
file_mime_type = None
if not any([file_data, file_id, file_url]):
raise ValueError(
f"File content must have at least one of 'file_data', 'file_id', or 'file_url'. "
f"Got file_data={file_data}, file_id={file_id}, file_url={file_url}"
)
if file_id:
if files_api is None:
raise ValueError("file_ids are not supported by this implementation of the Stack")
file_response = await files_api.openai_retrieve_file(file_id)
if not filename:
filename = file_response.filename
file_mime_type, _ = mimetypes.guess_type(file_response.filename)
raw_file_bytes = await extract_bytes_from_file(file_id, files_api)
ascii_text = generate_base64_ascii_text_from_bytes(raw_file_bytes)
resolved_file_data = construct_data_url(ascii_text, file_mime_type)
elif file_data:
if file_data.startswith("data:"):
resolved_file_data = file_data
else:
# Raw base64 data, wrap in data URL format
if filename:
file_mime_type, _ = mimetypes.guess_type(filename)
resolved_file_data = construct_data_url(file_data, file_mime_type)
elif file_url:
resolved_file_data = file_url
converted_parts.append(
OpenAIFile(
file=OpenAIFileFile(
file_data=resolved_file_data,
filename=filename,
)
)
)
elif isinstance(content_part, str):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
else:
@ -110,12 +221,14 @@ async def convert_response_content_to_chat_content(
async def convert_response_input_to_chat_messages(
input: str | list[OpenAIResponseInput],
previous_messages: list[OpenAIMessageParam] | None = None,
files_api: Files | None = None,
) -> list[OpenAIMessageParam]:
"""
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
:param input: The input to convert
:param previous_messages: Optional previous messages to check for function_call references
:param files_api: Files API for resolving file_id to raw file content (optional, required for file/image content)
"""
messages: list[OpenAIMessageParam] = []
if isinstance(input, list):
@ -169,6 +282,12 @@ async def convert_response_input_to_chat_messages(
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
# the tool list will be handled separately
pass
elif isinstance(
input_item,
OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFileSearchToolCall,
):
# these tool calls are tracked internally but not converted to chat messages
pass
elif isinstance(input_item, OpenAIResponseMCPApprovalRequest) or isinstance(
input_item, OpenAIResponseMCPApprovalResponse
):
@ -176,7 +295,7 @@ async def convert_response_input_to_chat_messages(
pass
elif isinstance(input_item, OpenAIResponseMessage):
# Narrow type to OpenAIResponseMessage which has content and role attributes
content = await convert_response_content_to_chat_content(input_item.content)
content = await convert_response_content_to_chat_content(input_item.content, files_api)
message_type = await get_message_type_by_role(input_item.role)
if message_type is None:
raise ValueError(
@ -320,11 +439,15 @@ def is_function_tool_call(
return False
async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
async def run_guardrails(safety_api: Safety | None, messages: str, guardrail_ids: list[str]) -> str | None:
"""Run guardrails against messages and return violation message if blocked."""
if not messages:
return None
# If safety API is not available, skip guardrails
if safety_api is None:
return None
# Look up shields to get their provider_resource_id (actual model ID)
model_ids = []
# TODO: list_shields not in Safety interface but available at runtime via API routing

View file

@ -7,7 +7,7 @@
from typing import Any
from llama_stack.core.datatypes import AccessRule, Api
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack_api import Files, Inference, Models
from .batches import ReferenceBatchesImpl

View file

@ -16,8 +16,8 @@ from typing import Any
from openai.types.batch import BatchError, Errors
from pydantic import BaseModel
from llama_stack.core.storage.kvstore import KVStore
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack_api import (
Batches,
BatchObject,

View file

@ -5,8 +5,8 @@
# the root directory of this source tree.
from typing import Any
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from llama_stack_api import Dataset, DatasetIO, DatasetsProtocolPrivate, PaginatedResponse

View file

@ -8,8 +8,8 @@ from typing import Any
from tqdm import tqdm
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack_api import (
Agents,
Benchmark,

View file

@ -13,11 +13,10 @@ from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from llama_stack_api import (
ExpiresAfter,
Files,
@ -28,6 +27,7 @@ from llama_stack_api import (
Order,
ResourceNotFoundError,
)
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
from .config import LocalfsFilesImplConfig

View file

@ -14,9 +14,8 @@ import faiss # type: ignore[import-untyped]
import numpy as np
from numpy.typing import NDArray
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from llama_stack_api import (
@ -32,6 +31,7 @@ from llama_stack_api import (
VectorStoreNotFoundError,
VectorStoresProtocolPrivate,
)
from llama_stack_api.internal.kvstore import KVStore
from .config import FaissVectorIOConfig

View file

@ -14,9 +14,8 @@ import numpy as np
import sqlite_vec # type: ignore[import-untyped]
from numpy.typing import NDArray
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
@ -35,6 +34,7 @@ from llama_stack_api import (
VectorStoreNotFoundError,
VectorStoresProtocolPrivate,
)
from llama_stack_api.internal.kvstore import KVStore
logger = get_logger(name=__name__, category="vector_io")

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.providers.utils.kvstore import kvstore_dependencies
from llama_stack.core.storage.kvstore import kvstore_dependencies
from llama_stack_api import (
Api,
InlineProviderSpec,
@ -30,11 +30,15 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
api_dependencies=[
Api.inference,
Api.safety,
Api.vector_io,
Api.tool_runtime,
Api.tool_groups,
Api.conversations,
Api.prompts,
Api.files,
],
optional_api_dependencies=[
Api.safety,
],
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
),

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
from llama_stack.core.storage.sqlstore.sqlstore import sql_store_pip_packages
from llama_stack_api import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec

View file

@ -6,7 +6,7 @@
from typing import Any
from urllib.parse import parse_qs, urlparse
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from llama_stack_api import Dataset, DatasetIO, DatasetsProtocolPrivate, PaginatedResponse

View file

@ -10,10 +10,9 @@ from typing import Annotated, Any
from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from llama_stack_api import (
ExpiresAfter,
Files,
@ -24,6 +23,7 @@ from llama_stack_api import (
Order,
ResourceNotFoundError,
)
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
from openai import OpenAI
from .config import OpenAIFilesImplConfig

View file

@ -19,10 +19,9 @@ if TYPE_CHECKING:
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from llama_stack_api import (
ExpiresAfter,
Files,
@ -33,6 +32,7 @@ from llama_stack_api import (
Order,
ResourceNotFoundError,
)
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
from .config import S3FilesImplConfig

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from urllib.parse import urljoin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AzureConfig
@ -22,4 +20,4 @@ class AzureInferenceAdapter(OpenAIMixin):
Returns the Azure API base URL from the configuration.
"""
return urljoin(str(self.config.api_base), "/openai/v1")
return str(self.config.base_url)

View file

@ -32,8 +32,9 @@ class AzureProviderDataValidator(BaseModel):
@json_schema_type
class AzureConfig(RemoteInferenceProviderConfig):
api_base: HttpUrl = Field(
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
base_url: HttpUrl | None = Field(
default=None,
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com/openai/v1)",
)
api_version: str | None = Field(
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
@ -48,14 +49,14 @@ class AzureConfig(RemoteInferenceProviderConfig):
def sample_run_config(
cls,
api_key: str = "${env.AZURE_API_KEY:=}",
api_base: str = "${env.AZURE_API_BASE:=}",
base_url: str = "${env.AZURE_API_BASE:=}",
api_version: str = "${env.AZURE_API_VERSION:=}",
api_type: str = "${env.AZURE_API_TYPE:=}",
**kwargs,
) -> dict[str, Any]:
return {
"api_key": api_key,
"api_base": api_base,
"base_url": base_url,
"api_version": api_version,
"api_type": api_type,
}

View file

@ -37,7 +37,7 @@ class BedrockInferenceAdapter(OpenAIMixin):
"""
config: BedrockConfig
provider_data_api_key_field: str = "aws_bedrock_api_key"
provider_data_api_key_field: str = "aws_bearer_token_bedrock"
def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""
@ -111,7 +111,7 @@ class BedrockInferenceAdapter(OpenAIMixin):
logger.error(f"AWS Bedrock authentication token expired: {error_msg}")
raise ValueError(
"AWS Bedrock authentication failed: Bearer token has expired. "
"The AWS_BEDROCK_API_KEY environment variable contains an expired pre-signed URL. "
"The AWS_BEARER_TOKEN_BEDROCK environment variable contains an expired pre-signed URL. "
"Please refresh your token by generating a new pre-signed URL with AWS credentials. "
"Refer to AWS Bedrock documentation for details on OpenAI-compatible endpoints."
) from e

View file

@ -12,9 +12,9 @@ from llama_stack.providers.utils.inference.model_registry import RemoteInference
class BedrockProviderDataValidator(BaseModel):
aws_bedrock_api_key: str | None = Field(
aws_bearer_token_bedrock: str | None = Field(
default=None,
description="API key for Amazon Bedrock",
description="API Key (Bearer token) for Amazon Bedrock",
)
@ -27,6 +27,6 @@ class BedrockConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(cls, **kwargs):
return {
"api_key": "${env.AWS_BEDROCK_API_KEY:=}",
"api_key": "${env.AWS_BEARER_TOKEN_BEDROCK:=}",
"region_name": "${env.AWS_DEFAULT_REGION:=us-east-2}",
}

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from urllib.parse import urljoin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack_api import (
OpenAIEmbeddingsRequestWithExtraBody,
@ -21,7 +19,7 @@ class CerebrasInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "cerebras_api_key"
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")
return str(self.config.base_url)
async def openai_embeddings(
self,

View file

@ -7,12 +7,12 @@
import os
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
DEFAULT_BASE_URL = "https://api.cerebras.ai"
DEFAULT_BASE_URL = "https://api.cerebras.ai/v1"
class CerebrasProviderDataValidator(BaseModel):
@ -24,8 +24,8 @@ class CerebrasProviderDataValidator(BaseModel):
@json_schema_type
class CerebrasImplConfig(RemoteInferenceProviderConfig):
base_url: str = Field(
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
base_url: HttpUrl | None = Field(
default=HttpUrl(os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL)),
description="Base URL for the Cerebras API",
)

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,9 +21,9 @@ class DatabricksProviderDataValidator(BaseModel):
@json_schema_type
class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
base_url: HttpUrl | None = Field(
default=None,
description="The URL for the Databricks model serving endpoint",
description="The URL for the Databricks model serving endpoint (should include /serving-endpoints path)",
)
auth_credential: SecretStr | None = Field(
default=None,
@ -34,11 +34,11 @@ class DatabricksImplConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(
cls,
url: str = "${env.DATABRICKS_HOST:=}",
base_url: str = "${env.DATABRICKS_HOST:=}",
api_token: str = "${env.DATABRICKS_TOKEN:=}",
**kwargs: Any,
) -> dict[str, Any]:
return {
"url": url,
"base_url": base_url,
"api_token": api_token,
}

View file

@ -29,15 +29,21 @@ class DatabricksInferenceAdapter(OpenAIMixin):
}
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
return str(self.config.base_url)
async def list_provider_model_ids(self) -> Iterable[str]:
# Filter out None values from endpoint names
api_token = self._get_api_key_from_config_or_provider_data()
# WorkspaceClient expects base host without /serving-endpoints suffix
base_url_str = str(self.config.base_url)
if base_url_str.endswith("/serving-endpoints"):
host = base_url_str[:-18] # Remove '/serving-endpoints'
else:
host = base_url_str
return [
endpoint.name # type: ignore[misc]
for endpoint in WorkspaceClient(
host=self.config.url, token=api_token
host=host, token=api_token
).serving_endpoints.list() # TODO: this is not async
]

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field
from pydantic import Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -14,14 +14,14 @@ from llama_stack_api import json_schema_type
@json_schema_type
class FireworksImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.fireworks.ai/inference/v1",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.fireworks.ai/inference/v1"),
description="The URL for the Fireworks server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.fireworks.ai/inference/v1",
"base_url": "https://api.fireworks.ai/inference/v1",
"api_key": api_key,
}

View file

@ -24,4 +24,4 @@ class FireworksInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "fireworks_api_key"
def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1"
return str(self.config.base_url)

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,14 +21,14 @@ class GroqProviderDataValidator(BaseModel):
@json_schema_type
class GroqConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.groq.com",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.groq.com/openai/v1"),
description="The URL for the Groq AI server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.groq.com",
"base_url": "https://api.groq.com/openai/v1",
"api_key": api_key,
}

View file

@ -15,4 +15,4 @@ class GroqInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "groq_api_key"
def get_base_url(self) -> str:
return f"{self.config.url}/openai/v1"
return str(self.config.base_url)

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,14 +21,14 @@ class LlamaProviderDataValidator(BaseModel):
@json_schema_type
class LlamaCompatConfig(RemoteInferenceProviderConfig):
openai_compat_api_base: str = Field(
default="https://api.llama.com/compat/v1/",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.llama.com/compat/v1/"),
description="The URL for the Llama API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.llama.com/compat/v1/",
"base_url": "https://api.llama.com/compat/v1/",
"api_key": api_key,
}

View file

@ -31,7 +31,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
:return: The Llama API base URL
"""
return self.config.openai_compat_api_base
return str(self.config.base_url)
async def openai_completion(
self,

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -44,18 +44,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
URL of your running NVIDIA NIM and do not need to set the api_key.
"""
url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
base_url: HttpUrl | None = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com/v1"),
description="A base url for accessing the NVIDIA NIM",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",
)
append_api_version: bool = Field(
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
)
rerank_model_to_url: dict[str, str] = Field(
default_factory=lambda: {
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
@ -68,13 +64,11 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(
cls,
url: str = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}",
base_url: HttpUrl | None = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}",
api_key: str = "${env.NVIDIA_API_KEY:=}",
append_api_version: bool = "${env.NVIDIA_APPEND_API_VERSION:=True}",
**kwargs,
) -> dict[str, Any]:
return {
"url": url,
"base_url": base_url,
"api_key": api_key,
"append_api_version": append_api_version,
}

View file

@ -44,7 +44,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
}
async def initialize(self) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.base_url})...")
if _is_nvidia_hosted(self.config):
if not self.config.auth_credential:
@ -72,7 +72,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
:return: The NVIDIA API base URL
"""
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
return str(self.config.base_url)
async def list_provider_model_ids(self) -> Iterable[str]:
"""

View file

@ -8,4 +8,4 @@ from . import NVIDIAConfig
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
return "integrate.api.nvidia.com" in config.url
return "integrate.api.nvidia.com" in str(config.base_url)

View file

@ -6,20 +6,22 @@
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
DEFAULT_OLLAMA_URL = "http://localhost:11434"
DEFAULT_OLLAMA_URL = "http://localhost:11434/v1"
class OllamaImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = DEFAULT_OLLAMA_URL
base_url: HttpUrl | None = Field(default=HttpUrl(DEFAULT_OLLAMA_URL))
@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
def sample_run_config(
cls, base_url: str = "${env.OLLAMA_URL:=http://localhost:11434/v1}", **kwargs
) -> dict[str, Any]:
return {
"url": url,
"base_url": base_url,
}

View file

@ -55,17 +55,23 @@ class OllamaInferenceAdapter(OpenAIMixin):
# ollama client attaches itself to the current event loop (sadly?)
loop = asyncio.get_running_loop()
if loop not in self._clients:
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
# Ollama client expects base URL without /v1 suffix
base_url_str = str(self.config.base_url)
if base_url_str.endswith("/v1"):
host = base_url_str[:-3]
else:
host = base_url_str
self._clients[loop] = AsyncOllamaClient(host=host)
return self._clients[loop]
def get_api_key(self):
return "NO KEY REQUIRED"
def get_base_url(self):
return self.config.url.rstrip("/") + "/v1"
return str(self.config.base_url)
async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
logger.info(f"checking connectivity to Ollama at `{self.config.base_url}`...")
r = await self.health()
if r["status"] == HealthStatus.ERROR:
logger.warning(

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,8 +21,8 @@ class OpenAIProviderDataValidator(BaseModel):
@json_schema_type
class OpenAIConfig(RemoteInferenceProviderConfig):
base_url: str = Field(
default="https://api.openai.com/v1",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.openai.com/v1"),
description="Base URL for OpenAI API",
)

View file

@ -35,4 +35,4 @@ class OpenAIInferenceAdapter(OpenAIMixin):
Returns the OpenAI API base URL from the configuration.
"""
return self.config.base_url
return str(self.config.base_url)

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field
from pydantic import Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -14,16 +14,16 @@ from llama_stack_api import json_schema_type
@json_schema_type
class PassthroughImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
base_url: HttpUrl | None = Field(
default=None,
description="The URL for the passthrough endpoint",
)
@classmethod
def sample_run_config(
cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
cls, base_url: HttpUrl | None = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
) -> dict[str, Any]:
return {
"url": url,
"base_url": base_url,
"api_key": api_key,
}

View file

@ -82,8 +82,8 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
def _get_passthrough_url(self) -> str:
"""Get the passthrough URL from config or provider data."""
if self.config.url is not None:
return self.config.url
if self.config.base_url is not None:
return str(self.config.base_url)
provider_data = self.get_request_provider_data()
if provider_data is None:

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,7 +21,7 @@ class RunpodProviderDataValidator(BaseModel):
@json_schema_type
class RunpodImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
base_url: HttpUrl | None = Field(
default=None,
description="The URL for the Runpod model serving endpoint",
)
@ -34,6 +34,6 @@ class RunpodImplConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"url": "${env.RUNPOD_URL:=}",
"base_url": "${env.RUNPOD_URL:=}",
"api_token": "${env.RUNPOD_API_TOKEN}",
}

View file

@ -28,7 +28,7 @@ class RunpodInferenceAdapter(OpenAIMixin):
def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""
return self.config.url
return str(self.config.base_url)
async def openai_chat_completion(
self,

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -21,14 +21,14 @@ class SambaNovaProviderDataValidator(BaseModel):
@json_schema_type
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.sambanova.ai/v1",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.sambanova.ai/v1"),
description="The URL for the SambaNova AI server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"base_url": "https://api.sambanova.ai/v1",
"api_key": api_key,
}

View file

@ -25,4 +25,4 @@ class SambaNovaInferenceAdapter(OpenAIMixin):
:return: The SambaNova base URL
"""
return self.config.url
return str(self.config.base_url)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -15,18 +15,19 @@ from llama_stack_api import json_schema_type
class TGIImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = Field(
description="The URL for the TGI serving endpoint",
base_url: HttpUrl | None = Field(
default=None,
description="The URL for the TGI serving endpoint (should include /v1 path)",
)
@classmethod
def sample_run_config(
cls,
url: str = "${env.TGI_URL:=}",
base_url: str = "${env.TGI_URL:=}",
**kwargs,
):
return {
"url": url,
"base_url": base_url,
}

View file

@ -8,7 +8,7 @@
from collections.abc import Iterable
from huggingface_hub import AsyncInferenceClient, HfApi
from pydantic import SecretStr
from pydantic import HttpUrl, SecretStr
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -23,7 +23,7 @@ log = get_logger(name=__name__, category="inference::tgi")
class _HfAdapter(OpenAIMixin):
url: str
base_url: HttpUrl
api_key: SecretStr
hf_client: AsyncInferenceClient
@ -36,7 +36,7 @@ class _HfAdapter(OpenAIMixin):
return "NO KEY REQUIRED"
def get_base_url(self):
return self.url
return self.base_url
async def list_provider_model_ids(self) -> Iterable[str]:
return [self.model_id]
@ -50,14 +50,20 @@ class _HfAdapter(OpenAIMixin):
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
if not config.url:
if not config.base_url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
log.info(f"Initializing TGI client with url={config.base_url}")
# Extract base URL without /v1 for HF client initialization
base_url_str = str(config.base_url).rstrip("/")
if base_url_str.endswith("/v1"):
base_url_for_client = base_url_str[:-3]
else:
base_url_for_client = base_url_str
self.hf_client = AsyncInferenceClient(model=base_url_for_client, provider="hf-inference")
endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
self.url = f"{config.url.rstrip('/')}/v1"
self.base_url = config.base_url
self.api_key = SecretStr("NO_KEY")

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field
from pydantic import Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -14,14 +14,14 @@ from llama_stack_api import json_schema_type
@json_schema_type
class TogetherImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.together.xyz/v1",
base_url: HttpUrl | None = Field(
default=HttpUrl("https://api.together.xyz/v1"),
description="The URL for the Together AI server",
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "https://api.together.xyz/v1",
"base_url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY:=}",
}

View file

@ -9,7 +9,6 @@ from collections.abc import Iterable
from typing import Any, cast
from together import AsyncTogether # type: ignore[import-untyped]
from together.constants import BASE_URL # type: ignore[import-untyped]
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
@ -42,7 +41,7 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
provider_data_api_key_field: str = "together_api_key"
def get_base_url(self):
return BASE_URL
return str(self.config.base_url)
def _get_client(self) -> AsyncTogether:
together_api_key = None

View file

@ -51,4 +51,4 @@ class VertexAIInferenceAdapter(OpenAIMixin):
:return: An iterable of model IDs
"""
return ["vertexai/gemini-2.0-flash", "vertexai/gemini-2.5-flash", "vertexai/gemini-2.5-pro"]
return ["google/gemini-2.0-flash", "google/gemini-2.5-flash", "google/gemini-2.5-pro"]

View file

@ -6,7 +6,7 @@
from pathlib import Path
from pydantic import Field, SecretStr, field_validator
from pydantic import Field, HttpUrl, SecretStr, field_validator
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -14,7 +14,7 @@ from llama_stack_api import json_schema_type
@json_schema_type
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
base_url: HttpUrl | None = Field(
default=None,
description="The URL for the vLLM model serving endpoint",
)
@ -48,11 +48,11 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(
cls,
url: str = "${env.VLLM_URL:=}",
base_url: str = "${env.VLLM_URL:=}",
**kwargs,
):
return {
"url": url,
"base_url": base_url,
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
"api_token": "${env.VLLM_API_TOKEN:=fake}",
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",

View file

@ -39,12 +39,12 @@ class VLLMInferenceAdapter(OpenAIMixin):
def get_base_url(self) -> str:
"""Get the base URL from config."""
if not self.config.url:
if not self.config.base_url:
raise ValueError("No base URL configured")
return self.config.url
return str(self.config.base_url)
async def initialize(self) -> None:
if not self.config.url:
if not self.config.base_url:
raise ValueError(
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
)

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type
@ -23,7 +23,7 @@ class WatsonXProviderDataValidator(BaseModel):
@json_schema_type
class WatsonXConfig(RemoteInferenceProviderConfig):
url: str = Field(
base_url: HttpUrl | None = Field(
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai",
)
@ -39,7 +39,7 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
"base_url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
"api_key": "${env.WATSONX_API_KEY:=}",
"project_id": "${env.WATSONX_PROJECT_ID:=}",
}

View file

@ -255,7 +255,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
)
def get_base_url(self) -> str:
return self.config.url
return str(self.config.base_url)
# Copied from OpenAIMixin
async def check_model_availability(self, model: str) -> bool:
@ -316,7 +316,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
"""
Retrieves foundation model specifications from the watsonx.ai API.
"""
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
url = f"{str(self.config.base_url)}/ml/v1/foundation_model_specs?version=2023-10-25"
headers = {
# Note that there is no authorization header. Listing models does not require authentication.
"Content-Type": "application/json",

View file

@ -48,16 +48,10 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
# Phase 1: Support both old header-based auth AND new authorization parameter
# Get headers and auth from provider data (old approach)
provider_headers, provider_auth = await self.get_headers_from_request(mcp_endpoint.uri)
# Get other headers from provider data (but NOT authorization)
provider_headers = await self.get_headers_from_request(mcp_endpoint.uri)
# New authorization parameter takes precedence over provider data
final_authorization = authorization or provider_auth
return await list_mcp_tools(
endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=final_authorization
)
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=authorization)
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
@ -69,39 +63,38 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
# Phase 1: Support both old header-based auth AND new authorization parameter
# Get headers and auth from provider data (old approach)
provider_headers, provider_auth = await self.get_headers_from_request(endpoint)
# New authorization parameter takes precedence over provider data
final_authorization = authorization or provider_auth
# Get other headers from provider data (but NOT authorization)
provider_headers = await self.get_headers_from_request(endpoint)
return await invoke_mcp_tool(
endpoint=endpoint,
tool_name=tool_name,
kwargs=kwargs,
headers=provider_headers,
authorization=final_authorization,
authorization=authorization,
)
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
"""
Extract headers and authorization from request provider data (Phase 1 backward compatibility).
Extract headers from request provider data, excluding authorization.
Phase 1: Temporarily allows Authorization to be passed via mcp_headers for backward compatibility.
Phase 2: Will enforce that Authorization should use the dedicated authorization parameter instead.
Authorization must be provided via the dedicated authorization parameter.
If Authorization is found in mcp_headers, raise an error to guide users to the correct approach.
Args:
mcp_endpoint_uri: The MCP endpoint URI to match against provider data
Returns:
Tuple of (headers_dict, authorization_token)
- headers_dict: All headers except Authorization
- authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None
dict[str, str]: Headers dictionary (without Authorization)
Raises:
ValueError: If Authorization header is found in mcp_headers
"""
def canonicalize_uri(uri: str) -> str:
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
headers = {}
authorization = None
provider_data = self.get_request_provider_data()
if provider_data and hasattr(provider_data, "mcp_headers") and provider_data.mcp_headers:
@ -109,17 +102,14 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
# Phase 1: Extract Authorization from mcp_headers for backward compatibility
# (Phase 2 will reject this and require the dedicated authorization parameter)
# Reject Authorization in mcp_headers - must use authorization parameter
for key in values.keys():
if key.lower() == "authorization":
# Extract authorization token and strip "Bearer " prefix if present
auth_value = values[key]
if auth_value.startswith("Bearer "):
authorization = auth_value[7:] # Remove "Bearer " prefix
else:
authorization = auth_value
else:
headers[key] = values[key]
raise ValueError(
"Authorization cannot be provided via mcp_headers in provider_data. "
"Please use the dedicated 'authorization' parameter instead. "
"Example: tool_runtime.invoke_tool(..., authorization='your-token')"
)
headers[key] = values[key]
return headers, authorization
return headers

View file

@ -11,10 +11,9 @@ from urllib.parse import urlparse
import chromadb
from numpy.typing import NDArray
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from llama_stack_api import (
@ -27,6 +26,7 @@ from llama_stack_api import (
VectorStore,
VectorStoresProtocolPrivate,
)
from llama_stack_api.internal.kvstore import KVStore
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig

View file

@ -11,10 +11,9 @@ from typing import Any
from numpy.typing import NDArray
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_WEIGHTED,
@ -34,6 +33,7 @@ from llama_stack_api import (
VectorStoreNotFoundError,
VectorStoresProtocolPrivate,
)
from llama_stack_api.internal.kvstore import KVStore
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig

View file

@ -13,10 +13,9 @@ from psycopg2 import sql
from psycopg2.extras import Json, execute_values
from pydantic import BaseModel, TypeAdapter
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
@ -31,6 +30,7 @@ from llama_stack_api import (
VectorStoreNotFoundError,
VectorStoresProtocolPrivate,
)
from llama_stack_api.internal.kvstore import KVStore
from .config import PGVectorVectorIOConfig

View file

@ -13,9 +13,9 @@ from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from llama_stack_api import (

View file

@ -13,9 +13,8 @@ from weaviate.classes.init import Auth
from weaviate.classes.query import Filter, HybridFusion
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
@ -35,6 +34,7 @@ from llama_stack_api import (
VectorStoreNotFoundError,
VectorStoresProtocolPrivate,
)
from llama_stack_api.internal.kvstore import KVStore
from .config import WeaviateVectorIOConfig

View file

@ -10,6 +10,8 @@ from sqlalchemy.exc import IntegrityError
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.storage.datatypes import InferenceStoreReference, StorageBackendType
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.core.storage.sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
from llama_stack.log import get_logger
from llama_stack_api import (
ListOpenAIChatCompletionResponse,
@ -18,10 +20,7 @@ from llama_stack_api import (
OpenAIMessageParam,
Order,
)
from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
from ..sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
logger = get_logger(name=__name__, category="inference")

View file

@ -3,23 +3,10 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Iterable
from typing import (
Any,
)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
try:
from openai.types.chat import (
ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall,
)
except ImportError:
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall,
)
from openai.types.chat import (
ChatCompletionMessageToolCall,
)
@ -32,18 +19,6 @@ from llama_stack.models.llama.datatypes import (
ToolCall,
ToolDefinition,
)
from llama_stack_api import (
URL,
GreedySamplingStrategy,
ImageContentItem,
JsonSchemaResponseFormat,
OpenAIResponseFormatParam,
SamplingParams,
TextContentItem,
TopKSamplingStrategy,
TopPSamplingStrategy,
_URLOrData,
)
logger = get_logger(name=__name__, category="providers::utils")
@ -73,42 +48,6 @@ class OpenAICompatCompletionResponse(BaseModel):
choices: list[OpenAICompatCompletionChoice]
def get_sampling_strategy_options(params: SamplingParams) -> dict:
options = {}
if isinstance(params.strategy, GreedySamplingStrategy):
options["temperature"] = 0.0
elif isinstance(params.strategy, TopPSamplingStrategy):
if params.strategy.temperature is not None:
options["temperature"] = params.strategy.temperature
if params.strategy.top_p is not None:
options["top_p"] = params.strategy.top_p
elif isinstance(params.strategy, TopKSamplingStrategy):
options["top_k"] = params.strategy.top_k
else:
raise ValueError(f"Unsupported sampling strategy: {params.strategy}")
return options
def get_sampling_options(params: SamplingParams | None) -> dict:
if not params:
return {}
options = {}
if params:
options.update(get_sampling_strategy_options(params))
if params.max_tokens:
options["max_tokens"] = params.max_tokens
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
options["repeat_penalty"] = params.repetition_penalty
if params.stop is not None:
options["stop"] = params.stop
return options
def text_from_choice(choice) -> str:
if hasattr(choice, "delta") and choice.delta:
return choice.delta.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations
@ -253,154 +192,6 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
return out
def _convert_stop_reason_to_openai_finish_reason(stop_reason: StopReason) -> str:
"""
Convert a StopReason to an OpenAI chat completion finish_reason.
"""
return {
StopReason.end_of_turn: "stop",
StopReason.end_of_message: "tool_calls",
StopReason.out_of_tokens: "length",
}.get(stop_reason, "stop")
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
"""
Convert an OpenAI chat completion finish_reason to a StopReason.
finish_reason: Literal["stop", "length", "tool_calls", ...]
- stop: model hit a natural stop point or a provided stop sequence
- length: maximum number of tokens specified in the request was reached
- tool_calls: model called a tool
->
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
# TODO(mf): are end_of_turn and end_of_message semantics correct?
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]:
lls_tools: list[ToolDefinition] = []
if not tools:
return lls_tools
for tool in tools:
tool_fn = tool.get("function", {})
tool_name = tool_fn.get("name", None)
tool_desc = tool_fn.get("description", None)
tool_params = tool_fn.get("parameters", None)
lls_tool = ToolDefinition(
tool_name=tool_name,
description=tool_desc,
input_schema=tool_params, # Pass through entire JSON Schema
)
lls_tools.append(lls_tool)
return lls_tools
def _convert_openai_request_response_format(
response_format: OpenAIResponseFormatParam | None = None,
):
if not response_format:
return None
# response_format can be a dict or a pydantic model
response_format_dict = dict(response_format) # type: ignore[arg-type] # OpenAIResponseFormatParam union needs dict conversion
if response_format_dict.get("type", "") == "json_schema":
return JsonSchemaResponseFormat(
type="json_schema", # type: ignore[arg-type] # Literal["json_schema"] incompatible with expected type
json_schema=response_format_dict.get("json_schema", {}).get("schema", ""),
)
return None
def _convert_openai_tool_calls(
tool_calls: list[OpenAIChatCompletionMessageFunctionToolCall],
) -> list[ToolCall]:
"""
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
OpenAI ChatCompletionMessageToolCall:
id: str
function: Function
type: Literal["function"]
OpenAI Function:
arguments: str
name: str
->
ToolCall:
call_id: str
tool_name: str
arguments: Dict[str, ...]
"""
if not tool_calls:
return [] # CompletionMessage tool_calls is not optional
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call.function.arguments,
)
for call in tool_calls
]
def _convert_openai_sampling_params(
max_tokens: int | None = None,
temperature: float | None = None,
top_p: float | None = None,
) -> SamplingParams:
sampling_params = SamplingParams()
if max_tokens:
sampling_params.max_tokens = max_tokens
# Map an explicit temperature of 0 to greedy sampling
if temperature == 0:
sampling_params.strategy = GreedySamplingStrategy()
else:
# OpenAI defaults to 1.0 for temperature and top_p if unset
if temperature is None:
temperature = 1.0
if top_p is None:
top_p = 1.0
sampling_params.strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p) # type: ignore[assignment] # SamplingParams.strategy union accepts this type
return sampling_params
def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam] | None):
if content is None:
return ""
if isinstance(content, str):
return content
elif isinstance(content, list):
return [openai_content_to_content(c) for c in content]
elif hasattr(content, "type"):
if content.type == "text":
return TextContentItem(type="text", text=content.text) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
elif content.type == "image_url":
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
else:
raise ValueError(f"Unknown content type: {content.type}")
else:
raise ValueError(f"Unknown content type: {content}")
async def prepare_openai_completion_params(**params):
async def _prepare_value(value: Any) -> Any:
new_value = value

View file

@ -213,6 +213,19 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
return api_key
def _validate_model_allowed(self, provider_model_id: str) -> None:
"""
Validate that the model is in the allowed_models list if configured.
:param provider_model_id: The provider-specific model ID to validate
:raises ValueError: If the model is not in the allowed_models list
"""
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
raise ValueError(
f"Model '{provider_model_id}' is not in the allowed models list. "
f"Allowed models: {self.config.allowed_models}"
)
async def _get_provider_model_id(self, model: str) -> str:
"""
Get the provider-specific model ID from the model store.
@ -259,8 +272,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
Direct OpenAI completion API call.
"""
# TODO: fix openai_completion to return type compatible with OpenAI's API response
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)
completion_kwargs = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model),
model=provider_model_id,
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
@ -292,6 +308,9 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
"""
Direct OpenAI chat completion API call.
"""
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)
messages = params.messages
if self.download_images:
@ -313,7 +332,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
messages = [await _localize_image_url(m) for m in messages]
request_params = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model),
model=provider_model_id,
messages=messages,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
@ -351,10 +370,13 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
"""
Direct OpenAI embeddings API call.
"""
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)
# Build request params conditionally to avoid NotGiven/Omit type mismatch
# The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven
request_params: dict[str, Any] = {
"model": await self._get_provider_model_id(params.model),
"model": provider_model_id,
"input": params.input,
}
if params.encoding_format is not None:

View file

@ -1,20 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_stack_api import json_schema_type
@json_schema_type
class SqliteControlPlaneConfig(BaseModel):
db_path: str = Field(
description="File path for the sqlite database",
)
table_name: str = Field(
default="llamastack_control_plane",
description="Table into which all the keys will be placed",
)

View file

@ -17,7 +17,6 @@ from pydantic import TypeAdapter
from llama_stack.core.id_generation import generate_object_id
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
content_from_data_and_mime_type,
@ -53,6 +52,7 @@ from llama_stack_api import (
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack_api.internal.kvstore import KVStore
EMBEDDING_DIMENSION = 768

View file

@ -6,6 +6,8 @@
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqlStoreReference
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
from llama_stack.log import get_logger
from llama_stack_api import (
ListOpenAIResponseInputItem,
@ -17,10 +19,7 @@ from llama_stack_api import (
OpenAIResponseObjectWithInput,
Order,
)
from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
from ..sqlstore.sqlstore import sqlstore_impl
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
logger = get_logger(name=__name__, category="openai_responses")

View file

@ -1,140 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Literal, Protocol
from pydantic import BaseModel
from llama_stack_api import PaginatedResponse
class ColumnType(Enum):
INTEGER = "INTEGER"
STRING = "STRING"
TEXT = "TEXT"
FLOAT = "FLOAT"
BOOLEAN = "BOOLEAN"
JSON = "JSON"
DATETIME = "DATETIME"
class ColumnDefinition(BaseModel):
type: ColumnType
primary_key: bool = False
nullable: bool = True
default: Any = None
class SqlStore(Protocol):
"""
A protocol for a SQL store.
"""
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
"""
Create a table.
"""
pass
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
"""
Insert a row or batch of rows into a table.
"""
pass
async def upsert(
self,
table: str,
data: Mapping[str, Any],
conflict_columns: list[str],
update_columns: list[str] | None = None,
) -> None:
"""
Insert a row and update specified columns when conflicts occur.
"""
pass
async def fetch_all(
self,
table: str,
where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None,
) -> PaginatedResponse:
"""
Fetch all rows from a table with optional cursor-based pagination.
:param table: The table name
:param where: Simple key-value WHERE conditions
:param where_sql: Raw SQL WHERE clause for complex queries
:param limit: Maximum number of records to return
:param order_by: List of (column, order) tuples for sorting
:param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page)
Requires order_by with exactly one column when used
:return: PaginatedResult with data and has_more flag
Note: Cursor pagination only supports single-column ordering for simplicity.
Multi-column ordering is allowed without cursor but will raise an error with cursor.
"""
pass
async def fetch_one(
self,
table: str,
where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None:
"""
Fetch one row from a table.
"""
pass
async def update(
self,
table: str,
data: Mapping[str, Any],
where: Mapping[str, Any],
) -> None:
"""
Update a row in a table.
"""
pass
async def delete(
self,
table: str,
where: Mapping[str, Any],
) -> None:
"""
Delete a row from a table.
"""
pass
async def add_column_if_not_exists(
self,
table: str,
column_name: str,
column_type: ColumnType,
nullable: bool = True,
) -> None:
"""
Add a column to an existing table if the column doesn't already exist.
This is useful for table migrations when adding new functionality.
If the table doesn't exist, this method should do nothing.
If the column already exists, this method should do nothing.
:param table: Table name
:param column_name: Name of the column to add
:param column_type: Type of the column to add
:param nullable: Whether the column should be nullable (default: True)
"""
pass

Some files were not shown because too many files have changed in this diff Show more