mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: support postgresql inference store (#2310)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 3s
Integration Tests / test-matrix (http, post_training) (push) Failing after 11s
Integration Tests / test-matrix (library, inference) (push) Failing after 13s
Integration Tests / test-matrix (http, providers) (push) Failing after 15s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 16s
Integration Tests / test-matrix (http, datasets) (push) Failing after 18s
Integration Tests / test-matrix (http, scoring) (push) Failing after 16s
Integration Tests / test-matrix (http, agents) (push) Failing after 19s
Integration Tests / test-matrix (library, datasets) (push) Failing after 16s
Integration Tests / test-matrix (http, inspect) (push) Failing after 18s
Integration Tests / test-matrix (library, agents) (push) Failing after 18s
Integration Tests / test-matrix (http, inference) (push) Failing after 20s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 8s
Test External Providers / test-external-providers (venv) (push) Failing after 8s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, providers) (push) Failing after 11s
Unit Tests / unit-tests (3.11) (push) Failing after 8s
Unit Tests / unit-tests (3.10) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 8s
Unit Tests / unit-tests (3.13) (push) Failing after 8s
Pre-commit / pre-commit (push) Successful in 57s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 3s
Integration Tests / test-matrix (http, post_training) (push) Failing after 11s
Integration Tests / test-matrix (library, inference) (push) Failing after 13s
Integration Tests / test-matrix (http, providers) (push) Failing after 15s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 16s
Integration Tests / test-matrix (http, datasets) (push) Failing after 18s
Integration Tests / test-matrix (http, scoring) (push) Failing after 16s
Integration Tests / test-matrix (http, agents) (push) Failing after 19s
Integration Tests / test-matrix (library, datasets) (push) Failing after 16s
Integration Tests / test-matrix (http, inspect) (push) Failing after 18s
Integration Tests / test-matrix (library, agents) (push) Failing after 18s
Integration Tests / test-matrix (http, inference) (push) Failing after 20s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 8s
Test External Providers / test-external-providers (venv) (push) Failing after 8s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, providers) (push) Failing after 11s
Unit Tests / unit-tests (3.11) (push) Failing after 8s
Unit Tests / unit-tests (3.10) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 8s
Unit Tests / unit-tests (3.13) (push) Failing after 8s
Pre-commit / pre-commit (push) Successful in 57s
# What does this PR do? * Added support postgresql inference store * Added 'oracle' template that demos how to config postgresql stores (except for telemetry, which is not supported currently) ## Test Plan llama stack build --template oracle --image-type conda --run LLAMA_STACK_CONFIG=http://localhost:8321 pytest -s -v tests/integration/ --text-model accounts/fireworks/models/llama-v3p3-70b-instruct -k 'inference_store'
This commit is contained in:
parent
168c7113df
commit
2603f10f95
32 changed files with 516 additions and 53 deletions
|
@ -65,7 +65,7 @@ class SqliteKVStoreConfig(CommonConfig):
|
||||||
class PostgresKVStoreConfig(CommonConfig):
|
class PostgresKVStoreConfig(CommonConfig):
|
||||||
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
||||||
host: str = "localhost"
|
host: str = "localhost"
|
||||||
port: int = 5432
|
port: str = "5432"
|
||||||
db: str = "llamastack"
|
db: str = "llamastack"
|
||||||
user: str
|
user: str
|
||||||
password: str | None = None
|
password: str | None = None
|
||||||
|
|
|
@ -19,10 +19,10 @@ from sqlalchemy import (
|
||||||
Text,
|
Text,
|
||||||
select,
|
select,
|
||||||
)
|
)
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
from ..api import ColumnDefinition, ColumnType, SqlStore
|
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||||
from ..sqlstore import SqliteSqlStoreConfig
|
from .sqlstore import SqlAlchemySqlStoreConfig
|
||||||
|
|
||||||
TYPE_MAPPING: dict[ColumnType, Any] = {
|
TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||||
ColumnType.INTEGER: Integer,
|
ColumnType.INTEGER: Integer,
|
||||||
|
@ -35,9 +35,10 @@ TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class SqliteSqlStoreImpl(SqlStore):
|
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
def __init__(self, config: SqliteSqlStoreConfig):
|
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||||
self.engine = create_async_engine(config.engine_str)
|
self.config = config
|
||||||
|
self.async_session = async_sessionmaker(create_async_engine(config.engine_str))
|
||||||
self.metadata = MetaData()
|
self.metadata = MetaData()
|
||||||
|
|
||||||
async def create_table(
|
async def create_table(
|
||||||
|
@ -78,13 +79,14 @@ class SqliteSqlStoreImpl(SqlStore):
|
||||||
|
|
||||||
# Create the table in the database if it doesn't exist
|
# Create the table in the database if it doesn't exist
|
||||||
# checkfirst=True ensures it doesn't try to recreate if it's already there
|
# checkfirst=True ensures it doesn't try to recreate if it's already there
|
||||||
async with self.engine.begin() as conn:
|
engine = create_async_engine(self.config.engine_str)
|
||||||
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
||||||
|
|
||||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||||
async with self.engine.begin() as conn:
|
async with self.async_session() as session:
|
||||||
await conn.execute(self.metadata.tables[table].insert(), data)
|
await session.execute(self.metadata.tables[table].insert(), data)
|
||||||
await conn.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def fetch_all(
|
async def fetch_all(
|
||||||
self,
|
self,
|
||||||
|
@ -93,7 +95,7 @@ class SqliteSqlStoreImpl(SqlStore):
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
async with self.engine.begin() as conn:
|
async with self.async_session() as session:
|
||||||
query = select(self.metadata.tables[table])
|
query = select(self.metadata.tables[table])
|
||||||
if where:
|
if where:
|
||||||
for key, value in where.items():
|
for key, value in where.items():
|
||||||
|
@ -117,7 +119,7 @@ class SqliteSqlStoreImpl(SqlStore):
|
||||||
query = query.order_by(self.metadata.tables[table].c[name].desc())
|
query = query.order_by(self.metadata.tables[table].c[name].desc())
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
|
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
|
||||||
result = await conn.execute(query)
|
result = await session.execute(query)
|
||||||
if result.rowcount == 0:
|
if result.rowcount == 0:
|
||||||
return []
|
return []
|
||||||
return [dict(row._mapping) for row in result]
|
return [dict(row._mapping) for row in result]
|
||||||
|
@ -142,20 +144,20 @@ class SqliteSqlStoreImpl(SqlStore):
|
||||||
if not where:
|
if not where:
|
||||||
raise ValueError("where is required for update")
|
raise ValueError("where is required for update")
|
||||||
|
|
||||||
async with self.engine.begin() as conn:
|
async with self.async_session() as session:
|
||||||
stmt = self.metadata.tables[table].update()
|
stmt = self.metadata.tables[table].update()
|
||||||
for key, value in where.items():
|
for key, value in where.items():
|
||||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||||
await conn.execute(stmt, data)
|
await session.execute(stmt, data)
|
||||||
await conn.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
||||||
if not where:
|
if not where:
|
||||||
raise ValueError("where is required for delete")
|
raise ValueError("where is required for delete")
|
||||||
|
|
||||||
async with self.engine.begin() as conn:
|
async with self.async_session() as session:
|
||||||
stmt = self.metadata.tables[table].delete()
|
stmt = self.metadata.tables[table].delete()
|
||||||
for key, value in where.items():
|
for key, value in where.items():
|
||||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||||
await conn.execute(stmt)
|
await session.execute(stmt)
|
||||||
await conn.commit()
|
await session.commit()
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
@ -21,7 +22,18 @@ class SqlStoreType(Enum):
|
||||||
postgres = "postgres"
|
postgres = "postgres"
|
||||||
|
|
||||||
|
|
||||||
class SqliteSqlStoreConfig(BaseModel):
|
class SqlAlchemySqlStoreConfig(BaseModel):
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def engine_str(self) -> str: ...
|
||||||
|
|
||||||
|
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
||||||
|
@property
|
||||||
|
def pip_packages(self) -> list[str]:
|
||||||
|
return ["sqlalchemy[asyncio]"]
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||||
type: Literal["sqlite"] = SqlStoreType.sqlite.value
|
type: Literal["sqlite"] = SqlStoreType.sqlite.value
|
||||||
db_path: str = Field(
|
db_path: str = Field(
|
||||||
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||||
|
@ -39,18 +51,26 @@ class SqliteSqlStoreConfig(BaseModel):
|
||||||
db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
|
db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
|
||||||
@property
|
@property
|
||||||
def pip_packages(self) -> list[str]:
|
def pip_packages(self) -> list[str]:
|
||||||
return ["sqlalchemy[asyncio]"]
|
return super().pip_packages + ["aiosqlite"]
|
||||||
|
|
||||||
|
|
||||||
class PostgresSqlStoreConfig(BaseModel):
|
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||||
type: Literal["postgres"] = SqlStoreType.postgres.value
|
type: Literal["postgres"] = SqlStoreType.postgres.value
|
||||||
|
host: str = "localhost"
|
||||||
|
port: str = "5432"
|
||||||
|
db: str = "llamastack"
|
||||||
|
user: str
|
||||||
|
password: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def engine_str(self) -> str:
|
||||||
|
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pip_packages(self) -> list[str]:
|
def pip_packages(self) -> list[str]:
|
||||||
raise NotImplementedError("Postgres is not implemented yet")
|
return super().pip_packages + ["asyncpg"]
|
||||||
|
|
||||||
|
|
||||||
SqlStoreConfig = Annotated[
|
SqlStoreConfig = Annotated[
|
||||||
|
@ -60,12 +80,10 @@ SqlStoreConfig = Annotated[
|
||||||
|
|
||||||
|
|
||||||
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
|
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
|
||||||
if config.type == SqlStoreType.sqlite.value:
|
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
|
||||||
from .sqlite.sqlite import SqliteSqlStoreImpl
|
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||||
|
|
||||||
impl = SqliteSqlStoreImpl(config)
|
impl = SqlAlchemySqlStoreImpl(config)
|
||||||
elif config.type == SqlStoreType.postgres.value:
|
|
||||||
raise NotImplementedError("Postgres is not implemented yet")
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown sqlstore type {config.type}")
|
raise ValueError(f"Unknown sqlstore type {config.type}")
|
||||||
|
|
||||||
|
|
|
@ -30,4 +30,5 @@ distribution_spec:
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -30,4 +30,5 @@ distribution_spec:
|
||||||
- inline::rag-runtime
|
- inline::rag-runtime
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -31,4 +31,5 @@ distribution_spec:
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -31,5 +31,5 @@ distribution_spec:
|
||||||
- inline::rag-runtime
|
- inline::rag-runtime
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- sqlalchemy[asyncio]
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -32,5 +32,5 @@ distribution_spec:
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- sqlalchemy[asyncio]
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -27,4 +27,5 @@ distribution_spec:
|
||||||
- inline::rag-runtime
|
- inline::rag-runtime
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -30,5 +30,5 @@ distribution_spec:
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- sqlalchemy[asyncio]
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -31,5 +31,5 @@ distribution_spec:
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- sqlalchemy[asyncio]
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -31,4 +31,5 @@ distribution_spec:
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -30,5 +30,5 @@ distribution_spec:
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- sqlalchemy[asyncio]
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -25,5 +25,5 @@ distribution_spec:
|
||||||
- inline::rag-runtime
|
- inline::rag-runtime
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- sqlalchemy[asyncio]
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -33,5 +33,5 @@ distribution_spec:
|
||||||
- remote::wolfram-alpha
|
- remote::wolfram-alpha
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- sqlalchemy[asyncio]
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -34,4 +34,5 @@ distribution_spec:
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -32,5 +32,5 @@ distribution_spec:
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- sqlalchemy[asyncio]
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
7
llama_stack/templates/postgres-demo/__init__.py
Normal file
7
llama_stack/templates/postgres-demo/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .postgres_demo import get_distribution_template # noqa: F401
|
24
llama_stack/templates/postgres-demo/build.yaml
Normal file
24
llama_stack/templates/postgres-demo/build.yaml
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
version: '2'
|
||||||
|
distribution_spec:
|
||||||
|
description: Quick start template for running Llama Stack with several popular providers
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::fireworks
|
||||||
|
- remote::vllm
|
||||||
|
vector_io:
|
||||||
|
- remote::chromadb
|
||||||
|
safety:
|
||||||
|
- inline::llama-guard
|
||||||
|
agents:
|
||||||
|
- inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- inline::meta-reference
|
||||||
|
tool_runtime:
|
||||||
|
- remote::brave-search
|
||||||
|
- remote::tavily-search
|
||||||
|
- inline::rag-runtime
|
||||||
|
- remote::model-context-protocol
|
||||||
|
image_type: conda
|
||||||
|
additional_pip_packages:
|
||||||
|
- asyncpg
|
||||||
|
- sqlalchemy[asyncio]
|
164
llama_stack/templates/postgres-demo/postgres_demo.py
Normal file
164
llama_stack/templates/postgres-demo/postgres_demo.py
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
ModelInput,
|
||||||
|
Provider,
|
||||||
|
ShieldInput,
|
||||||
|
ToolGroupInput,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.fireworks.models import (
|
||||||
|
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||||
|
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||||
|
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||||
|
from llama_stack.templates.template import (
|
||||||
|
DistributionTemplate,
|
||||||
|
RunConfigSettings,
|
||||||
|
get_model_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:
|
||||||
|
# in this template, we allow each API key to be optional
|
||||||
|
providers = [
|
||||||
|
(
|
||||||
|
"fireworks",
|
||||||
|
FIREWORKS_MODEL_ENTRIES,
|
||||||
|
FireworksImplConfig.sample_run_config(api_key="${env.FIREWORKS_API_KEY:}"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
inference_providers = []
|
||||||
|
available_models = {}
|
||||||
|
for provider_id, model_entries, config in providers:
|
||||||
|
inference_providers.append(
|
||||||
|
Provider(
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_type=f"remote::{provider_id}",
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
available_models[provider_id] = model_entries
|
||||||
|
inference_providers.append(
|
||||||
|
Provider(
|
||||||
|
provider_id="vllm-inference",
|
||||||
|
provider_type="remote::vllm",
|
||||||
|
config=VLLMInferenceAdapterConfig.sample_run_config(
|
||||||
|
url="${env.VLLM_URL:http://localhost:8000/v1}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return inference_providers, available_models
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
inference_providers, available_models = get_inference_providers()
|
||||||
|
providers = {
|
||||||
|
"inference": ([p.provider_type for p in inference_providers]),
|
||||||
|
"vector_io": ["remote::chromadb"],
|
||||||
|
"safety": ["inline::llama-guard"],
|
||||||
|
"agents": ["inline::meta-reference"],
|
||||||
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
"tool_runtime": [
|
||||||
|
"remote::brave-search",
|
||||||
|
"remote::tavily-search",
|
||||||
|
"inline::rag-runtime",
|
||||||
|
"remote::model-context-protocol",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
name = "postgres-demo"
|
||||||
|
|
||||||
|
vector_io_providers = [
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.ENABLE_CHROMADB+chromadb}",
|
||||||
|
provider_type="remote::chromadb",
|
||||||
|
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
default_tool_groups = [
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::websearch",
|
||||||
|
provider_id="tavily-search",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::rag",
|
||||||
|
provider_id="rag-runtime",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
default_models = get_model_registry(available_models)
|
||||||
|
default_models.append(
|
||||||
|
ModelInput(
|
||||||
|
model_id="${env.INFERENCE_MODEL}",
|
||||||
|
provider_id="vllm-inference",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
postgres_config = {
|
||||||
|
"type": "postgres",
|
||||||
|
"host": "${env.POSTGRES_HOST:localhost}",
|
||||||
|
"port": "${env.POSTGRES_PORT:5432}",
|
||||||
|
"db": "${env.POSTGRES_DB:llamastack}",
|
||||||
|
"user": "${env.POSTGRES_USER:llamastack}",
|
||||||
|
"password": "${env.POSTGRES_PASSWORD:llamastack}",
|
||||||
|
}
|
||||||
|
|
||||||
|
return DistributionTemplate(
|
||||||
|
name=name,
|
||||||
|
distro_type="self_hosted",
|
||||||
|
description="Quick start template for running Llama Stack with several popular providers",
|
||||||
|
container_image=None,
|
||||||
|
template_path=None,
|
||||||
|
providers=providers,
|
||||||
|
available_models_by_provider=available_models,
|
||||||
|
run_configs={
|
||||||
|
"run.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": inference_providers,
|
||||||
|
"vector_io": vector_io_providers,
|
||||||
|
"agents": [
|
||||||
|
Provider(
|
||||||
|
provider_id="meta-reference",
|
||||||
|
provider_type="inline::meta-reference",
|
||||||
|
config=dict(
|
||||||
|
persistence_store=postgres_config,
|
||||||
|
responses_store=postgres_config,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"telemetry": [
|
||||||
|
Provider(
|
||||||
|
provider_id="meta-reference",
|
||||||
|
provider_type="inline::meta-reference",
|
||||||
|
config=dict(
|
||||||
|
service_name="${env.OTEL_SERVICE_NAME:}",
|
||||||
|
sinks="${env.TELEMETRY_SINKS:console}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
},
|
||||||
|
default_models=default_models,
|
||||||
|
default_tool_groups=default_tool_groups,
|
||||||
|
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||||
|
metadata_store=PostgresKVStoreConfig.model_validate(postgres_config),
|
||||||
|
inference_store=PostgresSqlStoreConfig.model_validate(postgres_config),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
run_config_env_vars={
|
||||||
|
"LLAMA_STACK_PORT": (
|
||||||
|
"8321",
|
||||||
|
"Port for the Llama Stack distribution server",
|
||||||
|
),
|
||||||
|
"FIREWORKS_API_KEY": (
|
||||||
|
"",
|
||||||
|
"Fireworks API Key",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
224
llama_stack/templates/postgres-demo/run.yaml
Normal file
224
llama_stack/templates/postgres-demo/run.yaml
Normal file
|
@ -0,0 +1,224 @@
|
||||||
|
version: '2'
|
||||||
|
image_name: postgres-demo
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- inference
|
||||||
|
- safety
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: fireworks
|
||||||
|
provider_type: remote::fireworks
|
||||||
|
config:
|
||||||
|
url: https://api.fireworks.ai/inference/v1
|
||||||
|
api_key: ${env.FIREWORKS_API_KEY:}
|
||||||
|
- provider_id: vllm-inference
|
||||||
|
provider_type: remote::vllm
|
||||||
|
config:
|
||||||
|
url: ${env.VLLM_URL:http://localhost:8000/v1}
|
||||||
|
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
|
||||||
|
api_token: ${env.VLLM_API_TOKEN:fake}
|
||||||
|
tls_verify: ${env.VLLM_TLS_VERIFY:true}
|
||||||
|
vector_io:
|
||||||
|
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
|
||||||
|
provider_type: remote::chromadb
|
||||||
|
config:
|
||||||
|
url: ${env.CHROMADB_URL:}
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config:
|
||||||
|
excluded_categories: []
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: postgres
|
||||||
|
host: ${env.POSTGRES_HOST:localhost}
|
||||||
|
port: ${env.POSTGRES_PORT:5432}
|
||||||
|
db: ${env.POSTGRES_DB:llamastack}
|
||||||
|
user: ${env.POSTGRES_USER:llamastack}
|
||||||
|
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||||
|
responses_store:
|
||||||
|
type: postgres
|
||||||
|
host: ${env.POSTGRES_HOST:localhost}
|
||||||
|
port: ${env.POSTGRES_PORT:5432}
|
||||||
|
db: ${env.POSTGRES_DB:llamastack}
|
||||||
|
user: ${env.POSTGRES_USER:llamastack}
|
||||||
|
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:console}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: brave-search
|
||||||
|
provider_type: remote::brave-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: tavily-search
|
||||||
|
provider_type: remote::tavily-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: rag-runtime
|
||||||
|
provider_type: inline::rag-runtime
|
||||||
|
config: {}
|
||||||
|
- provider_id: model-context-protocol
|
||||||
|
provider_type: remote::model-context-protocol
|
||||||
|
config: {}
|
||||||
|
metadata_store:
|
||||||
|
type: postgres
|
||||||
|
host: ${env.POSTGRES_HOST:localhost}
|
||||||
|
port: ${env.POSTGRES_PORT:5432}
|
||||||
|
db: ${env.POSTGRES_DB:llamastack}
|
||||||
|
user: ${env.POSTGRES_USER:llamastack}
|
||||||
|
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||||
|
table_name: llamastack_kvstore
|
||||||
|
inference_store:
|
||||||
|
type: postgres
|
||||||
|
host: ${env.POSTGRES_HOST:localhost}
|
||||||
|
port: ${env.POSTGRES_PORT:5432}
|
||||||
|
db: ${env.POSTGRES_DB:llamastack}
|
||||||
|
user: ${env.POSTGRES_USER:llamastack}
|
||||||
|
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: accounts/fireworks/models/llama-guard-3-8b
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-Guard-3-8B
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: accounts/fireworks/models/llama-guard-3-11b-vision
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: accounts/fireworks/models/llama4-scout-instruct-basic
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 768
|
||||||
|
context_length: 8192
|
||||||
|
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: nomic-ai/nomic-embed-text-v1.5
|
||||||
|
model_type: embedding
|
||||||
|
- metadata: {}
|
||||||
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
|
provider_id: vllm-inference
|
||||||
|
model_type: llm
|
||||||
|
shields:
|
||||||
|
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||||
|
vector_dbs: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
benchmarks: []
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::websearch
|
||||||
|
provider_id: tavily-search
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: rag-runtime
|
||||||
|
server:
|
||||||
|
port: 8321
|
|
@ -32,5 +32,5 @@ distribution_spec:
|
||||||
- remote::wolfram-alpha
|
- remote::wolfram-alpha
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- sqlalchemy[asyncio]
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -23,4 +23,5 @@ distribution_spec:
|
||||||
- remote::wolfram-alpha
|
- remote::wolfram-alpha
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -36,4 +36,5 @@ distribution_spec:
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -28,8 +28,8 @@ from llama_stack.distribution.datatypes import (
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
def get_model_registry(
|
def get_model_registry(
|
||||||
|
@ -64,6 +64,8 @@ class RunConfigSettings(BaseModel):
|
||||||
default_tool_groups: list[ToolGroupInput] | None = None
|
default_tool_groups: list[ToolGroupInput] | None = None
|
||||||
default_datasets: list[DatasetInput] | None = None
|
default_datasets: list[DatasetInput] | None = None
|
||||||
default_benchmarks: list[BenchmarkInput] | None = None
|
default_benchmarks: list[BenchmarkInput] | None = None
|
||||||
|
metadata_store: KVStoreConfig | None = None
|
||||||
|
inference_store: SqlStoreConfig | None = None
|
||||||
|
|
||||||
def run_config(
|
def run_config(
|
||||||
self,
|
self,
|
||||||
|
@ -114,11 +116,13 @@ class RunConfigSettings(BaseModel):
|
||||||
container_image=container_image,
|
container_image=container_image,
|
||||||
apis=apis,
|
apis=apis,
|
||||||
providers=provider_configs,
|
providers=provider_configs,
|
||||||
metadata_store=SqliteKVStoreConfig.sample_run_config(
|
metadata_store=self.metadata_store
|
||||||
|
or SqliteKVStoreConfig.sample_run_config(
|
||||||
__distro_dir__=f"~/.llama/distributions/{name}",
|
__distro_dir__=f"~/.llama/distributions/{name}",
|
||||||
db_name="registry.db",
|
db_name="registry.db",
|
||||||
),
|
),
|
||||||
inference_store=SqliteSqlStoreConfig.sample_run_config(
|
inference_store=self.inference_store
|
||||||
|
or SqliteSqlStoreConfig.sample_run_config(
|
||||||
__distro_dir__=f"~/.llama/distributions/{name}",
|
__distro_dir__=f"~/.llama/distributions/{name}",
|
||||||
db_name="inference_store.db",
|
db_name="inference_store.db",
|
||||||
),
|
),
|
||||||
|
@ -164,7 +168,7 @@ class DistributionTemplate(BaseModel):
|
||||||
providers=self.providers,
|
providers=self.providers,
|
||||||
),
|
),
|
||||||
image_type="conda", # default to conda, can be overridden
|
image_type="conda", # default to conda, can be overridden
|
||||||
additional_pip_packages=additional_pip_packages,
|
additional_pip_packages=sorted(set(additional_pip_packages)),
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_markdown_docs(self) -> str:
|
def generate_markdown_docs(self) -> str:
|
||||||
|
|
|
@ -31,5 +31,5 @@ distribution_spec:
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- sqlalchemy[asyncio]
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -32,5 +32,5 @@ distribution_spec:
|
||||||
- remote::wolfram-alpha
|
- remote::wolfram-alpha
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- sqlalchemy[asyncio]
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -36,4 +36,5 @@ distribution_spec:
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -31,4 +31,5 @@ distribution_spec:
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -29,4 +29,5 @@ distribution_spec:
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -268,9 +268,9 @@ def test_openai_chat_completion_streaming_with_n(compat_client, client_with_mode
|
||||||
False,
|
False,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_inference_store(openai_client, client_with_models, text_model_id, stream):
|
def test_inference_store(compat_client, client_with_models, text_model_id, stream):
|
||||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||||
client = openai_client
|
client = compat_client
|
||||||
# make a chat completion
|
# make a chat completion
|
||||||
message = "Hello, world!"
|
message = "Hello, world!"
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
|
@ -301,9 +301,14 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
|
||||||
|
|
||||||
retrieved_response = client.chat.completions.retrieve(response_id)
|
retrieved_response = client.chat.completions.retrieve(response_id)
|
||||||
assert retrieved_response.id == response_id
|
assert retrieved_response.id == response_id
|
||||||
assert retrieved_response.input_messages[0]["content"] == message, retrieved_response
|
|
||||||
assert retrieved_response.choices[0].message.content == content, retrieved_response
|
assert retrieved_response.choices[0].message.content == content, retrieved_response
|
||||||
|
|
||||||
|
input_content = (
|
||||||
|
getattr(retrieved_response.input_messages[0], "content", None)
|
||||||
|
or retrieved_response.input_messages[0]["content"]
|
||||||
|
)
|
||||||
|
assert input_content == message, retrieved_response
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"stream",
|
"stream",
|
||||||
|
@ -312,9 +317,9 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
|
||||||
False,
|
False,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream):
|
def test_inference_store_tool_calls(compat_client, client_with_models, text_model_id, stream):
|
||||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||||
client = openai_client
|
client = compat_client
|
||||||
# make a chat completion
|
# make a chat completion
|
||||||
message = "What's the weather in Tokyo? Use the get_weather function to get the weather."
|
message = "What's the weather in Tokyo? Use the get_weather function to get the weather."
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
|
@ -361,7 +366,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
|
||||||
|
|
||||||
retrieved_response = client.chat.completions.retrieve(response_id)
|
retrieved_response = client.chat.completions.retrieve(response_id)
|
||||||
assert retrieved_response.id == response_id
|
assert retrieved_response.id == response_id
|
||||||
assert retrieved_response.input_messages[0]["content"] == message
|
input_content = (
|
||||||
|
getattr(retrieved_response.input_messages[0], "content", None)
|
||||||
|
or retrieved_response.input_messages[0]["content"]
|
||||||
|
)
|
||||||
|
assert input_content == message, retrieved_response
|
||||||
tool_calls = retrieved_response.choices[0].message.tool_calls
|
tool_calls = retrieved_response.choices[0].message.tool_calls
|
||||||
# sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called
|
# sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
|
|
|
@ -9,7 +9,7 @@ from tempfile import TemporaryDirectory
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||||
from llama_stack.providers.utils.sqlstore.sqlite.sqlite import SqliteSqlStoreImpl
|
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
async def test_sqlite_sqlstore():
|
async def test_sqlite_sqlstore():
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_name = "test.db"
|
db_name = "test.db"
|
||||||
sqlstore = SqliteSqlStoreImpl(
|
sqlstore = SqlAlchemySqlStoreImpl(
|
||||||
SqliteSqlStoreConfig(
|
SqliteSqlStoreConfig(
|
||||||
db_path=tmp_dir + "/" + db_name,
|
db_path=tmp_dir + "/" + db_name,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue