forked from phoenix-oss/llama-stack-mirror
feat: support postgresql inference store (#2310)
# What does this PR do? * Added support postgresql inference store * Added 'oracle' template that demos how to config postgresql stores (except for telemetry, which is not supported currently) ## Test Plan llama stack build --template oracle --image-type conda --run LLAMA_STACK_CONFIG=http://localhost:8321 pytest -s -v tests/integration/ --text-model accounts/fireworks/models/llama-v3p3-70b-instruct -k 'inference_store'
This commit is contained in:
parent
168c7113df
commit
2603f10f95
32 changed files with 516 additions and 53 deletions
|
@ -65,7 +65,7 @@ class SqliteKVStoreConfig(CommonConfig):
|
|||
class PostgresKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
port: str = "5432"
|
||||
db: str = "llamastack"
|
||||
user: str
|
||||
password: str | None = None
|
||||
|
|
|
@ -19,10 +19,10 @@ from sqlalchemy import (
|
|||
Text,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from ..api import ColumnDefinition, ColumnType, SqlStore
|
||||
from ..sqlstore import SqliteSqlStoreConfig
|
||||
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||
from .sqlstore import SqlAlchemySqlStoreConfig
|
||||
|
||||
TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||
ColumnType.INTEGER: Integer,
|
||||
|
@ -35,9 +35,10 @@ TYPE_MAPPING: dict[ColumnType, Any] = {
|
|||
}
|
||||
|
||||
|
||||
class SqliteSqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqliteSqlStoreConfig):
|
||||
self.engine = create_async_engine(config.engine_str)
|
||||
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||
self.config = config
|
||||
self.async_session = async_sessionmaker(create_async_engine(config.engine_str))
|
||||
self.metadata = MetaData()
|
||||
|
||||
async def create_table(
|
||||
|
@ -78,13 +79,14 @@ class SqliteSqlStoreImpl(SqlStore):
|
|||
|
||||
# Create the table in the database if it doesn't exist
|
||||
# checkfirst=True ensures it doesn't try to recreate if it's already there
|
||||
async with self.engine.begin() as conn:
|
||||
engine = create_async_engine(self.config.engine_str)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.execute(self.metadata.tables[table].insert(), data)
|
||||
await conn.commit()
|
||||
async with self.async_session() as session:
|
||||
await session.execute(self.metadata.tables[table].insert(), data)
|
||||
await session.commit()
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
|
@ -93,7 +95,7 @@ class SqliteSqlStoreImpl(SqlStore):
|
|||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self.engine.begin() as conn:
|
||||
async with self.async_session() as session:
|
||||
query = select(self.metadata.tables[table])
|
||||
if where:
|
||||
for key, value in where.items():
|
||||
|
@ -117,7 +119,7 @@ class SqliteSqlStoreImpl(SqlStore):
|
|||
query = query.order_by(self.metadata.tables[table].c[name].desc())
|
||||
else:
|
||||
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
|
||||
result = await conn.execute(query)
|
||||
result = await session.execute(query)
|
||||
if result.rowcount == 0:
|
||||
return []
|
||||
return [dict(row._mapping) for row in result]
|
||||
|
@ -142,20 +144,20 @@ class SqliteSqlStoreImpl(SqlStore):
|
|||
if not where:
|
||||
raise ValueError("where is required for update")
|
||||
|
||||
async with self.engine.begin() as conn:
|
||||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].update()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
await conn.execute(stmt, data)
|
||||
await conn.commit()
|
||||
await session.execute(stmt, data)
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
||||
if not where:
|
||||
raise ValueError("where is required for delete")
|
||||
|
||||
async with self.engine.begin() as conn:
|
||||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].delete()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
await conn.execute(stmt)
|
||||
await conn.commit()
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
|
@ -21,7 +22,18 @@ class SqlStoreType(Enum):
|
|||
postgres = "postgres"
|
||||
|
||||
|
||||
class SqliteSqlStoreConfig(BaseModel):
|
||||
class SqlAlchemySqlStoreConfig(BaseModel):
|
||||
@property
|
||||
@abstractmethod
|
||||
def engine_str(self) -> str: ...
|
||||
|
||||
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["sqlalchemy[asyncio]"]
|
||||
|
||||
|
||||
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal["sqlite"] = SqlStoreType.sqlite.value
|
||||
db_path: str = Field(
|
||||
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
|
@ -39,18 +51,26 @@ class SqliteSqlStoreConfig(BaseModel):
|
|||
db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
|
||||
)
|
||||
|
||||
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["sqlalchemy[asyncio]"]
|
||||
return super().pip_packages + ["aiosqlite"]
|
||||
|
||||
|
||||
class PostgresSqlStoreConfig(BaseModel):
|
||||
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal["postgres"] = SqlStoreType.postgres.value
|
||||
host: str = "localhost"
|
||||
port: str = "5432"
|
||||
db: str = "llamastack"
|
||||
user: str
|
||||
password: str | None = None
|
||||
|
||||
@property
|
||||
def engine_str(self) -> str:
|
||||
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
raise NotImplementedError("Postgres is not implemented yet")
|
||||
return super().pip_packages + ["asyncpg"]
|
||||
|
||||
|
||||
SqlStoreConfig = Annotated[
|
||||
|
@ -60,12 +80,10 @@ SqlStoreConfig = Annotated[
|
|||
|
||||
|
||||
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
|
||||
if config.type == SqlStoreType.sqlite.value:
|
||||
from .sqlite.sqlite import SqliteSqlStoreImpl
|
||||
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
|
||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
|
||||
impl = SqliteSqlStoreImpl(config)
|
||||
elif config.type == SqlStoreType.postgres.value:
|
||||
raise NotImplementedError("Postgres is not implemented yet")
|
||||
impl = SqlAlchemySqlStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown sqlstore type {config.type}")
|
||||
|
||||
|
|
|
@ -30,4 +30,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -30,4 +30,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -31,4 +31,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -31,5 +31,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -32,5 +32,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -27,4 +27,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -30,5 +30,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -31,5 +31,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -31,4 +31,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -30,5 +30,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -25,5 +25,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -33,5 +33,5 @@ distribution_spec:
|
|||
- remote::wolfram-alpha
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -34,4 +34,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -32,5 +32,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
7
llama_stack/templates/postgres-demo/__init__.py
Normal file
7
llama_stack/templates/postgres-demo/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .postgres_demo import get_distribution_template # noqa: F401
|
24
llama_stack/templates/postgres-demo/build.yaml
Normal file
24
llama_stack/templates/postgres-demo/build.yaml
Normal file
|
@ -0,0 +1,24 @@
|
|||
version: '2'
|
||||
distribution_spec:
|
||||
description: Quick start template for running Llama Stack with several popular providers
|
||||
providers:
|
||||
inference:
|
||||
- remote::fireworks
|
||||
- remote::vllm
|
||||
vector_io:
|
||||
- remote::chromadb
|
||||
safety:
|
||||
- inline::llama-guard
|
||||
agents:
|
||||
- inline::meta-reference
|
||||
telemetry:
|
||||
- inline::meta-reference
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- asyncpg
|
||||
- sqlalchemy[asyncio]
|
164
llama_stack/templates/postgres-demo/postgres_demo.py
Normal file
164
llama_stack/templates/postgres-demo/postgres_demo.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.fireworks.models import (
|
||||
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
from llama_stack.templates.template import (
|
||||
DistributionTemplate,
|
||||
RunConfigSettings,
|
||||
get_model_registry,
|
||||
)
|
||||
|
||||
|
||||
def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:
|
||||
# in this template, we allow each API key to be optional
|
||||
providers = [
|
||||
(
|
||||
"fireworks",
|
||||
FIREWORKS_MODEL_ENTRIES,
|
||||
FireworksImplConfig.sample_run_config(api_key="${env.FIREWORKS_API_KEY:}"),
|
||||
),
|
||||
]
|
||||
inference_providers = []
|
||||
available_models = {}
|
||||
for provider_id, model_entries, config in providers:
|
||||
inference_providers.append(
|
||||
Provider(
|
||||
provider_id=provider_id,
|
||||
provider_type=f"remote::{provider_id}",
|
||||
config=config,
|
||||
)
|
||||
)
|
||||
available_models[provider_id] = model_entries
|
||||
inference_providers.append(
|
||||
Provider(
|
||||
provider_id="vllm-inference",
|
||||
provider_type="remote::vllm",
|
||||
config=VLLMInferenceAdapterConfig.sample_run_config(
|
||||
url="${env.VLLM_URL:http://localhost:8000/v1}",
|
||||
),
|
||||
)
|
||||
)
|
||||
return inference_providers, available_models
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
inference_providers, available_models = get_inference_providers()
|
||||
providers = {
|
||||
"inference": ([p.provider_type for p in inference_providers]),
|
||||
"vector_io": ["remote::chromadb"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::rag-runtime",
|
||||
"remote::model-context-protocol",
|
||||
],
|
||||
}
|
||||
name = "postgres-demo"
|
||||
|
||||
vector_io_providers = [
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_CHROMADB+chromadb}",
|
||||
provider_type="remote::chromadb",
|
||||
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
|
||||
),
|
||||
]
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
]
|
||||
|
||||
default_models = get_model_registry(available_models)
|
||||
default_models.append(
|
||||
ModelInput(
|
||||
model_id="${env.INFERENCE_MODEL}",
|
||||
provider_id="vllm-inference",
|
||||
)
|
||||
)
|
||||
postgres_config = {
|
||||
"type": "postgres",
|
||||
"host": "${env.POSTGRES_HOST:localhost}",
|
||||
"port": "${env.POSTGRES_PORT:5432}",
|
||||
"db": "${env.POSTGRES_DB:llamastack}",
|
||||
"user": "${env.POSTGRES_USER:llamastack}",
|
||||
"password": "${env.POSTGRES_PASSWORD:llamastack}",
|
||||
}
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
description="Quick start template for running Llama Stack with several popular providers",
|
||||
container_image=None,
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": inference_providers,
|
||||
"vector_io": vector_io_providers,
|
||||
"agents": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
config=dict(
|
||||
persistence_store=postgres_config,
|
||||
responses_store=postgres_config,
|
||||
),
|
||||
)
|
||||
],
|
||||
"telemetry": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
config=dict(
|
||||
service_name="${env.OTEL_SERVICE_NAME:}",
|
||||
sinks="${env.TELEMETRY_SINKS:console}",
|
||||
),
|
||||
)
|
||||
],
|
||||
},
|
||||
default_models=default_models,
|
||||
default_tool_groups=default_tool_groups,
|
||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||
metadata_store=PostgresKVStoreConfig.model_validate(postgres_config),
|
||||
inference_store=PostgresSqlStoreConfig.model_validate(postgres_config),
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"8321",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"FIREWORKS_API_KEY": (
|
||||
"",
|
||||
"Fireworks API Key",
|
||||
),
|
||||
},
|
||||
)
|
224
llama_stack/templates/postgres-demo/run.yaml
Normal file
224
llama_stack/templates/postgres-demo/run.yaml
Normal file
|
@ -0,0 +1,224 @@
|
|||
version: '2'
|
||||
image_name: postgres-demo
|
||||
apis:
|
||||
- agents
|
||||
- inference
|
||||
- safety
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:}
|
||||
- provider_id: vllm-inference
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:http://localhost:8000/v1}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:true}
|
||||
vector_io:
|
||||
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:}
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:localhost}
|
||||
port: ${env.POSTGRES_PORT:5432}
|
||||
db: ${env.POSTGRES_DB:llamastack}
|
||||
user: ${env.POSTGRES_USER:llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||
responses_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:localhost}
|
||||
port: ${env.POSTGRES_PORT:5432}
|
||||
db: ${env.POSTGRES_DB:llamastack}
|
||||
user: ${env.POSTGRES_USER:llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:localhost}
|
||||
port: ${env.POSTGRES_PORT:5432}
|
||||
db: ${env.POSTGRES_DB:llamastack}
|
||||
user: ${env.POSTGRES_USER:llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||
table_name: llamastack_kvstore
|
||||
inference_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:localhost}
|
||||
port: ${env.POSTGRES_PORT:5432}
|
||||
db: ${env.POSTGRES_DB:llamastack}
|
||||
user: ${env.POSTGRES_USER:llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-guard-3-8b
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-Guard-3-8B
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-guard-3-11b-vision
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama4-scout-instruct-basic
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 768
|
||||
context_length: 8192
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
provider_id: fireworks
|
||||
provider_model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
model_type: embedding
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: vllm-inference
|
||||
model_type: llm
|
||||
shields:
|
||||
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
|
@ -32,5 +32,5 @@ distribution_spec:
|
|||
- remote::wolfram-alpha
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -23,4 +23,5 @@ distribution_spec:
|
|||
- remote::wolfram-alpha
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -36,4 +36,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -28,8 +28,8 @@ from llama_stack.distribution.datatypes import (
|
|||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
|
||||
|
||||
def get_model_registry(
|
||||
|
@ -64,6 +64,8 @@ class RunConfigSettings(BaseModel):
|
|||
default_tool_groups: list[ToolGroupInput] | None = None
|
||||
default_datasets: list[DatasetInput] | None = None
|
||||
default_benchmarks: list[BenchmarkInput] | None = None
|
||||
metadata_store: KVStoreConfig | None = None
|
||||
inference_store: SqlStoreConfig | None = None
|
||||
|
||||
def run_config(
|
||||
self,
|
||||
|
@ -114,11 +116,13 @@ class RunConfigSettings(BaseModel):
|
|||
container_image=container_image,
|
||||
apis=apis,
|
||||
providers=provider_configs,
|
||||
metadata_store=SqliteKVStoreConfig.sample_run_config(
|
||||
metadata_store=self.metadata_store
|
||||
or SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=f"~/.llama/distributions/{name}",
|
||||
db_name="registry.db",
|
||||
),
|
||||
inference_store=SqliteSqlStoreConfig.sample_run_config(
|
||||
inference_store=self.inference_store
|
||||
or SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=f"~/.llama/distributions/{name}",
|
||||
db_name="inference_store.db",
|
||||
),
|
||||
|
@ -164,7 +168,7 @@ class DistributionTemplate(BaseModel):
|
|||
providers=self.providers,
|
||||
),
|
||||
image_type="conda", # default to conda, can be overridden
|
||||
additional_pip_packages=additional_pip_packages,
|
||||
additional_pip_packages=sorted(set(additional_pip_packages)),
|
||||
)
|
||||
|
||||
def generate_markdown_docs(self) -> str:
|
||||
|
|
|
@ -31,5 +31,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -32,5 +32,5 @@ distribution_spec:
|
|||
- remote::wolfram-alpha
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -36,4 +36,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -31,4 +31,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -29,4 +29,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -268,9 +268,9 @@ def test_openai_chat_completion_streaming_with_n(compat_client, client_with_mode
|
|||
False,
|
||||
],
|
||||
)
|
||||
def test_inference_store(openai_client, client_with_models, text_model_id, stream):
|
||||
def test_inference_store(compat_client, client_with_models, text_model_id, stream):
|
||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||
client = openai_client
|
||||
client = compat_client
|
||||
# make a chat completion
|
||||
message = "Hello, world!"
|
||||
response = client.chat.completions.create(
|
||||
|
@ -301,9 +301,14 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
|
|||
|
||||
retrieved_response = client.chat.completions.retrieve(response_id)
|
||||
assert retrieved_response.id == response_id
|
||||
assert retrieved_response.input_messages[0]["content"] == message, retrieved_response
|
||||
assert retrieved_response.choices[0].message.content == content, retrieved_response
|
||||
|
||||
input_content = (
|
||||
getattr(retrieved_response.input_messages[0], "content", None)
|
||||
or retrieved_response.input_messages[0]["content"]
|
||||
)
|
||||
assert input_content == message, retrieved_response
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"stream",
|
||||
|
@ -312,9 +317,9 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
|
|||
False,
|
||||
],
|
||||
)
|
||||
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream):
|
||||
def test_inference_store_tool_calls(compat_client, client_with_models, text_model_id, stream):
|
||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||
client = openai_client
|
||||
client = compat_client
|
||||
# make a chat completion
|
||||
message = "What's the weather in Tokyo? Use the get_weather function to get the weather."
|
||||
response = client.chat.completions.create(
|
||||
|
@ -361,7 +366,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
|
|||
|
||||
retrieved_response = client.chat.completions.retrieve(response_id)
|
||||
assert retrieved_response.id == response_id
|
||||
assert retrieved_response.input_messages[0]["content"] == message
|
||||
input_content = (
|
||||
getattr(retrieved_response.input_messages[0], "content", None)
|
||||
or retrieved_response.input_messages[0]["content"]
|
||||
)
|
||||
assert input_content == message, retrieved_response
|
||||
tool_calls = retrieved_response.choices[0].message.tool_calls
|
||||
# sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called
|
||||
if tool_calls:
|
||||
|
|
|
@ -9,7 +9,7 @@ from tempfile import TemporaryDirectory
|
|||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlite.sqlite import SqliteSqlStoreImpl
|
||||
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
|
@ -17,7 +17,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
|||
async def test_sqlite_sqlstore():
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_name = "test.db"
|
||||
sqlstore = SqliteSqlStoreImpl(
|
||||
sqlstore = SqlAlchemySqlStoreImpl(
|
||||
SqliteSqlStoreConfig(
|
||||
db_path=tmp_dir + "/" + db_name,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue