Merge branch 'embeddings' of https://github.com/hardikjshah/llama-stack into embeddings

This commit is contained in:
Hardik Shah 2025-05-30 12:25:07 -07:00
commit 535e55d7dd
37 changed files with 605 additions and 66 deletions

View file

@ -292,12 +292,12 @@ class OpenAIResponsesImpl:
async def _store_response( async def _store_response(
self, self,
response: OpenAIResponseObject, response: OpenAIResponseObject,
original_input: str | list[OpenAIResponseInput], input: str | list[OpenAIResponseInput],
) -> None: ) -> None:
new_input_id = f"msg_{uuid.uuid4()}" new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(original_input, str): if isinstance(input, str):
# synthesize a message from the input string # synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=original_input) input_content = OpenAIResponseInputMessageContentText(text=input)
input_content_item = OpenAIResponseMessage( input_content_item = OpenAIResponseMessage(
role="user", role="user",
content=[input_content], content=[input_content],
@ -307,7 +307,7 @@ class OpenAIResponsesImpl:
else: else:
# we already have a list of messages # we already have a list of messages
input_items_data = [] input_items_data = []
for input_item in original_input: for input_item in input:
if isinstance(input_item, OpenAIResponseMessage): if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing # These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump() input_item_dict = input_item.model_dump()
@ -334,7 +334,6 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None, tools: list[OpenAIResponseInputTool] | None = None,
): ):
stream = False if stream is None else stream stream = False if stream is None else stream
original_input = input # Keep reference for storage
output_messages: list[OpenAIResponseOutput] = [] output_messages: list[OpenAIResponseOutput] = []
@ -372,7 +371,7 @@ class OpenAIResponsesImpl:
inference_result=inference_result, inference_result=inference_result,
ctx=ctx, ctx=ctx,
output_messages=output_messages, output_messages=output_messages,
original_input=original_input, input=input,
model=model, model=model,
store=store, store=store,
tools=tools, tools=tools,
@ -382,7 +381,7 @@ class OpenAIResponsesImpl:
inference_result=inference_result, inference_result=inference_result,
ctx=ctx, ctx=ctx,
output_messages=output_messages, output_messages=output_messages,
original_input=original_input, input=input,
model=model, model=model,
store=store, store=store,
tools=tools, tools=tools,
@ -393,7 +392,7 @@ class OpenAIResponsesImpl:
inference_result: Any, inference_result: Any,
ctx: ChatCompletionContext, ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput], output_messages: list[OpenAIResponseOutput],
original_input: str | list[OpenAIResponseInput], input: str | list[OpenAIResponseInput],
model: str, model: str,
store: bool | None, store: bool | None,
tools: list[OpenAIResponseInputTool] | None, tools: list[OpenAIResponseInputTool] | None,
@ -423,7 +422,7 @@ class OpenAIResponsesImpl:
if store: if store:
await self._store_response( await self._store_response(
response=response, response=response,
original_input=original_input, input=input,
) )
return response return response
@ -433,7 +432,7 @@ class OpenAIResponsesImpl:
inference_result: Any, inference_result: Any,
ctx: ChatCompletionContext, ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput], output_messages: list[OpenAIResponseOutput],
original_input: str | list[OpenAIResponseInput], input: str | list[OpenAIResponseInput],
model: str, model: str,
store: bool | None, store: bool | None,
tools: list[OpenAIResponseInputTool] | None, tools: list[OpenAIResponseInputTool] | None,
@ -544,7 +543,7 @@ class OpenAIResponsesImpl:
if store: if store:
await self._store_response( await self._store_response(
response=final_response, response=final_response,
original_input=original_input, input=input,
) )
# Emit response.completed # Emit response.completed

View file

@ -75,7 +75,9 @@ class PromptGuardShield:
self.temperature = temperature self.temperature = temperature
self.threshold = threshold self.threshold = threshold
self.device = "cuda" self.device = "cpu"
if torch.cuda.is_available():
self.device = "cuda"
# load model and tokenizer # load model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.tokenizer = AutoTokenizer.from_pretrained(model_dir)

View file

@ -218,7 +218,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
"json_schema": { "json_schema": {
"name": name, "name": name,
"schema": fmt, "schema": fmt,
"strict": True, "strict": False,
}, },
} }
if request.tools: if request.tools:

View file

@ -65,7 +65,7 @@ class SqliteKVStoreConfig(CommonConfig):
class PostgresKVStoreConfig(CommonConfig): class PostgresKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
host: str = "localhost" host: str = "localhost"
port: int = 5432 port: str = "5432"
db: str = "llamastack" db: str = "llamastack"
user: str user: str
password: str | None = None password: str | None = None

View file

@ -19,10 +19,10 @@ from sqlalchemy import (
Text, Text,
select, select,
) )
from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from ..api import ColumnDefinition, ColumnType, SqlStore from .api import ColumnDefinition, ColumnType, SqlStore
from ..sqlstore import SqliteSqlStoreConfig from .sqlstore import SqlAlchemySqlStoreConfig
TYPE_MAPPING: dict[ColumnType, Any] = { TYPE_MAPPING: dict[ColumnType, Any] = {
ColumnType.INTEGER: Integer, ColumnType.INTEGER: Integer,
@ -35,9 +35,10 @@ TYPE_MAPPING: dict[ColumnType, Any] = {
} }
class SqliteSqlStoreImpl(SqlStore): class SqlAlchemySqlStoreImpl(SqlStore):
def __init__(self, config: SqliteSqlStoreConfig): def __init__(self, config: SqlAlchemySqlStoreConfig):
self.engine = create_async_engine(config.engine_str) self.config = config
self.async_session = async_sessionmaker(create_async_engine(config.engine_str))
self.metadata = MetaData() self.metadata = MetaData()
async def create_table( async def create_table(
@ -78,13 +79,14 @@ class SqliteSqlStoreImpl(SqlStore):
# Create the table in the database if it doesn't exist # Create the table in the database if it doesn't exist
# checkfirst=True ensures it doesn't try to recreate if it's already there # checkfirst=True ensures it doesn't try to recreate if it's already there
async with self.engine.begin() as conn: engine = create_async_engine(self.config.engine_str)
async with engine.begin() as conn:
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
async def insert(self, table: str, data: Mapping[str, Any]) -> None: async def insert(self, table: str, data: Mapping[str, Any]) -> None:
async with self.engine.begin() as conn: async with self.async_session() as session:
await conn.execute(self.metadata.tables[table].insert(), data) await session.execute(self.metadata.tables[table].insert(), data)
await conn.commit() await session.commit()
async def fetch_all( async def fetch_all(
self, self,
@ -93,7 +95,7 @@ class SqliteSqlStoreImpl(SqlStore):
limit: int | None = None, limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
async with self.engine.begin() as conn: async with self.async_session() as session:
query = select(self.metadata.tables[table]) query = select(self.metadata.tables[table])
if where: if where:
for key, value in where.items(): for key, value in where.items():
@ -117,7 +119,7 @@ class SqliteSqlStoreImpl(SqlStore):
query = query.order_by(self.metadata.tables[table].c[name].desc()) query = query.order_by(self.metadata.tables[table].c[name].desc())
else: else:
raise ValueError(f"Invalid order '{order_type}' for column '{name}'") raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
result = await conn.execute(query) result = await session.execute(query)
if result.rowcount == 0: if result.rowcount == 0:
return [] return []
return [dict(row._mapping) for row in result] return [dict(row._mapping) for row in result]
@ -142,20 +144,20 @@ class SqliteSqlStoreImpl(SqlStore):
if not where: if not where:
raise ValueError("where is required for update") raise ValueError("where is required for update")
async with self.engine.begin() as conn: async with self.async_session() as session:
stmt = self.metadata.tables[table].update() stmt = self.metadata.tables[table].update()
for key, value in where.items(): for key, value in where.items():
stmt = stmt.where(self.metadata.tables[table].c[key] == value) stmt = stmt.where(self.metadata.tables[table].c[key] == value)
await conn.execute(stmt, data) await session.execute(stmt, data)
await conn.commit() await session.commit()
async def delete(self, table: str, where: Mapping[str, Any]) -> None: async def delete(self, table: str, where: Mapping[str, Any]) -> None:
if not where: if not where:
raise ValueError("where is required for delete") raise ValueError("where is required for delete")
async with self.engine.begin() as conn: async with self.async_session() as session:
stmt = self.metadata.tables[table].delete() stmt = self.metadata.tables[table].delete()
for key, value in where.items(): for key, value in where.items():
stmt = stmt.where(self.metadata.tables[table].c[key] == value) stmt = stmt.where(self.metadata.tables[table].c[key] == value)
await conn.execute(stmt) await session.execute(stmt)
await conn.commit() await session.commit()

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from abc import abstractmethod
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Annotated, Literal from typing import Annotated, Literal
@ -21,7 +22,18 @@ class SqlStoreType(Enum):
postgres = "postgres" postgres = "postgres"
class SqliteSqlStoreConfig(BaseModel): class SqlAlchemySqlStoreConfig(BaseModel):
@property
@abstractmethod
def engine_str(self) -> str: ...
# TODO: move this when we have a better way to specify dependencies with internal APIs
@property
def pip_packages(self) -> list[str]:
return ["sqlalchemy[asyncio]"]
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["sqlite"] = SqlStoreType.sqlite.value type: Literal["sqlite"] = SqlStoreType.sqlite.value
db_path: str = Field( db_path: str = Field(
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
@ -39,18 +51,26 @@ class SqliteSqlStoreConfig(BaseModel):
db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name, db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
) )
# TODO: move this when we have a better way to specify dependencies with internal APIs
@property @property
def pip_packages(self) -> list[str]: def pip_packages(self) -> list[str]:
return ["sqlalchemy[asyncio]"] return super().pip_packages + ["aiosqlite"]
class PostgresSqlStoreConfig(BaseModel): class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["postgres"] = SqlStoreType.postgres.value type: Literal["postgres"] = SqlStoreType.postgres.value
host: str = "localhost"
port: str = "5432"
db: str = "llamastack"
user: str
password: str | None = None
@property
def engine_str(self) -> str:
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
@property @property
def pip_packages(self) -> list[str]: def pip_packages(self) -> list[str]:
raise NotImplementedError("Postgres is not implemented yet") return super().pip_packages + ["asyncpg"]
SqlStoreConfig = Annotated[ SqlStoreConfig = Annotated[
@ -60,12 +80,10 @@ SqlStoreConfig = Annotated[
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore: def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
if config.type == SqlStoreType.sqlite.value: if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
from .sqlite.sqlite import SqliteSqlStoreImpl from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
impl = SqliteSqlStoreImpl(config) impl = SqlAlchemySqlStoreImpl(config)
elif config.type == SqlStoreType.postgres.value:
raise NotImplementedError("Postgres is not implemented yet")
else: else:
raise ValueError(f"Unknown sqlstore type {config.type}") raise ValueError(f"Unknown sqlstore type {config.type}")

View file

@ -30,4 +30,5 @@ distribution_spec:
- remote::model-context-protocol - remote::model-context-protocol
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -30,4 +30,5 @@ distribution_spec:
- inline::rag-runtime - inline::rag-runtime
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -31,4 +31,5 @@ distribution_spec:
- remote::model-context-protocol - remote::model-context-protocol
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -31,5 +31,5 @@ distribution_spec:
- inline::rag-runtime - inline::rag-runtime
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- sqlalchemy[asyncio] - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -32,5 +32,5 @@ distribution_spec:
- remote::model-context-protocol - remote::model-context-protocol
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- sqlalchemy[asyncio] - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -27,4 +27,5 @@ distribution_spec:
- inline::rag-runtime - inline::rag-runtime
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -30,5 +30,5 @@ distribution_spec:
- remote::model-context-protocol - remote::model-context-protocol
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- sqlalchemy[asyncio] - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -31,5 +31,5 @@ distribution_spec:
- remote::model-context-protocol - remote::model-context-protocol
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- sqlalchemy[asyncio] - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -31,4 +31,5 @@ distribution_spec:
- remote::model-context-protocol - remote::model-context-protocol
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -30,5 +30,5 @@ distribution_spec:
- remote::model-context-protocol - remote::model-context-protocol
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- sqlalchemy[asyncio] - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -25,5 +25,5 @@ distribution_spec:
- inline::rag-runtime - inline::rag-runtime
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- sqlalchemy[asyncio] - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -33,5 +33,5 @@ distribution_spec:
- remote::wolfram-alpha - remote::wolfram-alpha
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- sqlalchemy[asyncio] - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -34,4 +34,5 @@ distribution_spec:
- remote::model-context-protocol - remote::model-context-protocol
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -32,5 +32,5 @@ distribution_spec:
- remote::model-context-protocol - remote::model-context-protocol
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- sqlalchemy[asyncio] - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .postgres_demo import get_distribution_template # noqa: F401

View file

@ -0,0 +1,24 @@
version: '2'
distribution_spec:
description: Quick start template for running Llama Stack with several popular providers
providers:
inference:
- remote::fireworks
- remote::vllm
vector_io:
- remote::chromadb
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::rag-runtime
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- asyncpg
- sqlalchemy[asyncio]

View file

@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
from llama_stack.providers.remote.inference.fireworks.models import (
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
get_model_registry,
)
def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:
# in this template, we allow each API key to be optional
providers = [
(
"fireworks",
FIREWORKS_MODEL_ENTRIES,
FireworksImplConfig.sample_run_config(api_key="${env.FIREWORKS_API_KEY:}"),
),
]
inference_providers = []
available_models = {}
for provider_id, model_entries, config in providers:
inference_providers.append(
Provider(
provider_id=provider_id,
provider_type=f"remote::{provider_id}",
config=config,
)
)
available_models[provider_id] = model_entries
inference_providers.append(
Provider(
provider_id="vllm-inference",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig.sample_run_config(
url="${env.VLLM_URL:http://localhost:8000/v1}",
),
)
)
return inference_providers, available_models
def get_distribution_template() -> DistributionTemplate:
inference_providers, available_models = get_inference_providers()
providers = {
"inference": ([p.provider_type for p in inference_providers]),
"vector_io": ["remote::chromadb"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::rag-runtime",
"remote::model-context-protocol",
],
}
name = "postgres-demo"
vector_io_providers = [
Provider(
provider_id="${env.ENABLE_CHROMADB+chromadb}",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
),
]
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
default_models = get_model_registry(available_models)
default_models.append(
ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="vllm-inference",
)
)
postgres_config = {
"type": "postgres",
"host": "${env.POSTGRES_HOST:localhost}",
"port": "${env.POSTGRES_PORT:5432}",
"db": "${env.POSTGRES_DB:llamastack}",
"user": "${env.POSTGRES_USER:llamastack}",
"password": "${env.POSTGRES_PASSWORD:llamastack}",
}
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Quick start template for running Llama Stack with several popular providers",
container_image=None,
template_path=None,
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": inference_providers,
"vector_io": vector_io_providers,
"agents": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
config=dict(
persistence_store=postgres_config,
responses_store=postgres_config,
),
)
],
"telemetry": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
config=dict(
service_name="${env.OTEL_SERVICE_NAME:}",
sinks="${env.TELEMETRY_SINKS:console}",
),
)
],
},
default_models=default_models,
default_tool_groups=default_tool_groups,
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
metadata_store=PostgresKVStoreConfig.model_validate(postgres_config),
inference_store=PostgresSqlStoreConfig.model_validate(postgres_config),
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"8321",
"Port for the Llama Stack distribution server",
),
"FIREWORKS_API_KEY": (
"",
"Fireworks API Key",
),
},
)

View file

@ -0,0 +1,224 @@
version: '2'
image_name: postgres-demo
apis:
- agents
- inference
- safety
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:}
- provider_id: vllm-inference
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
api_token: ${env.VLLM_API_TOKEN:fake}
tls_verify: ${env.VLLM_TLS_VERIFY:true}
vector_io:
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
responses_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
table_name: llamastack_kvstore
inference_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
models:
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-guard-3-8b
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-8B
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-guard-3-11b-vision
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-11B-Vision
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama4-scout-instruct-basic
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
model_type: llm
- metadata:
embedding_dimension: 768
context_length: 8192
model_id: nomic-ai/nomic-embed-text-v1.5
provider_id: fireworks
provider_model_id: nomic-ai/nomic-embed-text-v1.5
model_type: embedding
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference
model_type: llm
shields:
- shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -32,5 +32,5 @@ distribution_spec:
- remote::wolfram-alpha - remote::wolfram-alpha
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- sqlalchemy[asyncio] - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -23,4 +23,5 @@ distribution_spec:
- remote::wolfram-alpha - remote::wolfram-alpha
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -36,4 +36,5 @@ distribution_spec:
- remote::model-context-protocol - remote::model-context-protocol
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -28,8 +28,8 @@ from llama_stack.distribution.datatypes import (
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
def get_model_registry( def get_model_registry(
@ -64,6 +64,8 @@ class RunConfigSettings(BaseModel):
default_tool_groups: list[ToolGroupInput] | None = None default_tool_groups: list[ToolGroupInput] | None = None
default_datasets: list[DatasetInput] | None = None default_datasets: list[DatasetInput] | None = None
default_benchmarks: list[BenchmarkInput] | None = None default_benchmarks: list[BenchmarkInput] | None = None
metadata_store: KVStoreConfig | None = None
inference_store: SqlStoreConfig | None = None
def run_config( def run_config(
self, self,
@ -114,11 +116,13 @@ class RunConfigSettings(BaseModel):
container_image=container_image, container_image=container_image,
apis=apis, apis=apis,
providers=provider_configs, providers=provider_configs,
metadata_store=SqliteKVStoreConfig.sample_run_config( metadata_store=self.metadata_store
or SqliteKVStoreConfig.sample_run_config(
__distro_dir__=f"~/.llama/distributions/{name}", __distro_dir__=f"~/.llama/distributions/{name}",
db_name="registry.db", db_name="registry.db",
), ),
inference_store=SqliteSqlStoreConfig.sample_run_config( inference_store=self.inference_store
or SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=f"~/.llama/distributions/{name}", __distro_dir__=f"~/.llama/distributions/{name}",
db_name="inference_store.db", db_name="inference_store.db",
), ),
@ -164,7 +168,7 @@ class DistributionTemplate(BaseModel):
providers=self.providers, providers=self.providers,
), ),
image_type="conda", # default to conda, can be overridden image_type="conda", # default to conda, can be overridden
additional_pip_packages=additional_pip_packages, additional_pip_packages=sorted(set(additional_pip_packages)),
) )
def generate_markdown_docs(self) -> str: def generate_markdown_docs(self) -> str:

