mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
Merge branch 'main' into routeur
This commit is contained in:
commit
3770963130
255 changed files with 18366 additions and 1909 deletions
|
|
@ -11,10 +11,9 @@ from typing import Any, Literal
|
|||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
||||
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack_api import (
|
||||
Conversation,
|
||||
ConversationDeletedResource,
|
||||
|
|
@ -25,6 +24,7 @@ from llama_stack_api import (
|
|||
Conversations,
|
||||
Metadata,
|
||||
)
|
||||
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_conversations")
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from typing import Any
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,9 +11,9 @@ from datetime import UTC, datetime, timedelta
|
|||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendType
|
||||
from llama_stack.core.storage.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
|
|
|||
|
|
@ -385,8 +385,8 @@ def _initialize_storage(run_config: StackRunConfig):
|
|||
else:
|
||||
raise ValueError(f"Unknown storage backend type: {type}")
|
||||
|
||||
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack.core.storage.kvstore.kvstore import register_kvstore_backends
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
register_kvstore_backends(kv_backends)
|
||||
register_sqlstore_backends(sql_backends)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ from typing import Annotated, Literal
|
|||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
|
||||
class StorageBackendType(StrEnum):
|
||||
KV_REDIS = "kv_redis"
|
||||
|
|
@ -256,15 +258,24 @@ class ResponsesStoreReference(InferenceStoreReference):
|
|||
|
||||
class ServerStoresConfig(BaseModel):
|
||||
metadata: KVStoreReference | None = Field(
|
||||
default=None,
|
||||
default=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="registry",
|
||||
),
|
||||
description="Metadata store configuration (uses KV backend)",
|
||||
)
|
||||
inference: InferenceStoreReference | None = Field(
|
||||
default=None,
|
||||
default=InferenceStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="inference_store",
|
||||
),
|
||||
description="Inference store configuration (uses SQL backend)",
|
||||
)
|
||||
conversations: SqlStoreReference | None = Field(
|
||||
default=None,
|
||||
default=SqlStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="openai_conversations",
|
||||
),
|
||||
description="Conversations store configuration (uses SQL backend)",
|
||||
)
|
||||
responses: ResponsesStoreReference | None = Field(
|
||||
|
|
@ -272,13 +283,21 @@ class ServerStoresConfig(BaseModel):
|
|||
description="Responses store configuration (uses SQL backend)",
|
||||
)
|
||||
prompts: KVStoreReference | None = Field(
|
||||
default=None,
|
||||
default=KVStoreReference(backend="kv_default", namespace="prompts"),
|
||||
description="Prompts store configuration (uses KV backend)",
|
||||
)
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
backends: dict[str, StorageBackendConfig] = Field(
|
||||
default={
|
||||
"kv_default": SqliteKVStoreConfig(
|
||||
db_path=f"${{env.SQLITE_STORE_DIR:={DISTRIBS_BASE_DIR}}}/kvstore.db",
|
||||
),
|
||||
"sql_default": SqliteSqlStoreConfig(
|
||||
db_path=f"${{env.SQLITE_STORE_DIR:={DISTRIBS_BASE_DIR}}}/sql_store.db",
|
||||
),
|
||||
},
|
||||
description="Named backend configurations (e.g., 'default', 'cache')",
|
||||
)
|
||||
stores: ServerStoresConfig = Field(
|
||||
|
|
|
|||
|
|
@ -4,4 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack_api.internal.kvstore import KVStore as KVStore
|
||||
|
||||
from .kvstore import * # noqa: F401, F403
|
||||
|
|
@ -13,11 +13,19 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from .api import KVStore
|
||||
from .config import KVStoreConfig
|
||||
from .config import (
|
||||
KVStoreConfig,
|
||||
MongoDBKVStoreConfig,
|
||||
PostgresKVStoreConfig,
|
||||
RedisKVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
def kvstore_dependencies():
|
||||
|
|
@ -33,7 +41,7 @@ def kvstore_dependencies():
|
|||
|
||||
class InmemoryKVStoreImpl(KVStore):
|
||||
def __init__(self):
|
||||
self._store = {}
|
||||
self._store: dict[str, str] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
|
@ -41,7 +49,7 @@ class InmemoryKVStoreImpl(KVStore):
|
|||
async def get(self, key: str) -> str | None:
|
||||
return self._store.get(key)
|
||||
|
||||
async def set(self, key: str, value: str) -> None:
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
self._store[key] = value
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
|
|
@ -70,7 +78,8 @@ def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None
|
|||
_KVSTORE_INSTANCES.clear()
|
||||
_KVSTORE_LOCKS.clear()
|
||||
for name, cfg in backends.items():
|
||||
_KVSTORE_BACKENDS[name] = cfg
|
||||
typed_cfg = cast(KVStoreConfig, cfg)
|
||||
_KVSTORE_BACKENDS[name] = typed_cfg
|
||||
|
||||
|
||||
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
|
||||
|
|
@ -94,19 +103,20 @@ async def kvstore_impl(reference: KVStoreReference) -> KVStore:
|
|||
config = backend_config.model_copy()
|
||||
config.namespace = reference.namespace
|
||||
|
||||
if config.type == StorageBackendType.KV_REDIS.value:
|
||||
impl: KVStore
|
||||
if isinstance(config, RedisKVStoreConfig):
|
||||
from .redis import RedisKVStoreImpl
|
||||
|
||||
impl = RedisKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_SQLITE.value:
|
||||
elif isinstance(config, SqliteKVStoreConfig):
|
||||
from .sqlite import SqliteKVStoreImpl
|
||||
|
||||
impl = SqliteKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_POSTGRES.value:
|
||||
elif isinstance(config, PostgresKVStoreConfig):
|
||||
from .postgres import PostgresKVStoreImpl
|
||||
|
||||
impl = PostgresKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_MONGODB.value:
|
||||
elif isinstance(config, MongoDBKVStoreConfig):
|
||||
from .mongodb import MongoDBKVStoreImpl
|
||||
|
||||
impl = MongoDBKVStoreImpl(config)
|
||||
|
|
@ -9,8 +9,8 @@ from datetime import datetime
|
|||
from pymongo import AsyncMongoClient
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
|
||||
from llama_stack.core.storage.kvstore import KVStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
from ..config import MongoDBKVStoreConfig
|
||||
|
||||
|
|
@ -6,12 +6,13 @@
|
|||
|
||||
from datetime import datetime
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import DictCursor
|
||||
import psycopg2 # type: ignore[import-not-found]
|
||||
from psycopg2.extensions import connection as PGConnection # type: ignore[import-not-found]
|
||||
from psycopg2.extras import DictCursor # type: ignore[import-not-found]
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from ..api import KVStore
|
||||
from ..config import PostgresKVStoreConfig
|
||||
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
|
@ -20,12 +21,12 @@ log = get_logger(name=__name__, category="providers::utils")
|
|||
class PostgresKVStoreImpl(KVStore):
|
||||
def __init__(self, config: PostgresKVStoreConfig):
|
||||
self.config = config
|
||||
self.conn = None
|
||||
self.cursor = None
|
||||
self._conn: PGConnection | None = None
|
||||
self._cursor: DictCursor | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
self.conn = psycopg2.connect(
|
||||
self._conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
database=self.config.db,
|
||||
|
|
@ -34,11 +35,11 @@ class PostgresKVStoreImpl(KVStore):
|
|||
sslmode=self.config.ssl_mode,
|
||||
sslrootcert=self.config.ca_cert_path,
|
||||
)
|
||||
self.conn.autocommit = True
|
||||
self.cursor = self.conn.cursor(cursor_factory=DictCursor)
|
||||
self._conn.autocommit = True
|
||||
self._cursor = self._conn.cursor(cursor_factory=DictCursor)
|
||||
|
||||
# Create table if it doesn't exist
|
||||
self.cursor.execute(
|
||||
self._cursor.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.config.table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
|
|
@ -51,6 +52,11 @@ class PostgresKVStoreImpl(KVStore):
|
|||
log.exception("Could not connect to PostgreSQL database server")
|
||||
raise RuntimeError("Could not connect to PostgreSQL database server") from e
|
||||
|
||||
def _cursor_or_raise(self) -> DictCursor:
|
||||
if self._cursor is None:
|
||||
raise RuntimeError("Postgres client not initialized")
|
||||
return self._cursor
|
||||
|
||||
def _namespaced_key(self, key: str) -> str:
|
||||
if not self.config.namespace:
|
||||
return key
|
||||
|
|
@ -58,7 +64,8 @@ class PostgresKVStoreImpl(KVStore):
|
|||
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
self.cursor.execute(
|
||||
cursor = self._cursor_or_raise()
|
||||
cursor.execute(
|
||||
f"""
|
||||
INSERT INTO {self.config.table_name} (key, value, expiration)
|
||||
VALUES (%s, %s, %s)
|
||||
|
|
@ -70,7 +77,8 @@ class PostgresKVStoreImpl(KVStore):
|
|||
|
||||
async def get(self, key: str) -> str | None:
|
||||
key = self._namespaced_key(key)
|
||||
self.cursor.execute(
|
||||
cursor = self._cursor_or_raise()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT value FROM {self.config.table_name}
|
||||
WHERE key = %s
|
||||
|
|
@ -78,12 +86,13 @@ class PostgresKVStoreImpl(KVStore):
|
|||
""",
|
||||
(key,),
|
||||
)
|
||||
result = self.cursor.fetchone()
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
self.cursor.execute(
|
||||
cursor = self._cursor_or_raise()
|
||||
cursor.execute(
|
||||
f"DELETE FROM {self.config.table_name} WHERE key = %s",
|
||||
(key,),
|
||||
)
|
||||
|
|
@ -92,7 +101,8 @@ class PostgresKVStoreImpl(KVStore):
|
|||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
|
||||
self.cursor.execute(
|
||||
cursor = self._cursor_or_raise()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT value FROM {self.config.table_name}
|
||||
WHERE key >= %s AND key < %s
|
||||
|
|
@ -101,14 +111,15 @@ class PostgresKVStoreImpl(KVStore):
|
|||
""",
|
||||
(start_key, end_key),
|
||||
)
|
||||
return [row[0] for row in self.cursor.fetchall()]
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
|
||||
self.cursor.execute(
|
||||
cursor = self._cursor_or_raise()
|
||||
cursor.execute(
|
||||
f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
|
||||
(start_key, end_key),
|
||||
)
|
||||
return [row[0] for row in self.cursor.fetchall()]
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
|
@ -6,18 +6,25 @@
|
|||
|
||||
from datetime import datetime
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from redis.asyncio import Redis # type: ignore[import-not-found]
|
||||
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from ..api import KVStore
|
||||
from ..config import RedisKVStoreConfig
|
||||
|
||||
|
||||
class RedisKVStoreImpl(KVStore):
|
||||
def __init__(self, config: RedisKVStoreConfig):
|
||||
self.config = config
|
||||
self._redis: Redis | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.redis = Redis.from_url(self.config.url)
|
||||
self._redis = Redis.from_url(self.config.url)
|
||||
|
||||
def _client(self) -> Redis:
|
||||
if self._redis is None:
|
||||
raise RuntimeError("Redis client not initialized")
|
||||
return self._redis
|
||||
|
||||
def _namespaced_key(self, key: str) -> str:
|
||||
if not self.config.namespace:
|
||||
|
|
@ -26,30 +33,37 @@ class RedisKVStoreImpl(KVStore):
|
|||
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.redis.set(key, value)
|
||||
client = self._client()
|
||||
await client.set(key, value)
|
||||
if expiration:
|
||||
await self.redis.expireat(key, expiration)
|
||||
await client.expireat(key, expiration)
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
key = self._namespaced_key(key)
|
||||
value = await self.redis.get(key)
|
||||
client = self._client()
|
||||
value = await client.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
await self.redis.ttl(key)
|
||||
return value
|
||||
await client.ttl(key)
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8")
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.redis.delete(key)
|
||||
await self._client().delete(key)
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
client = self._client()
|
||||
cursor = 0
|
||||
pattern = start_key + "*" # Match all keys starting with start_key prefix
|
||||
matching_keys = []
|
||||
matching_keys: list[str | bytes] = []
|
||||
while True:
|
||||
cursor, keys = await self.redis.scan(cursor, match=pattern, count=1000)
|
||||
cursor, keys = await client.scan(cursor, match=pattern, count=1000)
|
||||
|
||||
for key in keys:
|
||||
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
||||
|
|
@ -61,7 +75,7 @@ class RedisKVStoreImpl(KVStore):
|
|||
|
||||
# Then fetch all values in a single MGET call
|
||||
if matching_keys:
|
||||
values = await self.redis.mget(matching_keys)
|
||||
values = await client.mget(matching_keys)
|
||||
return [
|
||||
value.decode("utf-8") if isinstance(value, bytes) else value for value in values if value is not None
|
||||
]
|
||||
|
|
@ -70,7 +84,18 @@ class RedisKVStoreImpl(KVStore):
|
|||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
"""Get all keys in the given range."""
|
||||
matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}")
|
||||
if not matching_keys:
|
||||
return []
|
||||
return [k.decode("utf-8") for k in matching_keys]
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
client = self._client()
|
||||
cursor = 0
|
||||
pattern = start_key + "*"
|
||||
result: list[str] = []
|
||||
while True:
|
||||
cursor, keys = await client.scan(cursor, match=pattern, count=1000)
|
||||
for key in keys:
|
||||
key_str = key.decode("utf-8") if isinstance(key, bytes) else str(key)
|
||||
if start_key <= key_str <= end_key:
|
||||
result.append(key_str)
|
||||
if cursor == 0:
|
||||
break
|
||||
return result
|
||||
|
|
@ -10,8 +10,8 @@ from datetime import datetime
|
|||
import aiosqlite
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from ..api import KVStore
|
||||
from ..config import SqliteKVStoreConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
17
src/llama_stack/core/storage/sqlstore/__init__.py
Normal file
17
src/llama_stack/core/storage/sqlstore/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack_api.internal.sqlstore import (
|
||||
ColumnDefinition as ColumnDefinition,
|
||||
)
|
||||
from llama_stack_api.internal.sqlstore import (
|
||||
ColumnType as ColumnType,
|
||||
)
|
||||
from llama_stack_api.internal.sqlstore import (
|
||||
SqlStore as SqlStore,
|
||||
)
|
||||
|
||||
from .sqlstore import * # noqa: F401,F403
|
||||
|
|
@ -14,8 +14,8 @@ from llama_stack.core.datatypes import User
|
|||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.core.storage.datatypes import StorageBackendType
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
||||
from llama_stack_api import PaginatedResponse
|
||||
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType, SqlStore
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
|
@ -29,8 +29,7 @@ from sqlalchemy.sql.elements import ColumnElement
|
|||
from llama_stack.core.storage.datatypes import SqlAlchemySqlStoreConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import PaginatedResponse
|
||||
|
||||
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType, SqlStore
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
|
@ -16,8 +16,7 @@ from llama_stack.core.storage.datatypes import (
|
|||
StorageBackendConfig,
|
||||
StorageBackendType,
|
||||
)
|
||||
|
||||
from .api import SqlStore
|
||||
from llama_stack_api.internal.sqlstore import SqlStore
|
||||
|
||||
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
||||
|
||||
|
|
@ -12,8 +12,8 @@ import pydantic
|
|||
|
||||
from llama_stack.core.datatypes import RoutableObjectWithProvider
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
|
||||
logger = get_logger(__name__, category="core::registry")
|
||||
|
||||
|
|
|
|||
|
|
@ -17,44 +17,43 @@ providers:
|
|||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
base_url: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
base_url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
base_url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
base_url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
base_url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
|
|
@ -76,18 +75,18 @@ providers:
|
|||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
base_url: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
base_url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
base_url: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
|
|
|
|||
|
|
@ -17,44 +17,43 @@ providers:
|
|||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
base_url: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
base_url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
base_url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
base_url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
base_url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
|
|
@ -76,18 +75,18 @@ providers:
|
|||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
base_url: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
base_url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
base_url: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
|
|
|
|||
|
|
@ -16,9 +16,8 @@ providers:
|
|||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -16,9 +16,8 @@ providers:
|
|||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
|
|
|
|||
|
|
@ -27,12 +27,12 @@ providers:
|
|||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
base_url: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
base_url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
vector_io:
|
||||
- provider_id: sqlite-vec
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ providers:
|
|||
- provider_id: vllm-inference
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=http://localhost:8000/v1}
|
||||
base_url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
|
|
|
|||
|
|
@ -17,44 +17,43 @@ providers:
|
|||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
base_url: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
base_url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
base_url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
base_url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
base_url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
|
|
@ -76,18 +75,18 @@ providers:
|
|||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
base_url: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
base_url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
base_url: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
|
|
|
|||
|
|
@ -17,44 +17,43 @@ providers:
|
|||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
base_url: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
base_url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
base_url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
base_url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
base_url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
|
|
@ -76,18 +75,18 @@ providers:
|
|||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
base_url: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
base_url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
base_url: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
|
|
|
|||
|
|
@ -17,44 +17,43 @@ providers:
|
|||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
base_url: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
base_url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
base_url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
base_url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
base_url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
|
|
@ -76,18 +75,18 @@ providers:
|
|||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
base_url: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
base_url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
base_url: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
|
|
|
|||
|
|
@ -17,44 +17,43 @@ providers:
|
|||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
base_url: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
base_url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
base_url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
base_url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
base_url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
|
|
@ -76,18 +75,18 @@ providers:
|
|||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
base_url: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
base_url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
base_url: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ from llama_stack.core.datatypes import (
|
|||
ToolGroupInput,
|
||||
VectorStoresConfig,
|
||||
)
|
||||
from llama_stack.core.storage.kvstore.config import PostgresKVStoreConfig
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
|
|
@ -35,8 +37,6 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
|
|||
)
|
||||
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
from llama_stack_api import RemoteProviderSpec
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -35,13 +35,13 @@ from llama_stack.core.storage.datatypes import (
|
|||
SqlStoreReference,
|
||||
StorageBackendType,
|
||||
)
|
||||
from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.kvstore.config import get_pip_packages as get_kv_pip_packages
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import get_pip_packages as get_kv_pip_packages
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
|
||||
from llama_stack_api import DatasetPurpose, ModelType
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ providers:
|
|||
- provider_id: watsonx
|
||||
provider_type: remote::watsonx
|
||||
config:
|
||||
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
||||
base_url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
||||
api_key: ${env.WATSONX_API_KEY:=}
|
||||
project_id: ${env.WATSONX_PROJECT_ID:=}
|
||||
vector_io:
|
||||
|
|
|
|||
|
|
@ -23,12 +23,14 @@ async def get_provider_impl(
|
|||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.vector_io],
|
||||
deps[Api.safety],
|
||||
deps.get(Api.safety),
|
||||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
deps[Api.conversations],
|
||||
policy,
|
||||
deps[Api.prompts],
|
||||
deps[Api.files],
|
||||
telemetry_enabled,
|
||||
policy,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -6,12 +6,13 @@
|
|||
|
||||
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.storage.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack_api import (
|
||||
Agents,
|
||||
Conversations,
|
||||
Files,
|
||||
Inference,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
|
|
@ -22,6 +23,7 @@ from llama_stack_api import (
|
|||
OpenAIResponsePrompt,
|
||||
OpenAIResponseText,
|
||||
Order,
|
||||
Prompts,
|
||||
ResponseGuardrail,
|
||||
Safety,
|
||||
ToolGroups,
|
||||
|
|
@ -41,10 +43,12 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
config: MetaReferenceAgentsImplConfig,
|
||||
inference_api: Inference,
|
||||
vector_io_api: VectorIO,
|
||||
safety_api: Safety,
|
||||
safety_api: Safety | None,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
conversations_api: Conversations,
|
||||
prompts_api: Prompts,
|
||||
files_api: Files,
|
||||
policy: list[AccessRule],
|
||||
telemetry_enabled: bool = False,
|
||||
):
|
||||
|
|
@ -56,7 +60,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.tool_groups_api = tool_groups_api
|
||||
self.conversations_api = conversations_api
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
|
||||
self.prompts_api = prompts_api
|
||||
self.files_api = files_api
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
||||
self.policy = policy
|
||||
|
|
@ -73,6 +78,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
vector_io_api=self.vector_io_api,
|
||||
safety_api=self.safety_api,
|
||||
conversations_api=self.conversations_api,
|
||||
prompts_api=self.prompts_api,
|
||||
files_api=self.files_api,
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
|
@ -92,6 +99,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
model: str,
|
||||
prompt: OpenAIResponsePrompt | None = None,
|
||||
instructions: str | None = None,
|
||||
parallel_tool_calls: bool | None = True,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
|
|
@ -120,6 +128,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
include,
|
||||
max_infer_iters,
|
||||
guardrails,
|
||||
parallel_tool_calls,
|
||||
max_tool_calls,
|
||||
)
|
||||
return result # type: ignore[no-any-return]
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
|
@ -18,13 +19,17 @@ from llama_stack.providers.utils.responses.responses_store import (
|
|||
from llama_stack_api import (
|
||||
ConversationItem,
|
||||
Conversations,
|
||||
Files,
|
||||
Inference,
|
||||
InvalidConversationIdError,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputMessageContentFile,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMessage,
|
||||
|
|
@ -34,7 +39,9 @@ from llama_stack_api import (
|
|||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
Order,
|
||||
Prompts,
|
||||
ResponseGuardrailSpec,
|
||||
Safety,
|
||||
ToolGroups,
|
||||
|
|
@ -46,6 +53,7 @@ from .streaming import StreamingResponseOrchestrator
|
|||
from .tool_executor import ToolExecutor
|
||||
from .types import ChatCompletionContext, ToolContext
|
||||
from .utils import (
|
||||
convert_response_content_to_chat_content,
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
extract_guardrail_ids,
|
||||
|
|
@ -67,8 +75,10 @@ class OpenAIResponsesImpl:
|
|||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
vector_io_api: VectorIO, # VectorIO
|
||||
safety_api: Safety,
|
||||
safety_api: Safety | None,
|
||||
conversations_api: Conversations,
|
||||
prompts_api: Prompts,
|
||||
files_api: Files,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
|
|
@ -82,6 +92,8 @@ class OpenAIResponsesImpl:
|
|||
tool_runtime_api=tool_runtime_api,
|
||||
vector_io_api=vector_io_api,
|
||||
)
|
||||
self.prompts_api = prompts_api
|
||||
self.files_api = files_api
|
||||
|
||||
async def _prepend_previous_response(
|
||||
self,
|
||||
|
|
@ -122,11 +134,13 @@ class OpenAIResponsesImpl:
|
|||
# Use stored messages directly and convert only new input
|
||||
message_adapter = TypeAdapter(list[OpenAIMessageParam])
|
||||
messages = message_adapter.validate_python(previous_response.messages)
|
||||
new_messages = await convert_response_input_to_chat_messages(input, previous_messages=messages)
|
||||
new_messages = await convert_response_input_to_chat_messages(
|
||||
input, previous_messages=messages, files_api=self.files_api
|
||||
)
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
# Backward compatibility: reconstruct from inputs
|
||||
messages = await convert_response_input_to_chat_messages(all_input)
|
||||
messages = await convert_response_input_to_chat_messages(all_input, files_api=self.files_api)
|
||||
|
||||
tool_context.recover_tools_from_previous_response(previous_response)
|
||||
elif conversation is not None:
|
||||
|
|
@ -138,7 +152,7 @@ class OpenAIResponsesImpl:
|
|||
all_input = input
|
||||
if not conversation_items.data:
|
||||
# First turn - just convert the new input
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
messages = await convert_response_input_to_chat_messages(input, files_api=self.files_api)
|
||||
else:
|
||||
if not stored_messages:
|
||||
all_input = conversation_items.data
|
||||
|
|
@ -154,14 +168,82 @@ class OpenAIResponsesImpl:
|
|||
all_input = input
|
||||
|
||||
messages = stored_messages or []
|
||||
new_messages = await convert_response_input_to_chat_messages(all_input, previous_messages=messages)
|
||||
new_messages = await convert_response_input_to_chat_messages(
|
||||
all_input, previous_messages=messages, files_api=self.files_api
|
||||
)
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
all_input = input
|
||||
messages = await convert_response_input_to_chat_messages(all_input)
|
||||
messages = await convert_response_input_to_chat_messages(all_input, files_api=self.files_api)
|
||||
|
||||
return all_input, messages, tool_context
|
||||
|
||||
async def _prepend_prompt(
|
||||
self,
|
||||
messages: list[OpenAIMessageParam],
|
||||
openai_response_prompt: OpenAIResponsePrompt | None,
|
||||
) -> None:
|
||||
"""Prepend prompt template to messages, resolving text/image/file variables.
|
||||
|
||||
:param messages: List of OpenAIMessageParam objects
|
||||
:param openai_response_prompt: (Optional) OpenAIResponsePrompt object with variables
|
||||
:returns: string of utf-8 characters
|
||||
"""
|
||||
if not openai_response_prompt or not openai_response_prompt.id:
|
||||
return
|
||||
|
||||
prompt_version = int(openai_response_prompt.version) if openai_response_prompt.version else None
|
||||
cur_prompt = await self.prompts_api.get_prompt(openai_response_prompt.id, prompt_version)
|
||||
|
||||
if not cur_prompt or not cur_prompt.prompt:
|
||||
return
|
||||
|
||||
cur_prompt_text = cur_prompt.prompt
|
||||
cur_prompt_variables = cur_prompt.variables
|
||||
|
||||
if not openai_response_prompt.variables:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=cur_prompt_text))
|
||||
return
|
||||
|
||||
# Validate that all provided variables exist in the prompt
|
||||
for name in openai_response_prompt.variables.keys():
|
||||
if name not in cur_prompt_variables:
|
||||
raise ValueError(f"Variable {name} not found in prompt {openai_response_prompt.id}")
|
||||
|
||||
# Separate text and media variables
|
||||
text_substitutions = {}
|
||||
media_content_parts: list[OpenAIChatCompletionContentPartParam] = []
|
||||
|
||||
for name, value in openai_response_prompt.variables.items():
|
||||
# Text variable found
|
||||
if isinstance(value, OpenAIResponseInputMessageContentText):
|
||||
text_substitutions[name] = value.text
|
||||
|
||||
# Media variable found
|
||||
elif isinstance(value, OpenAIResponseInputMessageContentImage | OpenAIResponseInputMessageContentFile):
|
||||
converted_parts = await convert_response_content_to_chat_content([value], files_api=self.files_api)
|
||||
if isinstance(converted_parts, list):
|
||||
media_content_parts.extend(converted_parts)
|
||||
|
||||
# Eg: {{product_photo}} becomes "[Image: product_photo]"
|
||||
# This gives the model textual context about what media exists in the prompt
|
||||
var_type = value.type.replace("input_", "").replace("_", " ").title()
|
||||
text_substitutions[name] = f"[{var_type}: {name}]"
|
||||
|
||||
def replace_variable(match: re.Match[str]) -> str:
|
||||
var_name = match.group(1).strip()
|
||||
return str(text_substitutions.get(var_name, match.group(0)))
|
||||
|
||||
pattern = r"\{\{\s*(\w+)\s*\}\}"
|
||||
processed_prompt_text = re.sub(pattern, replace_variable, cur_prompt_text)
|
||||
|
||||
# Insert system message with resolved text
|
||||
messages.insert(0, OpenAISystemMessageParam(content=processed_prompt_text))
|
||||
|
||||
# If we have media, create a new user message because allows to ingest images and files
|
||||
if media_content_parts:
|
||||
messages.append(OpenAIUserMessageParam(content=media_content_parts))
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
response_id: str,
|
||||
|
|
@ -252,6 +334,7 @@ class OpenAIResponsesImpl:
|
|||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[str | ResponseGuardrailSpec] | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
max_tool_calls: int | None = None,
|
||||
):
|
||||
stream = bool(stream)
|
||||
|
|
@ -272,6 +355,14 @@ class OpenAIResponsesImpl:
|
|||
|
||||
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
|
||||
|
||||
# Validate that Safety API is available if guardrails are requested
|
||||
if guardrail_ids and self.safety_api is None:
|
||||
raise ValueError(
|
||||
"Cannot process guardrails: Safety API is not configured.\n\n"
|
||||
"To use guardrails, ensure the Safety API is configured in your stack, or remove "
|
||||
"the 'guardrails' parameter from your request."
|
||||
)
|
||||
|
||||
if conversation is not None:
|
||||
if previous_response_id is not None:
|
||||
raise ValueError(
|
||||
|
|
@ -288,6 +379,7 @@ class OpenAIResponsesImpl:
|
|||
input=input,
|
||||
conversation=conversation,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
instructions=instructions,
|
||||
previous_response_id=previous_response_id,
|
||||
store=store,
|
||||
|
|
@ -296,6 +388,7 @@ class OpenAIResponsesImpl:
|
|||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
guardrail_ids=guardrail_ids,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
max_tool_calls=max_tool_calls,
|
||||
)
|
||||
|
||||
|
|
@ -340,12 +433,14 @@ class OpenAIResponsesImpl:
|
|||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
prompt: OpenAIResponsePrompt | None = None,
|
||||
store: bool | None = True,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
parallel_tool_calls: bool | None = True,
|
||||
max_tool_calls: int | None = None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# These should never be None when called from create_openai_response (which sets defaults)
|
||||
|
|
@ -361,6 +456,9 @@ class OpenAIResponsesImpl:
|
|||
if instructions:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=instructions))
|
||||
|
||||
# Prepend reusable prompt (if provided)
|
||||
await self._prepend_prompt(messages, prompt)
|
||||
|
||||
# Structured outputs
|
||||
response_format = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
|
|
@ -383,8 +481,10 @@ class OpenAIResponsesImpl:
|
|||
ctx=ctx,
|
||||
response_id=response_id,
|
||||
created_at=created_at,
|
||||
prompt=prompt,
|
||||
text=text,
|
||||
max_infer_iters=max_infer_iters,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
tool_executor=self.tool_executor,
|
||||
safety_api=self.safety_api,
|
||||
guardrail_ids=guardrail_ids,
|
||||
|
|
|
|||
|
|
@ -66,6 +66,8 @@ from llama_stack_api import (
|
|||
OpenAIResponseUsage,
|
||||
OpenAIResponseUsageInputTokensDetails,
|
||||
OpenAIResponseUsageOutputTokensDetails,
|
||||
OpenAIToolMessageParam,
|
||||
Safety,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
|
||||
|
|
@ -111,9 +113,10 @@ class StreamingResponseOrchestrator:
|
|||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
instructions: str | None,
|
||||
safety_api,
|
||||
safety_api: Safety | None,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
prompt: OpenAIResponsePrompt | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
max_tool_calls: int | None = None,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
|
|
@ -128,6 +131,8 @@ class StreamingResponseOrchestrator:
|
|||
self.prompt = prompt
|
||||
# System message that is inserted into the model's context
|
||||
self.instructions = instructions
|
||||
# Whether to allow more than one function tool call generated per turn.
|
||||
self.parallel_tool_calls = parallel_tool_calls
|
||||
# Max number of total calls to built-in tools that can be processed in a response
|
||||
self.max_tool_calls = max_tool_calls
|
||||
self.sequence_number = 0
|
||||
|
|
@ -190,6 +195,7 @@ class StreamingResponseOrchestrator:
|
|||
usage=self.accumulated_usage,
|
||||
instructions=self.instructions,
|
||||
prompt=self.prompt,
|
||||
parallel_tool_calls=self.parallel_tool_calls,
|
||||
max_tool_calls=self.max_tool_calls,
|
||||
)
|
||||
|
||||
|
|
@ -901,10 +907,16 @@ class StreamingResponseOrchestrator:
|
|||
"""Coordinate execution of both function and non-function tool calls."""
|
||||
# Execute non-function tool calls
|
||||
for tool_call in non_function_tool_calls:
|
||||
# Check if total calls made to built-in and mcp tools exceed max_tool_calls
|
||||
# if total calls made to built-in and mcp tools exceed max_tool_calls
|
||||
# then create a tool response message indicating the call was skipped
|
||||
if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls:
|
||||
logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.")
|
||||
break
|
||||
skipped_call_message = OpenAIToolMessageParam(
|
||||
content=f"Tool call skipped: maximum tool calls limit ({self.max_tool_calls}) reached.",
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
next_turn_messages.append(skipped_call_message)
|
||||
continue
|
||||
|
||||
# Find the item_id for this tool call
|
||||
matching_item_id = None
|
||||
|
|
|
|||
|
|
@ -5,11 +5,14 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import mimetypes
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
|
||||
from llama_stack_api import (
|
||||
Files,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
|
|
@ -18,6 +21,8 @@ from llama_stack_api import (
|
|||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIFile,
|
||||
OpenAIFileFile,
|
||||
OpenAIImageURL,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIMessageParam,
|
||||
|
|
@ -29,6 +34,7 @@ from llama_stack_api import (
|
|||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContent,
|
||||
OpenAIResponseInputMessageContentFile,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
|
|
@ -37,9 +43,11 @@ from llama_stack_api import (
|
|||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
|
|
@ -49,6 +57,46 @@ from llama_stack_api import (
|
|||
)
|
||||
|
||||
|
||||
async def extract_bytes_from_file(file_id: str, files_api: Files) -> bytes:
|
||||
"""
|
||||
Extract raw bytes from file using the Files API.
|
||||
|
||||
:param file_id: The file identifier (e.g., "file-abc123")
|
||||
:param files_api: Files API instance
|
||||
:returns: Raw file content as bytes
|
||||
:raises: ValueError if file cannot be retrieved
|
||||
"""
|
||||
try:
|
||||
response = await files_api.openai_retrieve_file_content(file_id)
|
||||
return bytes(response.body)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to retrieve file content for file_id '{file_id}': {str(e)}") from e
|
||||
|
||||
|
||||
def generate_base64_ascii_text_from_bytes(raw_bytes: bytes) -> str:
|
||||
"""
|
||||
Converts raw binary bytes into a safe ASCII text representation for URLs
|
||||
|
||||
:param raw_bytes: the actual bytes that represents file content
|
||||
:returns: string of utf-8 characters
|
||||
"""
|
||||
return base64.b64encode(raw_bytes).decode("utf-8")
|
||||
|
||||
|
||||
def construct_data_url(ascii_text: str, mime_type: str | None) -> str:
|
||||
"""
|
||||
Construct data url with decoded data inside
|
||||
|
||||
:param ascii_text: ASCII content
|
||||
:param mime_type: MIME type of file
|
||||
:returns: data url string (eg. data:image/png,base64,%3Ch1%3EHello%2C%20World%21%3C%2Fh1%3E)
|
||||
"""
|
||||
if not mime_type:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
return f"data:{mime_type};base64,{ascii_text}"
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(
|
||||
choice: OpenAIChoice,
|
||||
citation_files: dict[str, str] | None = None,
|
||||
|
|
@ -78,11 +126,15 @@ async def convert_chat_choice_to_response_message(
|
|||
|
||||
async def convert_response_content_to_chat_content(
|
||||
content: str | Sequence[OpenAIResponseInputMessageContent | OpenAIResponseOutputMessageContent],
|
||||
files_api: Files | None,
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
"""
|
||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||
|
||||
The content schemas of each API look similar, but are not exactly the same.
|
||||
|
||||
:param content: The content to convert
|
||||
:param files_api: Files API for resolving file_id to raw file content (required if content contains files/images)
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
|
@ -95,9 +147,68 @@ async def convert_response_content_to_chat_content(
|
|||
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
|
||||
detail = content_part.detail
|
||||
image_mime_type = None
|
||||
if content_part.image_url:
|
||||
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
|
||||
image_url = OpenAIImageURL(url=content_part.image_url, detail=detail)
|
||||
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
elif content_part.file_id:
|
||||
if files_api is None:
|
||||
raise ValueError("file_ids are not supported by this implementation of the Stack")
|
||||
image_file_response = await files_api.openai_retrieve_file(content_part.file_id)
|
||||
if image_file_response.filename:
|
||||
image_mime_type, _ = mimetypes.guess_type(image_file_response.filename)
|
||||
raw_image_bytes = await extract_bytes_from_file(content_part.file_id, files_api)
|
||||
ascii_text = generate_base64_ascii_text_from_bytes(raw_image_bytes)
|
||||
image_data_url = construct_data_url(ascii_text, image_mime_type)
|
||||
image_url = OpenAIImageURL(url=image_data_url, detail=detail)
|
||||
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Image content must have either 'image_url' or 'file_id'. "
|
||||
f"Got image_url={content_part.image_url}, file_id={content_part.file_id}"
|
||||
)
|
||||
elif isinstance(content_part, OpenAIResponseInputMessageContentFile):
|
||||
resolved_file_data = None
|
||||
file_data = content_part.file_data
|
||||
file_id = content_part.file_id
|
||||
file_url = content_part.file_url
|
||||
filename = content_part.filename
|
||||
file_mime_type = None
|
||||
if not any([file_data, file_id, file_url]):
|
||||
raise ValueError(
|
||||
f"File content must have at least one of 'file_data', 'file_id', or 'file_url'. "
|
||||
f"Got file_data={file_data}, file_id={file_id}, file_url={file_url}"
|
||||
)
|
||||
if file_id:
|
||||
if files_api is None:
|
||||
raise ValueError("file_ids are not supported by this implementation of the Stack")
|
||||
|
||||
file_response = await files_api.openai_retrieve_file(file_id)
|
||||
if not filename:
|
||||
filename = file_response.filename
|
||||
file_mime_type, _ = mimetypes.guess_type(file_response.filename)
|
||||
raw_file_bytes = await extract_bytes_from_file(file_id, files_api)
|
||||
ascii_text = generate_base64_ascii_text_from_bytes(raw_file_bytes)
|
||||
resolved_file_data = construct_data_url(ascii_text, file_mime_type)
|
||||
elif file_data:
|
||||
if file_data.startswith("data:"):
|
||||
resolved_file_data = file_data
|
||||
else:
|
||||
# Raw base64 data, wrap in data URL format
|
||||
if filename:
|
||||
file_mime_type, _ = mimetypes.guess_type(filename)
|
||||
resolved_file_data = construct_data_url(file_data, file_mime_type)
|
||||
elif file_url:
|
||||
resolved_file_data = file_url
|
||||
converted_parts.append(
|
||||
OpenAIFile(
|
||||
file=OpenAIFileFile(
|
||||
file_data=resolved_file_data,
|
||||
filename=filename,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(content_part, str):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
|
||||
else:
|
||||
|
|
@ -110,12 +221,14 @@ async def convert_response_content_to_chat_content(
|
|||
async def convert_response_input_to_chat_messages(
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_messages: list[OpenAIMessageParam] | None = None,
|
||||
files_api: Files | None = None,
|
||||
) -> list[OpenAIMessageParam]:
|
||||
"""
|
||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||
|
||||
:param input: The input to convert
|
||||
:param previous_messages: Optional previous messages to check for function_call references
|
||||
:param files_api: Files API for resolving file_id to raw file content (optional, required for file/image content)
|
||||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if isinstance(input, list):
|
||||
|
|
@ -169,6 +282,12 @@ async def convert_response_input_to_chat_messages(
|
|||
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
|
||||
# the tool list will be handled separately
|
||||
pass
|
||||
elif isinstance(
|
||||
input_item,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
):
|
||||
# these tool calls are tracked internally but not converted to chat messages
|
||||
pass
|
||||
elif isinstance(input_item, OpenAIResponseMCPApprovalRequest) or isinstance(
|
||||
input_item, OpenAIResponseMCPApprovalResponse
|
||||
):
|
||||
|
|
@ -176,7 +295,7 @@ async def convert_response_input_to_chat_messages(
|
|||
pass
|
||||
elif isinstance(input_item, OpenAIResponseMessage):
|
||||
# Narrow type to OpenAIResponseMessage which has content and role attributes
|
||||
content = await convert_response_content_to_chat_content(input_item.content)
|
||||
content = await convert_response_content_to_chat_content(input_item.content, files_api)
|
||||
message_type = await get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
raise ValueError(
|
||||
|
|
@ -320,11 +439,15 @@ def is_function_tool_call(
|
|||
return False
|
||||
|
||||
|
||||
async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
|
||||
async def run_guardrails(safety_api: Safety | None, messages: str, guardrail_ids: list[str]) -> str | None:
|
||||
"""Run guardrails against messages and return violation message if blocked."""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# If safety API is not available, skip guardrails
|
||||
if safety_api is None:
|
||||
return None
|
||||
|
||||
# Look up shields to get their provider_resource_id (actual model ID)
|
||||
model_ids = []
|
||||
# TODO: list_shields not in Safety interface but available at runtime via API routing
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import AccessRule, Api
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack_api import Files, Inference, Models
|
||||
|
||||
from .batches import ReferenceBatchesImpl
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ from typing import Any
|
|||
from openai.types.batch import BatchError, Errors
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.core.storage.kvstore import KVStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack_api import (
|
||||
Batches,
|
||||
BatchObject,
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
from llama_stack_api import Dataset, DatasetIO, DatasetsProtocolPrivate, PaginatedResponse
|
||||
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from typing import Any
|
|||
|
||||
from tqdm import tqdm
|
||||
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack_api import (
|
||||
Agents,
|
||||
Benchmark,
|
||||
|
|
|
|||
|
|
@ -13,11 +13,10 @@ from fastapi import Depends, File, Form, Response, UploadFile
|
|||
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack_api import (
|
||||
ExpiresAfter,
|
||||
Files,
|
||||
|
|
@ -28,6 +27,7 @@ from llama_stack_api import (
|
|||
Order,
|
||||
ResourceNotFoundError,
|
||||
)
|
||||
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
|
||||
|
||||
from .config import LocalfsFilesImplConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -14,9 +14,8 @@ import faiss # type: ignore[import-untyped]
|
|||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||
from llama_stack_api import (
|
||||
|
|
@ -32,6 +31,7 @@ from llama_stack_api import (
|
|||
VectorStoreNotFoundError,
|
||||
VectorStoresProtocolPrivate,
|
||||
)
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -14,9 +14,8 @@ import numpy as np
|
|||
import sqlite_vec # type: ignore[import-untyped]
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
|
|
@ -35,6 +34,7 @@ from llama_stack_api import (
|
|||
VectorStoreNotFoundError,
|
||||
VectorStoresProtocolPrivate,
|
||||
)
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.utils.kvstore import kvstore_dependencies
|
||||
from llama_stack.core.storage.kvstore import kvstore_dependencies
|
||||
from llama_stack_api import (
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
|
|
@ -30,11 +30,15 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
Api.safety,
|
||||
Api.vector_io,
|
||||
Api.tool_runtime,
|
||||
Api.tool_groups,
|
||||
Api.conversations,
|
||||
Api.prompts,
|
||||
Api.files,
|
||||
],
|
||||
optional_api_dependencies=[
|
||||
Api.safety,
|
||||
],
|
||||
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import sql_store_pip_packages
|
||||
from llama_stack_api import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
from llama_stack_api import Dataset, DatasetIO, DatasetsProtocolPrivate, PaginatedResponse
|
||||
|
||||
|
|
|
|||
|
|
@ -10,10 +10,9 @@ from typing import Annotated, Any
|
|||
from fastapi import Depends, File, Form, Response, UploadFile
|
||||
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack_api import (
|
||||
ExpiresAfter,
|
||||
Files,
|
||||
|
|
@ -24,6 +23,7 @@ from llama_stack_api import (
|
|||
Order,
|
||||
ResourceNotFoundError,
|
||||
)
|
||||
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
|
||||
from openai import OpenAI
|
||||
|
||||
from .config import OpenAIFilesImplConfig
|
||||
|
|
|
|||
|
|
@ -19,10 +19,9 @@ if TYPE_CHECKING:
|
|||
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack_api import (
|
||||
ExpiresAfter,
|
||||
Files,
|
||||
|
|
@ -33,6 +32,7 @@ from llama_stack_api import (
|
|||
Order,
|
||||
ResourceNotFoundError,
|
||||
)
|
||||
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AzureConfig
|
||||
|
|
@ -22,4 +20,4 @@ class AzureInferenceAdapter(OpenAIMixin):
|
|||
|
||||
Returns the Azure API base URL from the configuration.
|
||||
"""
|
||||
return urljoin(str(self.config.api_base), "/openai/v1")
|
||||
return str(self.config.base_url)
|
||||
|
|
|
|||
|
|
@ -32,8 +32,9 @@ class AzureProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AzureConfig(RemoteInferenceProviderConfig):
|
||||
api_base: HttpUrl = Field(
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||
base_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com/openai/v1)",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
|
||||
|
|
@ -48,14 +49,14 @@ class AzureConfig(RemoteInferenceProviderConfig):
|
|||
def sample_run_config(
|
||||
cls,
|
||||
api_key: str = "${env.AZURE_API_KEY:=}",
|
||||
api_base: str = "${env.AZURE_API_BASE:=}",
|
||||
base_url: str = "${env.AZURE_API_BASE:=}",
|
||||
api_version: str = "${env.AZURE_API_VERSION:=}",
|
||||
api_type: str = "${env.AZURE_API_TYPE:=}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"base_url": base_url,
|
||||
"api_version": api_version,
|
||||
"api_type": api_type,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class BedrockInferenceAdapter(OpenAIMixin):
|
|||
"""
|
||||
|
||||
config: BedrockConfig
|
||||
provider_data_api_key_field: str = "aws_bedrock_api_key"
|
||||
provider_data_api_key_field: str = "aws_bearer_token_bedrock"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get base URL for OpenAI client."""
|
||||
|
|
@ -111,7 +111,7 @@ class BedrockInferenceAdapter(OpenAIMixin):
|
|||
logger.error(f"AWS Bedrock authentication token expired: {error_msg}")
|
||||
raise ValueError(
|
||||
"AWS Bedrock authentication failed: Bearer token has expired. "
|
||||
"The AWS_BEDROCK_API_KEY environment variable contains an expired pre-signed URL. "
|
||||
"The AWS_BEARER_TOKEN_BEDROCK environment variable contains an expired pre-signed URL. "
|
||||
"Please refresh your token by generating a new pre-signed URL with AWS credentials. "
|
||||
"Refer to AWS Bedrock documentation for details on OpenAI-compatible endpoints."
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ from llama_stack.providers.utils.inference.model_registry import RemoteInference
|
|||
|
||||
|
||||
class BedrockProviderDataValidator(BaseModel):
|
||||
aws_bedrock_api_key: str | None = Field(
|
||||
aws_bearer_token_bedrock: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Amazon Bedrock",
|
||||
description="API Key (Bearer token) for Amazon Bedrock",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -27,6 +27,6 @@ class BedrockConfig(RemoteInferenceProviderConfig):
|
|||
@classmethod
|
||||
def sample_run_config(cls, **kwargs):
|
||||
return {
|
||||
"api_key": "${env.AWS_BEDROCK_API_KEY:=}",
|
||||
"api_key": "${env.AWS_BEARER_TOKEN_BEDROCK:=}",
|
||||
"region_name": "${env.AWS_DEFAULT_REGION:=us-east-2}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,8 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack_api import (
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
|
|
@ -21,7 +19,7 @@ class CerebrasInferenceAdapter(OpenAIMixin):
|
|||
provider_data_api_key_field: str = "cerebras_api_key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return urljoin(self.config.base_url, "v1")
|
||||
return str(self.config.base_url)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack_api import json_schema_type
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||
DEFAULT_BASE_URL = "https://api.cerebras.ai/v1"
|
||||
|
||||
|
||||
class CerebrasProviderDataValidator(BaseModel):
|
||||
|
|
@ -24,8 +24,8 @@ class CerebrasProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
||||
base_url: str = Field(
|
||||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||
base_url: HttpUrl | None = Field(
|
||||
default=HttpUrl(os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL)),
|
||||
description="Base URL for the Cerebras API",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack_api import json_schema_type
|
||||
|
|
@ -21,9 +21,9 @@ class DatabricksProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
base_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
description="The URL for the Databricks model serving endpoint (should include /serving-endpoints path)",
|
||||
)
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
|
|
@ -34,11 +34,11 @@ class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
|||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.DATABRICKS_HOST:=}",
|
||||
base_url: str = "${env.DATABRICKS_HOST:=}",
|
||||
api_token: str = "${env.DATABRICKS_TOKEN:=}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"base_url": base_url,
|
||||
"api_token": api_token,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,15 +29,21 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
|||
}
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/serving-endpoints"
|
||||
return str(self.config.base_url)
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
# Filter out None values from endpoint names
|
||||
api_token = self._get_api_key_from_config_or_provider_data()
|
||||
# WorkspaceClient expects base host without /serving-endpoints suffix
|
||||
base_url_str = str(self.config.base_url)
|
||||
if base_url_str.endswith("/serving-endpoints"):
|
||||
host = base_url_str[:-18] # Remove '/serving-endpoints'
|
||||
else:
|
||||
host = base_url_str
|
||||
return [
|
||||
endpoint.name # type: ignore[misc]
|
||||
for endpoint in WorkspaceClient(
|
||||
host=self.config.url, token=api_token
|
||||
host=host, token=api_token
|
||||
).serving_endpoints.list() # TODO: this is not async
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, HttpUrl
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack_api import json_schema_type
|
||||
|
|
@ -14,14 +14,14 @@ from llama_stack_api import json_schema_type
|
|||
|
||||
@json_schema_type
|
||||
class FireworksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.fireworks.ai/inference/v1",
|
||||
base_url: HttpUrl | None = Field(
|
||||
default=HttpUrl("https://api.fireworks.ai/inference/v1"),
|
||||
description="The URL for the Fireworks server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.fireworks.ai/inference/v1",
|
||||
"base_url": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,4 +24,4 @@ class FireworksInferenceAdapter(OpenAIMixin):
|
|||
provider_data_api_key_field: str = "fireworks_api_key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return "https://api.fireworks.ai/inference/v1"
|
||||
return str(self.config.base_url)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack_api import json_schema_type
|
||||
|
|
@ -21,14 +21,14 @@ class GroqProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class GroqConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.groq.com",
|
||||
base_url: HttpUrl | None = Field(
|
||||
default=HttpUrl("https://api.groq.com/openai/v1"),
|
||||
description="The URL for the Groq AI server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.groq.com",
|
||||
"base_url": "https://api.groq.com/openai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,4 +15,4 @@ class GroqInferenceAdapter(OpenAIMixin):
|
|||
provider_data_api_key_field: str = "groq_api_key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/openai/v1"
|
||||
return str(self.config.base_url)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack_api import json_schema_type
|
||||
|
|
@ -21,14 +21,14 @@ class LlamaProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class LlamaCompatConfig(RemoteInferenceProviderConfig):
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.llama.com/compat/v1/",
|
||||
base_url: HttpUrl | None = Field(
|
||||
default=HttpUrl("https://api.llama.com/compat/v1/"),
|
||||
description="The URL for the Llama API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.llama.com/compat/v1/",
|
||||
"base_url": "https://api.llama.com/compat/v1/",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
|
|||
|
||||
:return: The Llama API base URL
|
||||
"""
|
||||
return self.config.openai_compat_api_base
|
||||
return str(self.config.base_url)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack_api import json_schema_type
|
||||
|
|
@ -44,18 +44,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
URL of your running NVIDIA NIM and do not need to set the api_key.
|
||||
"""
|
||||
|
||||
url: str = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
|
||||
base_url: HttpUrl | None = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com/v1"),
|
||||
description="A base url for accessing the NVIDIA NIM",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60,
|
||||
description="Timeout for the HTTP requests",
|
||||
)
|
||||
append_api_version: bool = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||
)
|
||||
rerank_model_to_url: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
||||
|
|
@ -68,13 +64,11 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}",
|
||||
base_url: HttpUrl | None = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}",
|
||||
api_key: str = "${env.NVIDIA_API_KEY:=}",
|
||||
append_api_version: bool = "${env.NVIDIA_APPEND_API_VERSION:=True}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"base_url": base_url,
|
||||
"api_key": api_key,
|
||||
"append_api_version": append_api_version,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.base_url})...")
|
||||
|
||||
if _is_nvidia_hosted(self.config):
|
||||
if not self.config.auth_credential:
|
||||
|
|
@ -72,7 +72,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
|
||||
:return: The NVIDIA API base URL
|
||||
"""
|
||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||
return str(self.config.base_url)
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -8,4 +8,4 @@ from . import NVIDIAConfig
|
|||
|
||||
|
||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||
return "integrate.api.nvidia.com" in config.url
|
||||
return "integrate.api.nvidia.com" in str(config.base_url)
|
||||
|
|
|
|||
|
|
@ -6,20 +6,22 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import Field, HttpUrl, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434/v1"
|
||||
|
||||
|
||||
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
base_url: HttpUrl | None = Field(default=HttpUrl(DEFAULT_OLLAMA_URL))
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(
|
||||
cls, base_url: str = "${env.OLLAMA_URL:=http://localhost:11434/v1}", **kwargs
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"base_url": base_url,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,17 +55,23 @@ class OllamaInferenceAdapter(OpenAIMixin):
|
|||
# ollama client attaches itself to the current event loop (sadly?)
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop not in self._clients:
|
||||
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
|
||||
# Ollama client expects base URL without /v1 suffix
|
||||
base_url_str = str(self.config.base_url)
|
||||
if base_url_str.endswith("/v1"):
|
||||
host = base_url_str[:-3]
|
||||
else:
|
||||
host = base_url_str
|
||||
self._clients[loop] = AsyncOllamaClient(host=host)
|
||||
return self._clients[loop]
|
||||
|
||||
def get_api_key(self):
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self):
|
||||
return self.config.url.rstrip("/") + "/v1"
|
||||
return str(self.config.base_url)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||
logger.info(f"checking connectivity to Ollama at `{self.config.base_url}`...")
|
||||
r = await self.health()
|
||||
if r["status"] == HealthStatus.ERROR:
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack_api import json_schema_type
|
||||
|
|
@ -21,8 +21,8 @@ class OpenAIProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIConfig(RemoteInferenceProviderConfig):
|
||||
base_url: str = Field(
|
||||
default="https://api.openai.com/v1",
|
||||
base_url: HttpUrl | None = Field(
|
||||
default=HttpUrl("https://api.openai.com/v1"),
|
||||
description="Base URL for OpenAI API",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -35,4 +35,4 @@ class OpenAIInferenceAdapter(OpenAIMixin):
|
|||
|
||||
Returns the OpenAI API base URL from the configuration.
|
||||
"""
|
||||
return self.config.base_url
|
||||
return str(self.config.base_url)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, HttpUrl
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack_api import json_schema_type
|
||||
|
|
@ -14,16 +14,16 @@ from llama_stack_api import json_schema_type
|
|||
|
||||
@json_schema_type
|
||||
class PassthroughImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
base_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="The URL for the passthrough endpoint",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
|
||||
cls, base_url: HttpUrl | None = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"base_url": base_url,
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -82,8 +82,8 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
|||
|
||||
def _get_passthrough_url(self) -> str:
|
||||
"""Get the passthrough URL from config or provider data."""
|
||||
if self.config.url is not None:
|
||||
return self.config.url
|
||||
if self.config.base_url is not None:
|
||||
return str(self.config.base_url)
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack_api import json_schema_type
|
||||
|
|
@ -21,7 +21,7 @@ class RunpodProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class RunpodImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
base_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="The URL for the Runpod model serving endpoint",
|
||||
)
|
||||
|
|
@ -34,6 +34,6 @@ class RunpodImplConfig(RemoteInferenceProviderConfig):
|
|||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.RUNPOD_URL:=}",
|
||||
"base_url": "${env.RUNPOD_URL:=}",
|
||||
"api_token": "${env.RUNPOD_API_TOKEN}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class RunpodInferenceAdapter(OpenAIMixin):
|
|||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get base URL for OpenAI client."""
|
||||
return self.config.url
|
||||
return str(self.config.base_url)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack_api import json_schema_type
|
||||
|
|
@ -21,14 +21,14 @@ class SambaNovaProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.sambanova.ai/v1",
|
||||
base_url: HttpUrl | None = Field(
|
||||
default=HttpUrl("https://api.sambanova.ai/v1"),
|
||||
description="The URL for the SambaNova AI server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.sambanova.ai/v1",
|
||||
"base_url": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,4 +25,4 @@ class SambaNovaInferenceAdapter(OpenAIMixin):
|
|||
|
||||
:return: The SambaNova base URL
|
||||
"""
|
||||
return self.config.url
|
||||
return str(self.config.base_url)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack_api import json_schema_type
|
||||
|
|
@ -15,18 +15,19 @@ from llama_stack_api import json_schema_type
|
|||
class TGIImplConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
url: str = Field(
|
||||
description="The URL for the TGI serving endpoint",
|
||||
base_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="The URL for the TGI serving endpoint (should include /v1 path)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.TGI_URL:=}",
|
||||
base_url: str = "${env.TGI_URL:=}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"url": url,
|
||||
"base_url": base_url,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
from collections.abc import Iterable
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from pydantic import SecretStr
|
||||
from pydantic import HttpUrl, SecretStr
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
|
@ -23,7 +23,7 @@ log = get_logger(name=__name__, category="inference::tgi")
|
|||
|
||||
|
||||
class _HfAdapter(OpenAIMixin):
|
||||
url: str
|
||||
base_url: HttpUrl
|
||||
api_key: SecretStr
|
||||
|
||||
hf_client: AsyncInferenceClient
|
||||
|
|
@ -36,7 +36,7 @@ class _HfAdapter(OpenAIMixin):
|
|||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self):
|
||||
return self.url
|
||||
return self.base_url
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [self.model_id]
|
||||
|
|
@ -50,14 +50,20 @@ class _HfAdapter(OpenAIMixin):
|
|||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
if not config.url:
|
||||
if not config.base_url:
|
||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
|
||||
log.info(f"Initializing TGI client with url={config.base_url}")
|
||||
# Extract base URL without /v1 for HF client initialization
|
||||
base_url_str = str(config.base_url).rstrip("/")
|
||||
if base_url_str.endswith("/v1"):
|
||||
base_url_for_client = base_url_str[:-3]
|
||||
else:
|
||||
base_url_for_client = base_url_str
|
||||
self.hf_client = AsyncInferenceClient(model=base_url_for_client, provider="hf-inference")
|
||||
endpoint_info = await self.hf_client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
self.url = f"{config.url.rstrip('/')}/v1"
|
||||
self.base_url = config.base_url
|
||||
self.api_key = SecretStr("NO_KEY")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, HttpUrl
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack_api import json_schema_type
|
||||
|
|
@ -14,14 +14,14 @@ from llama_stack_api import json_schema_type
|
|||
|
||||
@json_schema_type
|
||||
class TogetherImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
base_url: HttpUrl | None = Field(
|
||||
default=HttpUrl("https://api.together.xyz/v1"),
|
||||
description="The URL for the Together AI server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.together.xyz/v1",
|
||||
"base_url": "https://api.together.xyz/v1",
|
||||
"api_key": "${env.TOGETHER_API_KEY:=}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from collections.abc import Iterable
|
|||
from typing import Any, cast
|
||||
|
||||
from together import AsyncTogether # type: ignore[import-untyped]
|
||||
from together.constants import BASE_URL # type: ignore[import-untyped]
|
||||
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -42,7 +41,7 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
|||
provider_data_api_key_field: str = "together_api_key"
|
||||
|
||||
def get_base_url(self):
|
||||
return BASE_URL
|
||||
return str(self.config.base_url)
|
||||
|
||||
def _get_client(self) -> AsyncTogether:
|
||||
together_api_key = None
|
||||
|
|
|
|||
|
|
@ -51,4 +51,4 @@ class VertexAIInferenceAdapter(OpenAIMixin):
|
|||
|
||||
:return: An iterable of model IDs
|
||||
"""
|
||||
return ["vertexai/gemini-2.0-flash", "vertexai/gemini-2.5-flash", "vertexai/gemini-2.5-pro"]
|
||||
return ["google/gemini-2.0-flash", "google/gemini-2.5-flash", "google/gemini-2.5-pro"]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, SecretStr, field_validator
|
||||
from pydantic import Field, HttpUrl, SecretStr, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack_api import json_schema_type
|
||||
|
|
@ -14,7 +14,7 @@ from llama_stack_api import json_schema_type
|
|||
|
||||
@json_schema_type
|
||||
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
base_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="The URL for the vLLM model serving endpoint",
|
||||
)
|
||||
|
|
@ -48,11 +48,11 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
|||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.VLLM_URL:=}",
|
||||
base_url: str = "${env.VLLM_URL:=}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"url": url,
|
||||
"base_url": base_url,
|
||||
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
|
||||
"api_token": "${env.VLLM_API_TOKEN:=fake}",
|
||||
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
|
||||
|
|
|
|||
|
|
@ -39,12 +39,12 @@ class VLLMInferenceAdapter(OpenAIMixin):
|
|||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get the base URL from config."""
|
||||
if not self.config.url:
|
||||
if not self.config.base_url:
|
||||
raise ValueError("No base URL configured")
|
||||
return self.config.url
|
||||
return str(self.config.base_url)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if not self.config.url:
|
||||
if not self.config.base_url:
|
||||
raise ValueError(
|
||||
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack_api import json_schema_type
|
||||
|
|
@ -23,7 +23,7 @@ class WatsonXProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class WatsonXConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
base_url: HttpUrl | None = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||
description="A base url for accessing the watsonx.ai",
|
||||
)
|
||||
|
|
@ -39,7 +39,7 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
|
|||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
|
||||
"base_url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
|
||||
"api_key": "${env.WATSONX_API_KEY:=}",
|
||||
"project_id": "${env.WATSONX_PROJECT_ID:=}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
)
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return self.config.url
|
||||
return str(self.config.base_url)
|
||||
|
||||
# Copied from OpenAIMixin
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
|
|
@ -316,7 +316,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
"""
|
||||
Retrieves foundation model specifications from the watsonx.ai API.
|
||||
"""
|
||||
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
|
||||
url = f"{str(self.config.base_url)}/ml/v1/foundation_model_specs?version=2023-10-25"
|
||||
headers = {
|
||||
# Note that there is no authorization header. Listing models does not require authentication.
|
||||
"Content-Type": "application/json",
|
||||
|
|
|
|||
|
|
@ -48,16 +48,10 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
|
|||
if mcp_endpoint is None:
|
||||
raise ValueError("mcp_endpoint is required")
|
||||
|
||||
# Phase 1: Support both old header-based auth AND new authorization parameter
|
||||
# Get headers and auth from provider data (old approach)
|
||||
provider_headers, provider_auth = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||
# Get other headers from provider data (but NOT authorization)
|
||||
provider_headers = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||
|
||||
# New authorization parameter takes precedence over provider data
|
||||
final_authorization = authorization or provider_auth
|
||||
|
||||
return await list_mcp_tools(
|
||||
endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=final_authorization
|
||||
)
|
||||
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=authorization)
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
|
|
@ -69,39 +63,38 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
|
|||
if urlparse(endpoint).scheme not in ("http", "https"):
|
||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||
|
||||
# Phase 1: Support both old header-based auth AND new authorization parameter
|
||||
# Get headers and auth from provider data (old approach)
|
||||
provider_headers, provider_auth = await self.get_headers_from_request(endpoint)
|
||||
|
||||
# New authorization parameter takes precedence over provider data
|
||||
final_authorization = authorization or provider_auth
|
||||
# Get other headers from provider data (but NOT authorization)
|
||||
provider_headers = await self.get_headers_from_request(endpoint)
|
||||
|
||||
return await invoke_mcp_tool(
|
||||
endpoint=endpoint,
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
headers=provider_headers,
|
||||
authorization=final_authorization,
|
||||
authorization=authorization,
|
||||
)
|
||||
|
||||
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
|
||||
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
|
||||
"""
|
||||
Extract headers and authorization from request provider data (Phase 1 backward compatibility).
|
||||
Extract headers from request provider data, excluding authorization.
|
||||
|
||||
Phase 1: Temporarily allows Authorization to be passed via mcp_headers for backward compatibility.
|
||||
Phase 2: Will enforce that Authorization should use the dedicated authorization parameter instead.
|
||||
Authorization must be provided via the dedicated authorization parameter.
|
||||
If Authorization is found in mcp_headers, raise an error to guide users to the correct approach.
|
||||
|
||||
Args:
|
||||
mcp_endpoint_uri: The MCP endpoint URI to match against provider data
|
||||
|
||||
Returns:
|
||||
Tuple of (headers_dict, authorization_token)
|
||||
- headers_dict: All headers except Authorization
|
||||
- authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None
|
||||
dict[str, str]: Headers dictionary (without Authorization)
|
||||
|
||||
Raises:
|
||||
ValueError: If Authorization header is found in mcp_headers
|
||||
"""
|
||||
|
||||
def canonicalize_uri(uri: str) -> str:
|
||||
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
|
||||
|
||||
headers = {}
|
||||
authorization = None
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data and hasattr(provider_data, "mcp_headers") and provider_data.mcp_headers:
|
||||
|
|
@ -109,17 +102,14 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
|
|||
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
||||
continue
|
||||
|
||||
# Phase 1: Extract Authorization from mcp_headers for backward compatibility
|
||||
# (Phase 2 will reject this and require the dedicated authorization parameter)
|
||||
# Reject Authorization in mcp_headers - must use authorization parameter
|
||||
for key in values.keys():
|
||||
if key.lower() == "authorization":
|
||||
# Extract authorization token and strip "Bearer " prefix if present
|
||||
auth_value = values[key]
|
||||
if auth_value.startswith("Bearer "):
|
||||
authorization = auth_value[7:] # Remove "Bearer " prefix
|
||||
else:
|
||||
authorization = auth_value
|
||||
else:
|
||||
headers[key] = values[key]
|
||||
raise ValueError(
|
||||
"Authorization cannot be provided via mcp_headers in provider_data. "
|
||||
"Please use the dedicated 'authorization' parameter instead. "
|
||||
"Example: tool_runtime.invoke_tool(..., authorization='your-token')"
|
||||
)
|
||||
headers[key] = values[key]
|
||||
|
||||
return headers, authorization
|
||||
return headers
|
||||
|
|
|
|||
|
|
@ -11,10 +11,9 @@ from urllib.parse import urlparse
|
|||
import chromadb
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||
from llama_stack_api import (
|
||||
|
|
@ -27,6 +26,7 @@ from llama_stack_api import (
|
|||
VectorStore,
|
||||
VectorStoresProtocolPrivate,
|
||||
)
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -11,10 +11,9 @@ from typing import Any
|
|||
from numpy.typing import NDArray
|
||||
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
|
||||
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_WEIGHTED,
|
||||
|
|
@ -34,6 +33,7 @@ from llama_stack_api import (
|
|||
VectorStoreNotFoundError,
|
||||
VectorStoresProtocolPrivate,
|
||||
)
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -13,10 +13,9 @@ from psycopg2 import sql
|
|||
from psycopg2.extras import Json, execute_values
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
|
||||
|
|
@ -31,6 +30,7 @@ from llama_stack_api import (
|
|||
VectorStoreNotFoundError,
|
||||
VectorStoresProtocolPrivate,
|
||||
)
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@ from numpy.typing import NDArray
|
|||
from qdrant_client import AsyncQdrantClient, models
|
||||
from qdrant_client.models import PointStruct
|
||||
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||
from llama_stack_api import (
|
||||
|
|
|
|||
|
|
@ -13,9 +13,8 @@ from weaviate.classes.init import Auth
|
|||
from weaviate.classes.query import Filter, HybridFusion
|
||||
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
|
|
@ -35,6 +34,7 @@ from llama_stack_api import (
|
|||
VectorStoreNotFoundError,
|
||||
VectorStoresProtocolPrivate,
|
||||
)
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from .config import WeaviateVectorIOConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from sqlalchemy.exc import IntegrityError
|
|||
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.storage.datatypes import InferenceStoreReference, StorageBackendType
|
||||
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
ListOpenAIChatCompletionResponse,
|
||||
|
|
@ -18,10 +20,7 @@ from llama_stack_api import (
|
|||
OpenAIMessageParam,
|
||||
Order,
|
||||
)
|
||||
|
||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from ..sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
|
||||
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
|
|
|||
|
|
@ -3,23 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from collections.abc import Iterable
|
||||
from typing import (
|
||||
Any,
|
||||
)
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
|
||||
)
|
||||
|
||||
try:
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall,
|
||||
)
|
||||
except ImportError:
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
|
|
@ -32,18 +19,6 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolCall,
|
||||
ToolDefinition,
|
||||
)
|
||||
from llama_stack_api import (
|
||||
URL,
|
||||
GreedySamplingStrategy,
|
||||
ImageContentItem,
|
||||
JsonSchemaResponseFormat,
|
||||
OpenAIResponseFormatParam,
|
||||
SamplingParams,
|
||||
TextContentItem,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
_URLOrData,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
|
@ -73,42 +48,6 @@ class OpenAICompatCompletionResponse(BaseModel):
|
|||
choices: list[OpenAICompatCompletionChoice]
|
||||
|
||||
|
||||
def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
||||
options = {}
|
||||
if isinstance(params.strategy, GreedySamplingStrategy):
|
||||
options["temperature"] = 0.0
|
||||
elif isinstance(params.strategy, TopPSamplingStrategy):
|
||||
if params.strategy.temperature is not None:
|
||||
options["temperature"] = params.strategy.temperature
|
||||
if params.strategy.top_p is not None:
|
||||
options["top_p"] = params.strategy.top_p
|
||||
elif isinstance(params.strategy, TopKSamplingStrategy):
|
||||
options["top_k"] = params.strategy.top_k
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy: {params.strategy}")
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def get_sampling_options(params: SamplingParams | None) -> dict:
|
||||
if not params:
|
||||
return {}
|
||||
|
||||
options = {}
|
||||
if params:
|
||||
options.update(get_sampling_strategy_options(params))
|
||||
if params.max_tokens:
|
||||
options["max_tokens"] = params.max_tokens
|
||||
|
||||
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
||||
options["repeat_penalty"] = params.repetition_penalty
|
||||
|
||||
if params.stop is not None:
|
||||
options["stop"] = params.stop
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def text_from_choice(choice) -> str:
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
return choice.delta.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations
|
||||
|
|
@ -253,154 +192,6 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
return out
|
||||
|
||||
|
||||
def _convert_stop_reason_to_openai_finish_reason(stop_reason: StopReason) -> str:
|
||||
"""
|
||||
Convert a StopReason to an OpenAI chat completion finish_reason.
|
||||
"""
|
||||
return {
|
||||
StopReason.end_of_turn: "stop",
|
||||
StopReason.end_of_message: "tool_calls",
|
||||
StopReason.out_of_tokens: "length",
|
||||
}.get(stop_reason, "stop")
|
||||
|
||||
|
||||
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
||||
"""
|
||||
Convert an OpenAI chat completion finish_reason to a StopReason.
|
||||
|
||||
finish_reason: Literal["stop", "length", "tool_calls", ...]
|
||||
- stop: model hit a natural stop point or a provided stop sequence
|
||||
- length: maximum number of tokens specified in the request was reached
|
||||
- tool_calls: model called a tool
|
||||
|
||||
->
|
||||
|
||||
class StopReason(Enum):
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
"""
|
||||
|
||||
# TODO(mf): are end_of_turn and end_of_message semantics correct?
|
||||
return {
|
||||
"stop": StopReason.end_of_turn,
|
||||
"length": StopReason.out_of_tokens,
|
||||
"tool_calls": StopReason.end_of_message,
|
||||
}.get(finish_reason, StopReason.end_of_turn)
|
||||
|
||||
|
||||
def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]:
|
||||
lls_tools: list[ToolDefinition] = []
|
||||
if not tools:
|
||||
return lls_tools
|
||||
|
||||
for tool in tools:
|
||||
tool_fn = tool.get("function", {})
|
||||
tool_name = tool_fn.get("name", None)
|
||||
tool_desc = tool_fn.get("description", None)
|
||||
tool_params = tool_fn.get("parameters", None)
|
||||
|
||||
lls_tool = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool_desc,
|
||||
input_schema=tool_params, # Pass through entire JSON Schema
|
||||
)
|
||||
lls_tools.append(lls_tool)
|
||||
return lls_tools
|
||||
|
||||
|
||||
def _convert_openai_request_response_format(
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
):
|
||||
if not response_format:
|
||||
return None
|
||||
# response_format can be a dict or a pydantic model
|
||||
response_format_dict = dict(response_format) # type: ignore[arg-type] # OpenAIResponseFormatParam union needs dict conversion
|
||||
if response_format_dict.get("type", "") == "json_schema":
|
||||
return JsonSchemaResponseFormat(
|
||||
type="json_schema", # type: ignore[arg-type] # Literal["json_schema"] incompatible with expected type
|
||||
json_schema=response_format_dict.get("json_schema", {}).get("schema", ""),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _convert_openai_tool_calls(
|
||||
tool_calls: list[OpenAIChatCompletionMessageFunctionToolCall],
|
||||
) -> list[ToolCall]:
|
||||
"""
|
||||
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
|
||||
|
||||
OpenAI ChatCompletionMessageToolCall:
|
||||
id: str
|
||||
function: Function
|
||||
type: Literal["function"]
|
||||
|
||||
OpenAI Function:
|
||||
arguments: str
|
||||
name: str
|
||||
|
||||
->
|
||||
|
||||
ToolCall:
|
||||
call_id: str
|
||||
tool_name: str
|
||||
arguments: Dict[str, ...]
|
||||
"""
|
||||
if not tool_calls:
|
||||
return [] # CompletionMessage tool_calls is not optional
|
||||
|
||||
return [
|
||||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=call.function.arguments,
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
||||
|
||||
def _convert_openai_sampling_params(
|
||||
max_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
) -> SamplingParams:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
if max_tokens:
|
||||
sampling_params.max_tokens = max_tokens
|
||||
|
||||
# Map an explicit temperature of 0 to greedy sampling
|
||||
if temperature == 0:
|
||||
sampling_params.strategy = GreedySamplingStrategy()
|
||||
else:
|
||||
# OpenAI defaults to 1.0 for temperature and top_p if unset
|
||||
if temperature is None:
|
||||
temperature = 1.0
|
||||
if top_p is None:
|
||||
top_p = 1.0
|
||||
sampling_params.strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p) # type: ignore[assignment] # SamplingParams.strategy union accepts this type
|
||||
|
||||
return sampling_params
|
||||
|
||||
|
||||
def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam] | None):
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
return [openai_content_to_content(c) for c in content]
|
||||
elif hasattr(content, "type"):
|
||||
if content.type == "text":
|
||||
return TextContentItem(type="text", text=content.text) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
|
||||
elif content.type == "image_url":
|
||||
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {content.type}")
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {content}")
|
||||
|
||||
|
||||
async def prepare_openai_completion_params(**params):
|
||||
async def _prepare_value(value: Any) -> Any:
|
||||
new_value = value
|
||||
|
|
|
|||
|
|
@ -213,6 +213,19 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
|
||||
return api_key
|
||||
|
||||
def _validate_model_allowed(self, provider_model_id: str) -> None:
|
||||
"""
|
||||
Validate that the model is in the allowed_models list if configured.
|
||||
|
||||
:param provider_model_id: The provider-specific model ID to validate
|
||||
:raises ValueError: If the model is not in the allowed_models list
|
||||
"""
|
||||
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
|
||||
raise ValueError(
|
||||
f"Model '{provider_model_id}' is not in the allowed models list. "
|
||||
f"Allowed models: {self.config.allowed_models}"
|
||||
)
|
||||
|
||||
async def _get_provider_model_id(self, model: str) -> str:
|
||||
"""
|
||||
Get the provider-specific model ID from the model store.
|
||||
|
|
@ -259,8 +272,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
Direct OpenAI completion API call.
|
||||
"""
|
||||
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
||||
provider_model_id = await self._get_provider_model_id(params.model)
|
||||
self._validate_model_allowed(provider_model_id)
|
||||
|
||||
completion_kwargs = await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(params.model),
|
||||
model=provider_model_id,
|
||||
prompt=params.prompt,
|
||||
best_of=params.best_of,
|
||||
echo=params.echo,
|
||||
|
|
@ -292,6 +308,9 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
"""
|
||||
Direct OpenAI chat completion API call.
|
||||
"""
|
||||
provider_model_id = await self._get_provider_model_id(params.model)
|
||||
self._validate_model_allowed(provider_model_id)
|
||||
|
||||
messages = params.messages
|
||||
|
||||
if self.download_images:
|
||||
|
|
@ -313,7 +332,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
messages = [await _localize_image_url(m) for m in messages]
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(params.model),
|
||||
model=provider_model_id,
|
||||
messages=messages,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
function_call=params.function_call,
|
||||
|
|
@ -351,10 +370,13 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
"""
|
||||
Direct OpenAI embeddings API call.
|
||||
"""
|
||||
provider_model_id = await self._get_provider_model_id(params.model)
|
||||
self._validate_model_allowed(provider_model_id)
|
||||
|
||||
# Build request params conditionally to avoid NotGiven/Omit type mismatch
|
||||
# The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven
|
||||
request_params: dict[str, Any] = {
|
||||
"model": await self._get_provider_model_id(params.model),
|
||||
"model": provider_model_id,
|
||||
"input": params.input,
|
||||
}
|
||||
if params.encoding_format is not None:
|
||||
|
|
|
|||
|
|
@ -1,20 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack_api import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SqliteControlPlaneConfig(BaseModel):
|
||||
db_path: str = Field(
|
||||
description="File path for the sqlite database",
|
||||
)
|
||||
table_name: str = Field(
|
||||
default="llamastack_control_plane",
|
||||
description="Table into which all the keys will be placed",
|
||||
)
|
||||
|
|
@ -17,7 +17,6 @@ from pydantic import TypeAdapter
|
|||
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
content_from_data_and_mime_type,
|
||||
|
|
@ -53,6 +52,7 @@ from llama_stack_api import (
|
|||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
EMBEDDING_DIMENSION = 768
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@
|
|||
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqlStoreReference
|
||||
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
ListOpenAIResponseInputItem,
|
||||
|
|
@ -17,10 +19,7 @@ from llama_stack_api import (
|
|||
OpenAIResponseObjectWithInput,
|
||||
Order,
|
||||
)
|
||||
|
||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from ..sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,140 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack_api import PaginatedResponse
|
||||
|
||||
|
||||
class ColumnType(Enum):
|
||||
INTEGER = "INTEGER"
|
||||
STRING = "STRING"
|
||||
TEXT = "TEXT"
|
||||
FLOAT = "FLOAT"
|
||||
BOOLEAN = "BOOLEAN"
|
||||
JSON = "JSON"
|
||||
DATETIME = "DATETIME"
|
||||
|
||||
|
||||
class ColumnDefinition(BaseModel):
|
||||
type: ColumnType
|
||||
primary_key: bool = False
|
||||
nullable: bool = True
|
||||
default: Any = None
|
||||
|
||||
|
||||
class SqlStore(Protocol):
|
||||
"""
|
||||
A protocol for a SQL store.
|
||||
"""
|
||||
|
||||
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
|
||||
"""
|
||||
Create a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||
"""
|
||||
Insert a row or batch of rows into a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
conflict_columns: list[str],
|
||||
update_columns: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Insert a row and update specified columns when conflicts occur.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
where_sql: str | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
cursor: tuple[str, str] | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""
|
||||
Fetch all rows from a table with optional cursor-based pagination.
|
||||
|
||||
:param table: The table name
|
||||
:param where: Simple key-value WHERE conditions
|
||||
:param where_sql: Raw SQL WHERE clause for complex queries
|
||||
:param limit: Maximum number of records to return
|
||||
:param order_by: List of (column, order) tuples for sorting
|
||||
:param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page)
|
||||
Requires order_by with exactly one column when used
|
||||
:return: PaginatedResult with data and has_more flag
|
||||
|
||||
Note: Cursor pagination only supports single-column ordering for simplicity.
|
||||
Multi-column ordering is allowed without cursor but will raise an error with cursor.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def fetch_one(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
where_sql: str | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Fetch one row from a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def update(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
where: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Update a row in a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Delete a row from a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def add_column_if_not_exists(
|
||||
self,
|
||||
table: str,
|
||||
column_name: str,
|
||||
column_type: ColumnType,
|
||||
nullable: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Add a column to an existing table if the column doesn't already exist.
|
||||
|
||||
This is useful for table migrations when adding new functionality.
|
||||
If the table doesn't exist, this method should do nothing.
|
||||
If the column already exists, this method should do nothing.
|
||||
|
||||
:param table: Table name
|
||||
:param column_name: Name of the column to add
|
||||
:param column_type: Type of the column to add
|
||||
:param nullable: Whether the column should be nullable (default: True)
|
||||
"""
|
||||
pass
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue