diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 3a56d41ef..1fcb1c461 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -292,12 +292,12 @@ class OpenAIResponsesImpl: async def _store_response( self, response: OpenAIResponseObject, - original_input: str | list[OpenAIResponseInput], + input: str | list[OpenAIResponseInput], ) -> None: new_input_id = f"msg_{uuid.uuid4()}" - if isinstance(original_input, str): + if isinstance(input, str): # synthesize a message from the input string - input_content = OpenAIResponseInputMessageContentText(text=original_input) + input_content = OpenAIResponseInputMessageContentText(text=input) input_content_item = OpenAIResponseMessage( role="user", content=[input_content], @@ -307,7 +307,7 @@ class OpenAIResponsesImpl: else: # we already have a list of messages input_items_data = [] - for input_item in original_input: + for input_item in input: if isinstance(input_item, OpenAIResponseMessage): # These may or may not already have an id, so dump to dict, check for id, and add if missing input_item_dict = input_item.model_dump() @@ -334,7 +334,6 @@ class OpenAIResponsesImpl: tools: list[OpenAIResponseInputTool] | None = None, ): stream = False if stream is None else stream - original_input = input # Keep reference for storage output_messages: list[OpenAIResponseOutput] = [] @@ -372,7 +371,7 @@ class OpenAIResponsesImpl: inference_result=inference_result, ctx=ctx, output_messages=output_messages, - original_input=original_input, + input=input, model=model, store=store, tools=tools, @@ -382,7 +381,7 @@ class OpenAIResponsesImpl: inference_result=inference_result, ctx=ctx, output_messages=output_messages, - original_input=original_input, + input=input, model=model, store=store, tools=tools, @@ -393,7 +392,7 @@ class OpenAIResponsesImpl: inference_result: Any, ctx: ChatCompletionContext, output_messages: list[OpenAIResponseOutput], - original_input: str | list[OpenAIResponseInput], + input: str | list[OpenAIResponseInput], model: str, store: bool | None, tools: list[OpenAIResponseInputTool] | None, @@ -423,7 +422,7 @@ class OpenAIResponsesImpl: if store: await self._store_response( response=response, - original_input=original_input, + input=input, ) return response @@ -433,7 +432,7 @@ class OpenAIResponsesImpl: inference_result: Any, ctx: ChatCompletionContext, output_messages: list[OpenAIResponseOutput], - original_input: str | list[OpenAIResponseInput], + input: str | list[OpenAIResponseInput], model: str, store: bool | None, tools: list[OpenAIResponseInputTool] | None, @@ -544,7 +543,7 @@ class OpenAIResponsesImpl: if store: await self._store_response( response=final_response, - original_input=original_input, + input=input, ) # Emit response.completed diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 56ce8285f..ff87889ea 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -75,7 +75,9 @@ class PromptGuardShield: self.temperature = temperature self.threshold = threshold - self.device = "cuda" + self.device = "cpu" + if torch.cuda.is_available(): + self.device = "cuda" # load model and tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_dir) diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index d182aa1dc..20f863665 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -218,7 +218,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): "json_schema": { "name": name, "schema": fmt, - "strict": True, + "strict": False, }, } if request.tools: diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py index e9aac6e8c..bbb0c5c0a 100644 --- a/llama_stack/providers/utils/kvstore/config.py +++ b/llama_stack/providers/utils/kvstore/config.py @@ -65,7 +65,7 @@ class SqliteKVStoreConfig(CommonConfig): class PostgresKVStoreConfig(CommonConfig): type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value host: str = "localhost" - port: int = 5432 + port: str = "5432" db: str = "llamastack" user: str password: str | None = None diff --git a/llama_stack/providers/utils/sqlstore/sqlite/sqlite.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py similarity index 83% rename from llama_stack/providers/utils/sqlstore/sqlite/sqlite.py rename to llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 0ef5f0fa1..825220679 100644 --- a/llama_stack/providers/utils/sqlstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -19,10 +19,10 @@ from sqlalchemy import ( Text, select, ) -from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine -from ..api import ColumnDefinition, ColumnType, SqlStore -from ..sqlstore import SqliteSqlStoreConfig +from .api import ColumnDefinition, ColumnType, SqlStore +from .sqlstore import SqlAlchemySqlStoreConfig TYPE_MAPPING: dict[ColumnType, Any] = { ColumnType.INTEGER: Integer, @@ -35,9 +35,10 @@ TYPE_MAPPING: dict[ColumnType, Any] = { } -class SqliteSqlStoreImpl(SqlStore): - def __init__(self, config: SqliteSqlStoreConfig): - self.engine = create_async_engine(config.engine_str) +class SqlAlchemySqlStoreImpl(SqlStore): + def __init__(self, config: SqlAlchemySqlStoreConfig): + self.config = config + self.async_session = async_sessionmaker(create_async_engine(config.engine_str)) self.metadata = MetaData() async def create_table( @@ -78,13 +79,14 @@ class SqliteSqlStoreImpl(SqlStore): # Create the table in the database if it doesn't exist # checkfirst=True ensures it doesn't try to recreate if it's already there - async with self.engine.begin() as conn: + engine = create_async_engine(self.config.engine_str) + async with engine.begin() as conn: await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) async def insert(self, table: str, data: Mapping[str, Any]) -> None: - async with self.engine.begin() as conn: - await conn.execute(self.metadata.tables[table].insert(), data) - await conn.commit() + async with self.async_session() as session: + await session.execute(self.metadata.tables[table].insert(), data) + await session.commit() async def fetch_all( self, @@ -93,7 +95,7 @@ class SqliteSqlStoreImpl(SqlStore): limit: int | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, ) -> list[dict[str, Any]]: - async with self.engine.begin() as conn: + async with self.async_session() as session: query = select(self.metadata.tables[table]) if where: for key, value in where.items(): @@ -117,7 +119,7 @@ class SqliteSqlStoreImpl(SqlStore): query = query.order_by(self.metadata.tables[table].c[name].desc()) else: raise ValueError(f"Invalid order '{order_type}' for column '{name}'") - result = await conn.execute(query) + result = await session.execute(query) if result.rowcount == 0: return [] return [dict(row._mapping) for row in result] @@ -142,20 +144,20 @@ class SqliteSqlStoreImpl(SqlStore): if not where: raise ValueError("where is required for update") - async with self.engine.begin() as conn: + async with self.async_session() as session: stmt = self.metadata.tables[table].update() for key, value in where.items(): stmt = stmt.where(self.metadata.tables[table].c[key] == value) - await conn.execute(stmt, data) - await conn.commit() + await session.execute(stmt, data) + await session.commit() async def delete(self, table: str, where: Mapping[str, Any]) -> None: if not where: raise ValueError("where is required for delete") - async with self.engine.begin() as conn: + async with self.async_session() as session: stmt = self.metadata.tables[table].delete() for key, value in where.items(): stmt = stmt.where(self.metadata.tables[table].c[key] == value) - await conn.execute(stmt) - await conn.commit() + await session.execute(stmt) + await session.commit() diff --git a/llama_stack/providers/utils/sqlstore/sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlstore.py index 99f64805f..3091e8f96 100644 --- a/llama_stack/providers/utils/sqlstore/sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlstore.py @@ -5,6 +5,7 @@ # the root directory of this source tree. +from abc import abstractmethod from enum import Enum from pathlib import Path from typing import Annotated, Literal @@ -21,7 +22,18 @@ class SqlStoreType(Enum): postgres = "postgres" -class SqliteSqlStoreConfig(BaseModel): +class SqlAlchemySqlStoreConfig(BaseModel): + @property + @abstractmethod + def engine_str(self) -> str: ... + + # TODO: move this when we have a better way to specify dependencies with internal APIs + @property + def pip_packages(self) -> list[str]: + return ["sqlalchemy[asyncio]"] + + +class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig): type: Literal["sqlite"] = SqlStoreType.sqlite.value db_path: str = Field( default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), @@ -39,18 +51,26 @@ class SqliteSqlStoreConfig(BaseModel): db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name, ) - # TODO: move this when we have a better way to specify dependencies with internal APIs @property def pip_packages(self) -> list[str]: - return ["sqlalchemy[asyncio]"] + return super().pip_packages + ["aiosqlite"] -class PostgresSqlStoreConfig(BaseModel): +class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig): type: Literal["postgres"] = SqlStoreType.postgres.value + host: str = "localhost" + port: str = "5432" + db: str = "llamastack" + user: str + password: str | None = None + + @property + def engine_str(self) -> str: + return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}" @property def pip_packages(self) -> list[str]: - raise NotImplementedError("Postgres is not implemented yet") + return super().pip_packages + ["asyncpg"] SqlStoreConfig = Annotated[ @@ -60,12 +80,10 @@ SqlStoreConfig = Annotated[ def sqlstore_impl(config: SqlStoreConfig) -> SqlStore: - if config.type == SqlStoreType.sqlite.value: - from .sqlite.sqlite import SqliteSqlStoreImpl + if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]: + from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl - impl = SqliteSqlStoreImpl(config) - elif config.type == SqlStoreType.postgres.value: - raise NotImplementedError("Postgres is not implemented yet") + impl = SqlAlchemySqlStoreImpl(config) else: raise ValueError(f"Unknown sqlstore type {config.type}") diff --git a/llama_stack/templates/bedrock/build.yaml b/llama_stack/templates/bedrock/build.yaml index 09fbf307d..97a06f77a 100644 --- a/llama_stack/templates/bedrock/build.yaml +++ b/llama_stack/templates/bedrock/build.yaml @@ -30,4 +30,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/cerebras/build.yaml b/llama_stack/templates/cerebras/build.yaml index 95b0302f2..f26f4ed9b 100644 --- a/llama_stack/templates/cerebras/build.yaml +++ b/llama_stack/templates/cerebras/build.yaml @@ -30,4 +30,5 @@ distribution_spec: - inline::rag-runtime image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/ci-tests/build.yaml b/llama_stack/templates/ci-tests/build.yaml index 6fe96c603..9f4fbbdda 100644 --- a/llama_stack/templates/ci-tests/build.yaml +++ b/llama_stack/templates/ci-tests/build.yaml @@ -31,4 +31,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/dell/build.yaml b/llama_stack/templates/dell/build.yaml index d37215f35..513df16c1 100644 --- a/llama_stack/templates/dell/build.yaml +++ b/llama_stack/templates/dell/build.yaml @@ -31,5 +31,5 @@ distribution_spec: - inline::rag-runtime image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/fireworks/build.yaml b/llama_stack/templates/fireworks/build.yaml index f162d9b43..be19181c0 100644 --- a/llama_stack/templates/fireworks/build.yaml +++ b/llama_stack/templates/fireworks/build.yaml @@ -32,5 +32,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/groq/build.yaml b/llama_stack/templates/groq/build.yaml index 92b46ce66..819df22f0 100644 --- a/llama_stack/templates/groq/build.yaml +++ b/llama_stack/templates/groq/build.yaml @@ -27,4 +27,5 @@ distribution_spec: - inline::rag-runtime image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/hf-endpoint/build.yaml b/llama_stack/templates/hf-endpoint/build.yaml index 4d09cc33e..8ede83694 100644 --- a/llama_stack/templates/hf-endpoint/build.yaml +++ b/llama_stack/templates/hf-endpoint/build.yaml @@ -30,5 +30,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/hf-serverless/build.yaml b/llama_stack/templates/hf-serverless/build.yaml index d06c628ac..d0752db9a 100644 --- a/llama_stack/templates/hf-serverless/build.yaml +++ b/llama_stack/templates/hf-serverless/build.yaml @@ -31,5 +31,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/llama_api/build.yaml b/llama_stack/templates/llama_api/build.yaml index d0dc08923..857e5f014 100644 --- a/llama_stack/templates/llama_api/build.yaml +++ b/llama_stack/templates/llama_api/build.yaml @@ -31,4 +31,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/meta-reference-gpu/build.yaml b/llama_stack/templates/meta-reference-gpu/build.yaml index e0ac87e47..53ad411e3 100644 --- a/llama_stack/templates/meta-reference-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-gpu/build.yaml @@ -30,5 +30,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/templates/nvidia/build.yaml index e1e6fb3d8..6bd8a0100 100644 --- a/llama_stack/templates/nvidia/build.yaml +++ b/llama_stack/templates/nvidia/build.yaml @@ -25,5 +25,5 @@ distribution_spec: - inline::rag-runtime image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 9d8ba3a1e..36a120897 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -33,5 +33,5 @@ distribution_spec: - remote::wolfram-alpha image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/open-benchmark/build.yaml b/llama_stack/templates/open-benchmark/build.yaml index aa6d876fe..840f1e1db 100644 --- a/llama_stack/templates/open-benchmark/build.yaml +++ b/llama_stack/templates/open-benchmark/build.yaml @@ -34,4 +34,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/passthrough/build.yaml b/llama_stack/templates/passthrough/build.yaml index 7560f1032..46b99cb75 100644 --- a/llama_stack/templates/passthrough/build.yaml +++ b/llama_stack/templates/passthrough/build.yaml @@ -32,5 +32,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/postgres-demo/__init__.py b/llama_stack/templates/postgres-demo/__init__.py new file mode 100644 index 000000000..81473cb73 --- /dev/null +++ b/llama_stack/templates/postgres-demo/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .postgres_demo import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/postgres-demo/build.yaml b/llama_stack/templates/postgres-demo/build.yaml new file mode 100644 index 000000000..8f3648abe --- /dev/null +++ b/llama_stack/templates/postgres-demo/build.yaml @@ -0,0 +1,24 @@ +version: '2' +distribution_spec: + description: Quick start template for running Llama Stack with several popular providers + providers: + inference: + - remote::fireworks + - remote::vllm + vector_io: + - remote::chromadb + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::rag-runtime + - remote::model-context-protocol +image_type: conda +additional_pip_packages: +- asyncpg +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/postgres-demo/postgres_demo.py b/llama_stack/templates/postgres-demo/postgres_demo.py new file mode 100644 index 000000000..d2e352320 --- /dev/null +++ b/llama_stack/templates/postgres-demo/postgres_demo.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) +from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig +from llama_stack.providers.remote.inference.fireworks.models import ( + MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES, +) +from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig +from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig +from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry +from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig +from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig +from llama_stack.templates.template import ( + DistributionTemplate, + RunConfigSettings, + get_model_registry, +) + + +def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]: + # in this template, we allow each API key to be optional + providers = [ + ( + "fireworks", + FIREWORKS_MODEL_ENTRIES, + FireworksImplConfig.sample_run_config(api_key="${env.FIREWORKS_API_KEY:}"), + ), + ] + inference_providers = [] + available_models = {} + for provider_id, model_entries, config in providers: + inference_providers.append( + Provider( + provider_id=provider_id, + provider_type=f"remote::{provider_id}", + config=config, + ) + ) + available_models[provider_id] = model_entries + inference_providers.append( + Provider( + provider_id="vllm-inference", + provider_type="remote::vllm", + config=VLLMInferenceAdapterConfig.sample_run_config( + url="${env.VLLM_URL:http://localhost:8000/v1}", + ), + ) + ) + return inference_providers, available_models + + +def get_distribution_template() -> DistributionTemplate: + inference_providers, available_models = get_inference_providers() + providers = { + "inference": ([p.provider_type for p in inference_providers]), + "vector_io": ["remote::chromadb"], + "safety": ["inline::llama-guard"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::rag-runtime", + "remote::model-context-protocol", + ], + } + name = "postgres-demo" + + vector_io_providers = [ + Provider( + provider_id="${env.ENABLE_CHROMADB+chromadb}", + provider_type="remote::chromadb", + config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"), + ), + ] + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), + ] + + default_models = get_model_registry(available_models) + default_models.append( + ModelInput( + model_id="${env.INFERENCE_MODEL}", + provider_id="vllm-inference", + ) + ) + postgres_config = { + "type": "postgres", + "host": "${env.POSTGRES_HOST:localhost}", + "port": "${env.POSTGRES_PORT:5432}", + "db": "${env.POSTGRES_DB:llamastack}", + "user": "${env.POSTGRES_USER:llamastack}", + "password": "${env.POSTGRES_PASSWORD:llamastack}", + } + + return DistributionTemplate( + name=name, + distro_type="self_hosted", + description="Quick start template for running Llama Stack with several popular providers", + container_image=None, + template_path=None, + providers=providers, + available_models_by_provider=available_models, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": inference_providers, + "vector_io": vector_io_providers, + "agents": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + config=dict( + persistence_store=postgres_config, + responses_store=postgres_config, + ), + ) + ], + "telemetry": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + config=dict( + service_name="${env.OTEL_SERVICE_NAME:}", + sinks="${env.TELEMETRY_SINKS:console}", + ), + ) + ], + }, + default_models=default_models, + default_tool_groups=default_tool_groups, + default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + metadata_store=PostgresKVStoreConfig.model_validate(postgres_config), + inference_store=PostgresSqlStoreConfig.model_validate(postgres_config), + ), + }, + run_config_env_vars={ + "LLAMA_STACK_PORT": ( + "8321", + "Port for the Llama Stack distribution server", + ), + "FIREWORKS_API_KEY": ( + "", + "Fireworks API Key", + ), + }, + ) diff --git a/llama_stack/templates/postgres-demo/run.yaml b/llama_stack/templates/postgres-demo/run.yaml new file mode 100644 index 000000000..889b8eaa7 --- /dev/null +++ b/llama_stack/templates/postgres-demo/run.yaml @@ -0,0 +1,224 @@ +version: '2' +image_name: postgres-demo +apis: +- agents +- inference +- safety +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: fireworks + provider_type: remote::fireworks + config: + url: https://api.fireworks.ai/inference/v1 + api_key: ${env.FIREWORKS_API_KEY:} + - provider_id: vllm-inference + provider_type: remote::vllm + config: + url: ${env.VLLM_URL:http://localhost:8000/v1} + max_tokens: ${env.VLLM_MAX_TOKENS:4096} + api_token: ${env.VLLM_API_TOKEN:fake} + tls_verify: ${env.VLLM_TLS_VERIFY:true} + vector_io: + - provider_id: ${env.ENABLE_CHROMADB+chromadb} + provider_type: remote::chromadb + config: + url: ${env.CHROMADB_URL:} + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: postgres + host: ${env.POSTGRES_HOST:localhost} + port: ${env.POSTGRES_PORT:5432} + db: ${env.POSTGRES_DB:llamastack} + user: ${env.POSTGRES_USER:llamastack} + password: ${env.POSTGRES_PASSWORD:llamastack} + responses_store: + type: postgres + host: ${env.POSTGRES_HOST:localhost} + port: ${env.POSTGRES_PORT:5432} + db: ${env.POSTGRES_DB:llamastack} + user: ${env.POSTGRES_USER:llamastack} + password: ${env.POSTGRES_PASSWORD:llamastack} + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: ${env.OTEL_SERVICE_NAME:} + sinks: ${env.TELEMETRY_SINKS:console} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} +metadata_store: + type: postgres + host: ${env.POSTGRES_HOST:localhost} + port: ${env.POSTGRES_PORT:5432} + db: ${env.POSTGRES_DB:llamastack} + user: ${env.POSTGRES_USER:llamastack} + password: ${env.POSTGRES_PASSWORD:llamastack} + table_name: llamastack_kvstore +inference_store: + type: postgres + host: ${env.POSTGRES_HOST:localhost} + port: ${env.POSTGRES_PORT:5432} + db: ${env.POSTGRES_DB:llamastack} + user: ${env.POSTGRES_USER:llamastack} + password: ${env.POSTGRES_PASSWORD:llamastack} +models: +- metadata: {} + model_id: accounts/fireworks/models/llama-v3p1-8b-instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-v3p1-70b-instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-70B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-v3p1-405b-instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-v3p2-3b-instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-3B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-v3p3-70b-instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.3-70B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-guard-3-8b + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-guard-3-8b + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-Guard-3-8B + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-guard-3-8b + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama-guard-3-11b-vision + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-Guard-3-11B-Vision + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama4-scout-instruct-basic + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic + model_type: llm +- metadata: {} + model_id: accounts/fireworks/models/llama4-maverick-instruct-basic + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic + model_type: llm +- metadata: + embedding_dimension: 768 + context_length: 8192 + model_id: nomic-ai/nomic-embed-text-v1.5 + provider_id: fireworks + provider_model_id: nomic-ai/nomic-embed-text-v1.5 + model_type: embedding +- metadata: {} + model_id: ${env.INFERENCE_MODEL} + provider_id: vllm-inference + model_type: llm +shields: +- shield_id: meta-llama/Llama-Guard-3-8B +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +server: + port: 8321 diff --git a/llama_stack/templates/remote-vllm/build.yaml b/llama_stack/templates/remote-vllm/build.yaml index fcd4deeff..16fe5d4fd 100644 --- a/llama_stack/templates/remote-vllm/build.yaml +++ b/llama_stack/templates/remote-vllm/build.yaml @@ -32,5 +32,5 @@ distribution_spec: - remote::wolfram-alpha image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/sambanova/build.yaml b/llama_stack/templates/sambanova/build.yaml index b644dcfdc..14b1c8974 100644 --- a/llama_stack/templates/sambanova/build.yaml +++ b/llama_stack/templates/sambanova/build.yaml @@ -23,4 +23,5 @@ distribution_spec: - remote::wolfram-alpha image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/starter/build.yaml b/llama_stack/templates/starter/build.yaml index 652814ffd..ec97c7d3e 100644 --- a/llama_stack/templates/starter/build.yaml +++ b/llama_stack/templates/starter/build.yaml @@ -36,4 +36,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index ec5cd38ea..4013f08f9 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -28,8 +28,8 @@ from llama_stack.distribution.datatypes import ( from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig def get_model_registry( @@ -64,6 +64,8 @@ class RunConfigSettings(BaseModel): default_tool_groups: list[ToolGroupInput] | None = None default_datasets: list[DatasetInput] | None = None default_benchmarks: list[BenchmarkInput] | None = None + metadata_store: KVStoreConfig | None = None + inference_store: SqlStoreConfig | None = None def run_config( self, @@ -114,11 +116,13 @@ class RunConfigSettings(BaseModel): container_image=container_image, apis=apis, providers=provider_configs, - metadata_store=SqliteKVStoreConfig.sample_run_config( + metadata_store=self.metadata_store + or SqliteKVStoreConfig.sample_run_config( __distro_dir__=f"~/.llama/distributions/{name}", db_name="registry.db", ), - inference_store=SqliteSqlStoreConfig.sample_run_config( + inference_store=self.inference_store + or SqliteSqlStoreConfig.sample_run_config( __distro_dir__=f"~/.llama/distributions/{name}", db_name="inference_store.db", ), @@ -164,7 +168,7 @@ class DistributionTemplate(BaseModel): providers=self.providers, ), image_type="conda", # default to conda, can be overridden - additional_pip_packages=additional_pip_packages, + additional_pip_packages=sorted(set(additional_pip_packages)), ) def generate_markdown_docs(self) -> str: diff --git a/llama_stack/templates/tgi/build.yaml b/llama_stack/templates/tgi/build.yaml index 652900c84..361b0b680 100644 --- a/llama_stack/templates/tgi/build.yaml +++ b/llama_stack/templates/tgi/build.yaml @@ -31,5 +31,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/together/build.yaml b/llama_stack/templates/together/build.yaml index 4a556a66f..5ffeac873 100644 --- a/llama_stack/templates/together/build.yaml +++ b/llama_stack/templates/together/build.yaml @@ -32,5 +32,5 @@ distribution_spec: - remote::wolfram-alpha image_type: conda additional_pip_packages: -- sqlalchemy[asyncio] +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/verification/build.yaml b/llama_stack/templates/verification/build.yaml index cb7ab4798..ce083dbba 100644 --- a/llama_stack/templates/verification/build.yaml +++ b/llama_stack/templates/verification/build.yaml @@ -36,4 +36,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/vllm-gpu/build.yaml b/llama_stack/templates/vllm-gpu/build.yaml index 5a9d003cb..d5ff0f1f4 100644 --- a/llama_stack/templates/vllm-gpu/build.yaml +++ b/llama_stack/templates/vllm-gpu/build.yaml @@ -31,4 +31,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/watsonx/build.yaml b/llama_stack/templates/watsonx/build.yaml index 87233fb26..e68ace183 100644 --- a/llama_stack/templates/watsonx/build.yaml +++ b/llama_stack/templates/watsonx/build.yaml @@ -29,4 +29,5 @@ distribution_spec: - remote::model-context-protocol image_type: conda additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] diff --git a/scripts/distro_codegen.py b/scripts/distro_codegen.py index 8820caf55..d33c5de67 100755 --- a/scripts/distro_codegen.py +++ b/scripts/distro_codegen.py @@ -107,6 +107,13 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[ return None, [] +def pre_import_templates(template_dirs: list[Path]) -> None: + # Pre-import all template modules to avoid deadlocks. + for template_dir in template_dirs: + module_name = f"llama_stack.templates.{template_dir.name}" + importlib.import_module(module_name) + + def main(): templates_dir = REPO_ROOT / "llama_stack" / "templates" change_tracker = ChangedPathTracker() @@ -118,6 +125,8 @@ def main(): template_dirs = list(find_template_dirs(templates_dir)) task = progress.add_task("Processing distribution templates...", total=len(template_dirs)) + pre_import_templates(template_dirs) + # Create a partial function with the progress bar process_func = partial(process_template, progress=progress, change_tracker=change_tracker) diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 2cd76a23d..190840f70 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -268,9 +268,9 @@ def test_openai_chat_completion_streaming_with_n(compat_client, client_with_mode False, ], ) -def test_inference_store(openai_client, client_with_models, text_model_id, stream): +def test_inference_store(compat_client, client_with_models, text_model_id, stream): skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) - client = openai_client + client = compat_client # make a chat completion message = "Hello, world!" response = client.chat.completions.create( @@ -301,9 +301,14 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea retrieved_response = client.chat.completions.retrieve(response_id) assert retrieved_response.id == response_id - assert retrieved_response.input_messages[0]["content"] == message, retrieved_response assert retrieved_response.choices[0].message.content == content, retrieved_response + input_content = ( + getattr(retrieved_response.input_messages[0], "content", None) + or retrieved_response.input_messages[0]["content"] + ) + assert input_content == message, retrieved_response + @pytest.mark.parametrize( "stream", @@ -312,9 +317,9 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea False, ], ) -def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream): +def test_inference_store_tool_calls(compat_client, client_with_models, text_model_id, stream): skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) - client = openai_client + client = compat_client # make a chat completion message = "What's the weather in Tokyo? Use the get_weather function to get the weather." response = client.chat.completions.create( @@ -361,7 +366,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode retrieved_response = client.chat.completions.retrieve(response_id) assert retrieved_response.id == response_id - assert retrieved_response.input_messages[0]["content"] == message + input_content = ( + getattr(retrieved_response.input_messages[0], "content", None) + or retrieved_response.input_messages[0]["content"] + ) + assert input_content == message, retrieved_response tool_calls = retrieved_response.choices[0].message.tool_calls # sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called if tool_calls: diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index 9c491accb..5b6cee0ec 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -628,3 +628,69 @@ async def test_responses_store_list_input_items_logic(): result = await responses_store.list_response_input_items("resp_123", limit=0, order=Order.asc) assert result.object == "list" assert len(result.data) == 0 # Should return no items + + +@pytest.mark.asyncio +async def test_store_response_uses_rehydrated_input_with_previous_response( + openai_responses_impl, mock_responses_store, mock_inference_api +): + """Test that _store_response uses the full re-hydrated input (including previous responses) + rather than just the original input when previous_response_id is provided.""" + + # Setup - Create a previous response that should be included in the stored input + previous_response = OpenAIResponseObjectWithInput( + id="resp-previous-123", + object="response", + created_at=1234567890, + model="meta-llama/Llama-3.1-8B-Instruct", + status="completed", + input=[ + OpenAIResponseMessage( + id="msg-prev-user", role="user", content=[OpenAIResponseInputMessageContentText(text="What is 2+2?")] + ) + ], + output=[ + OpenAIResponseMessage( + id="msg-prev-assistant", + role="assistant", + content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")], + ) + ], + ) + + mock_responses_store.get_response_object.return_value = previous_response + + current_input = "Now what is 3+3?" + model = "meta-llama/Llama-3.1-8B-Instruct" + mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml") + mock_inference_api.openai_chat_completion.return_value = mock_chat_completion + + # Execute - Create response with previous_response_id + result = await openai_responses_impl.create_openai_response( + input=current_input, + model=model, + previous_response_id="resp-previous-123", + store=True, + ) + + store_call_args = mock_responses_store.store_response_object.call_args + stored_input = store_call_args.kwargs["input"] + + # Verify that the stored input contains the full re-hydrated conversation: + # 1. Previous user message + # 2. Previous assistant response + # 3. Current user message + assert len(stored_input) == 3 + + assert stored_input[0].role == "user" + assert stored_input[0].content[0].text == "What is 2+2?" + + assert stored_input[1].role == "assistant" + assert stored_input[1].content[0].text == "2+2 equals 4." + + assert stored_input[2].role == "user" + assert stored_input[2].content == "Now what is 3+3?" + + # Verify the response itself is correct + assert result.model == model + assert result.status == "completed" diff --git a/tests/unit/utils/test_sqlstore.py b/tests/unit/utils/test_sqlstore.py index 8ded760ef..6231e9082 100644 --- a/tests/unit/utils/test_sqlstore.py +++ b/tests/unit/utils/test_sqlstore.py @@ -9,7 +9,7 @@ from tempfile import TemporaryDirectory import pytest from llama_stack.providers.utils.sqlstore.api import ColumnType -from llama_stack.providers.utils.sqlstore.sqlite.sqlite import SqliteSqlStoreImpl +from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig @@ -17,7 +17,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig async def test_sqlite_sqlstore(): with TemporaryDirectory() as tmp_dir: db_name = "test.db" - sqlstore = SqliteSqlStoreImpl( + sqlstore = SqlAlchemySqlStoreImpl( SqliteSqlStoreConfig( db_path=tmp_dir + "/" + db_name, )