View file

@ -31,5 +31,5 @@ distribution_spec:
- remote::model-context-protocol - remote::model-context-protocol
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- sqlalchemy[asyncio] - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -32,5 +32,5 @@ distribution_spec:
- remote::wolfram-alpha - remote::wolfram-alpha
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- sqlalchemy[asyncio] - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -36,4 +36,5 @@ distribution_spec:
- remote::model-context-protocol - remote::model-context-protocol
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -31,4 +31,5 @@ distribution_spec:
- remote::model-context-protocol - remote::model-context-protocol
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -29,4 +29,5 @@ distribution_spec:
- remote::model-context-protocol - remote::model-context-protocol
image_type: conda image_type: conda
additional_pip_packages: additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -107,6 +107,13 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[
return None, [] return None, []
def pre_import_templates(template_dirs: list[Path]) -> None:
# Pre-import all template modules to avoid deadlocks.
for template_dir in template_dirs:
module_name = f"llama_stack.templates.{template_dir.name}"
importlib.import_module(module_name)
def main(): def main():
templates_dir = REPO_ROOT / "llama_stack" / "templates" templates_dir = REPO_ROOT / "llama_stack" / "templates"
change_tracker = ChangedPathTracker() change_tracker = ChangedPathTracker()
@ -118,6 +125,8 @@ def main():
template_dirs = list(find_template_dirs(templates_dir)) template_dirs = list(find_template_dirs(templates_dir))
task = progress.add_task("Processing distribution templates...", total=len(template_dirs)) task = progress.add_task("Processing distribution templates...", total=len(template_dirs))
pre_import_templates(template_dirs)
# Create a partial function with the progress bar # Create a partial function with the progress bar
process_func = partial(process_template, progress=progress, change_tracker=change_tracker) process_func = partial(process_template, progress=progress, change_tracker=change_tracker)

View file

@ -268,9 +268,9 @@ def test_openai_chat_completion_streaming_with_n(compat_client, client_with_mode
False, False,
], ],
) )
def test_inference_store(openai_client, client_with_models, text_model_id, stream): def test_inference_store(compat_client, client_with_models, text_model_id, stream):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
client = openai_client client = compat_client
# make a chat completion # make a chat completion
message = "Hello, world!" message = "Hello, world!"
response = client.chat.completions.create( response = client.chat.completions.create(
@ -301,9 +301,14 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
retrieved_response = client.chat.completions.retrieve(response_id) retrieved_response = client.chat.completions.retrieve(response_id)
assert retrieved_response.id == response_id assert retrieved_response.id == response_id
assert retrieved_response.input_messages[0]["content"] == message, retrieved_response
assert retrieved_response.choices[0].message.content == content, retrieved_response assert retrieved_response.choices[0].message.content == content, retrieved_response
input_content = (
getattr(retrieved_response.input_messages[0], "content", None)
or retrieved_response.input_messages[0]["content"]
)
assert input_content == message, retrieved_response
@pytest.mark.parametrize( @pytest.mark.parametrize(
"stream", "stream",
@ -312,9 +317,9 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
False, False,
], ],
) )
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream): def test_inference_store_tool_calls(compat_client, client_with_models, text_model_id, stream):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
client = openai_client client = compat_client
# make a chat completion # make a chat completion
message = "What's the weather in Tokyo? Use the get_weather function to get the weather." message = "What's the weather in Tokyo? Use the get_weather function to get the weather."
response = client.chat.completions.create( response = client.chat.completions.create(
@ -361,7 +366,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
retrieved_response = client.chat.completions.retrieve(response_id) retrieved_response = client.chat.completions.retrieve(response_id)
assert retrieved_response.id == response_id assert retrieved_response.id == response_id
assert retrieved_response.input_messages[0]["content"] == message input_content = (
getattr(retrieved_response.input_messages[0], "content", None)
or retrieved_response.input_messages[0]["content"]
)
assert input_content == message, retrieved_response
tool_calls = retrieved_response.choices[0].message.tool_calls tool_calls = retrieved_response.choices[0].message.tool_calls
# sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called # sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called
if tool_calls: if tool_calls:

View file

@ -628,3 +628,69 @@ async def test_responses_store_list_input_items_logic():
result = await responses_store.list_response_input_items("resp_123", limit=0, order=Order.asc) result = await responses_store.list_response_input_items("resp_123", limit=0, order=Order.asc)
assert result.object == "list" assert result.object == "list"
assert len(result.data) == 0 # Should return no items assert len(result.data) == 0 # Should return no items
@pytest.mark.asyncio
async def test_store_response_uses_rehydrated_input_with_previous_response(
openai_responses_impl, mock_responses_store, mock_inference_api
):
"""Test that _store_response uses the full re-hydrated input (including previous responses)
rather than just the original input when previous_response_id is provided."""
# Setup - Create a previous response that should be included in the stored input
previous_response = OpenAIResponseObjectWithInput(
id="resp-previous-123",
object="response",
created_at=1234567890,
model="meta-llama/Llama-3.1-8B-Instruct",
status="completed",
input=[
OpenAIResponseMessage(
id="msg-prev-user", role="user", content=[OpenAIResponseInputMessageContentText(text="What is 2+2?")]
)
],
output=[
OpenAIResponseMessage(
id="msg-prev-assistant",
role="assistant",
content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")],
)
],
)
mock_responses_store.get_response_object.return_value = previous_response
current_input = "Now what is 3+3?"
model = "meta-llama/Llama-3.1-8B-Instruct"
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
# Execute - Create response with previous_response_id
result = await openai_responses_impl.create_openai_response(
input=current_input,
model=model,
previous_response_id="resp-previous-123",
store=True,
)
store_call_args = mock_responses_store.store_response_object.call_args
stored_input = store_call_args.kwargs["input"]
# Verify that the stored input contains the full re-hydrated conversation:
# 1. Previous user message
# 2. Previous assistant response
# 3. Current user message
assert len(stored_input) == 3
assert stored_input[0].role == "user"
assert stored_input[0].content[0].text == "What is 2+2?"
assert stored_input[1].role == "assistant"
assert stored_input[1].content[0].text == "2+2 equals 4."
assert stored_input[2].role == "user"
assert stored_input[2].content == "Now what is 3+3?"
# Verify the response itself is correct
assert result.model == model
assert result.status == "completed"

View file

@ -9,7 +9,7 @@ from tempfile import TemporaryDirectory
import pytest import pytest
from llama_stack.providers.utils.sqlstore.api import ColumnType from llama_stack.providers.utils.sqlstore.api import ColumnType
from llama_stack.providers.utils.sqlstore.sqlite.sqlite import SqliteSqlStoreImpl from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@ -17,7 +17,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
async def test_sqlite_sqlstore(): async def test_sqlite_sqlstore():
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_name = "test.db" db_name = "test.db"
sqlstore = SqliteSqlStoreImpl( sqlstore = SqlAlchemySqlStoreImpl(
SqliteSqlStoreConfig( SqliteSqlStoreConfig(
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )