Merge branch 'embeddings' of https://github.com/hardikjshah/llama-stack into embeddings

This commit is contained in:
Hardik Shah 2025-05-30 12:25:07 -07:00
commit 535e55d7dd
37 changed files with 605 additions and 66 deletions

View file

@ -292,12 +292,12 @@ class OpenAIResponsesImpl:
async def _store_response(
self,
response: OpenAIResponseObject,
original_input: str | list[OpenAIResponseInput],
input: str | list[OpenAIResponseInput],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(original_input, str):
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=original_input)
input_content = OpenAIResponseInputMessageContentText(text=input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
@ -307,7 +307,7 @@ class OpenAIResponsesImpl:
else:
# we already have a list of messages
input_items_data = []
for input_item in original_input:
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
@ -334,7 +334,6 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None,
):
stream = False if stream is None else stream
original_input = input # Keep reference for storage
output_messages: list[OpenAIResponseOutput] = []
@ -372,7 +371,7 @@ class OpenAIResponsesImpl:
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
original_input=original_input,
input=input,
model=model,
store=store,
tools=tools,
@ -382,7 +381,7 @@ class OpenAIResponsesImpl:
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
original_input=original_input,
input=input,
model=model,
store=store,
tools=tools,
@ -393,7 +392,7 @@ class OpenAIResponsesImpl:
inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
original_input: str | list[OpenAIResponseInput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
@ -423,7 +422,7 @@ class OpenAIResponsesImpl:
if store:
await self._store_response(
response=response,
original_input=original_input,
input=input,
)
return response
@ -433,7 +432,7 @@ class OpenAIResponsesImpl:
inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
original_input: str | list[OpenAIResponseInput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
@ -544,7 +543,7 @@ class OpenAIResponsesImpl:
if store:
await self._store_response(
response=final_response,
original_input=original_input,
input=input,
)
# Emit response.completed

View file

@ -75,7 +75,9 @@ class PromptGuardShield:
self.temperature = temperature
self.threshold = threshold
self.device = "cuda"
self.device = "cpu"
if torch.cuda.is_available():
self.device = "cuda"
# load model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)

View file

@ -218,7 +218,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
"json_schema": {
"name": name,
"schema": fmt,
"strict": True,
"strict": False,
},
}
if request.tools:

View file

@ -65,7 +65,7 @@ class SqliteKVStoreConfig(CommonConfig):
class PostgresKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
host: str = "localhost"
port: int = 5432
port: str = "5432"
db: str = "llamastack"
user: str
password: str | None = None

View file

@ -19,10 +19,10 @@ from sqlalchemy import (
Text,
select,
)
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from ..api import ColumnDefinition, ColumnType, SqlStore
from ..sqlstore import SqliteSqlStoreConfig
from .api import ColumnDefinition, ColumnType, SqlStore
from .sqlstore import SqlAlchemySqlStoreConfig
TYPE_MAPPING: dict[ColumnType, Any] = {
ColumnType.INTEGER: Integer,
@ -35,9 +35,10 @@ TYPE_MAPPING: dict[ColumnType, Any] = {
}
class SqliteSqlStoreImpl(SqlStore):
def __init__(self, config: SqliteSqlStoreConfig):
self.engine = create_async_engine(config.engine_str)
class SqlAlchemySqlStoreImpl(SqlStore):
def __init__(self, config: SqlAlchemySqlStoreConfig):
self.config = config
self.async_session = async_sessionmaker(create_async_engine(config.engine_str))
self.metadata = MetaData()
async def create_table(
@ -78,13 +79,14 @@ class SqliteSqlStoreImpl(SqlStore):
# Create the table in the database if it doesn't exist
# checkfirst=True ensures it doesn't try to recreate if it's already there
async with self.engine.begin() as conn:
engine = create_async_engine(self.config.engine_str)
async with engine.begin() as conn:
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
async with self.engine.begin() as conn:
await conn.execute(self.metadata.tables[table].insert(), data)
await conn.commit()
async with self.async_session() as session:
await session.execute(self.metadata.tables[table].insert(), data)
await session.commit()
async def fetch_all(
self,
@ -93,7 +95,7 @@ class SqliteSqlStoreImpl(SqlStore):
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> list[dict[str, Any]]:
async with self.engine.begin() as conn:
async with self.async_session() as session:
query = select(self.metadata.tables[table])
if where:
for key, value in where.items():
@ -117,7 +119,7 @@ class SqliteSqlStoreImpl(SqlStore):
query = query.order_by(self.metadata.tables[table].c[name].desc())
else:
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
result = await conn.execute(query)
result = await session.execute(query)
if result.rowcount == 0:
return []
return [dict(row._mapping) for row in result]
@ -142,20 +144,20 @@ class SqliteSqlStoreImpl(SqlStore):
if not where:
raise ValueError("where is required for update")
async with self.engine.begin() as conn:
async with self.async_session() as session:
stmt = self.metadata.tables[table].update()
for key, value in where.items():
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
await conn.execute(stmt, data)
await conn.commit()
await session.execute(stmt, data)
await session.commit()
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
if not where:
raise ValueError("where is required for delete")
async with self.engine.begin() as conn:
async with self.async_session() as session:
stmt = self.metadata.tables[table].delete()
for key, value in where.items():
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
await conn.execute(stmt)
await conn.commit()
await session.execute(stmt)
await session.commit()

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from abc import abstractmethod
from enum import Enum
from pathlib import Path
from typing import Annotated, Literal
@ -21,7 +22,18 @@ class SqlStoreType(Enum):
postgres = "postgres"
class SqliteSqlStoreConfig(BaseModel):
class SqlAlchemySqlStoreConfig(BaseModel):
@property
@abstractmethod
def engine_str(self) -> str: ...
# TODO: move this when we have a better way to specify dependencies with internal APIs
@property
def pip_packages(self) -> list[str]:
return ["sqlalchemy[asyncio]"]
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["sqlite"] = SqlStoreType.sqlite.value
db_path: str = Field(
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
@ -39,18 +51,26 @@ class SqliteSqlStoreConfig(BaseModel):
db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
)
# TODO: move this when we have a better way to specify dependencies with internal APIs
@property
def pip_packages(self) -> list[str]:
return ["sqlalchemy[asyncio]"]
return super().pip_packages + ["aiosqlite"]
class PostgresSqlStoreConfig(BaseModel):
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["postgres"] = SqlStoreType.postgres.value
host: str = "localhost"
port: str = "5432"
db: str = "llamastack"
user: str
password: str | None = None
@property
def engine_str(self) -> str:
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
@property
def pip_packages(self) -> list[str]:
raise NotImplementedError("Postgres is not implemented yet")
return super().pip_packages + ["asyncpg"]
SqlStoreConfig = Annotated[
@ -60,12 +80,10 @@ SqlStoreConfig = Annotated[
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
if config.type == SqlStoreType.sqlite.value:
from .sqlite.sqlite import SqliteSqlStoreImpl
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
impl = SqliteSqlStoreImpl(config)
elif config.type == SqlStoreType.postgres.value:
raise NotImplementedError("Postgres is not implemented yet")
impl = SqlAlchemySqlStoreImpl(config)
else:
raise ValueError(f"Unknown sqlstore type {config.type}")

View file

@ -30,4 +30,5 @@ distribution_spec:
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -30,4 +30,5 @@ distribution_spec:
- inline::rag-runtime
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -31,4 +31,5 @@ distribution_spec:
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -31,5 +31,5 @@ distribution_spec:
- inline::rag-runtime
image_type: conda
additional_pip_packages:
- sqlalchemy[asyncio]
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -32,5 +32,5 @@ distribution_spec:
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- sqlalchemy[asyncio]
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -27,4 +27,5 @@ distribution_spec:
- inline::rag-runtime
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -30,5 +30,5 @@ distribution_spec:
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- sqlalchemy[asyncio]
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -31,5 +31,5 @@ distribution_spec:
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- sqlalchemy[asyncio]
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -31,4 +31,5 @@ distribution_spec:
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -30,5 +30,5 @@ distribution_spec:
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- sqlalchemy[asyncio]
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -25,5 +25,5 @@ distribution_spec:
- inline::rag-runtime
image_type: conda
additional_pip_packages:
- sqlalchemy[asyncio]
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -33,5 +33,5 @@ distribution_spec:
- remote::wolfram-alpha
image_type: conda
additional_pip_packages:
- sqlalchemy[asyncio]
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -34,4 +34,5 @@ distribution_spec:
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -32,5 +32,5 @@ distribution_spec:
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- sqlalchemy[asyncio]
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .postgres_demo import get_distribution_template # noqa: F401

View file

@ -0,0 +1,24 @@
version: '2'
distribution_spec:
description: Quick start template for running Llama Stack with several popular providers
providers:
inference:
- remote::fireworks
- remote::vllm
vector_io:
- remote::chromadb
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::rag-runtime
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- asyncpg
- sqlalchemy[asyncio]

View file

@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
from llama_stack.providers.remote.inference.fireworks.models import (
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
get_model_registry,
)
def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:
# in this template, we allow each API key to be optional
providers = [
(
"fireworks",
FIREWORKS_MODEL_ENTRIES,
FireworksImplConfig.sample_run_config(api_key="${env.FIREWORKS_API_KEY:}"),
),
]
inference_providers = []
available_models = {}
for provider_id, model_entries, config in providers:
inference_providers.append(
Provider(
provider_id=provider_id,
provider_type=f"remote::{provider_id}",
config=config,
)
)
available_models[provider_id] = model_entries
inference_providers.append(
Provider(
provider_id="vllm-inference",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig.sample_run_config(
url="${env.VLLM_URL:http://localhost:8000/v1}",
),
)
)
return inference_providers, available_models
def get_distribution_template() -> DistributionTemplate:
inference_providers, available_models = get_inference_providers()
providers = {
"inference": ([p.provider_type for p in inference_providers]),
"vector_io": ["remote::chromadb"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::rag-runtime",
"remote::model-context-protocol",
],
}
name = "postgres-demo"
vector_io_providers = [
Provider(
provider_id="${env.ENABLE_CHROMADB+chromadb}",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
),
]
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
default_models = get_model_registry(available_models)
default_models.append(
ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="vllm-inference",
)
)
postgres_config = {
"type": "postgres",
"host": "${env.POSTGRES_HOST:localhost}",
"port": "${env.POSTGRES_PORT:5432}",
"db": "${env.POSTGRES_DB:llamastack}",
"user": "${env.POSTGRES_USER:llamastack}",
"password": "${env.POSTGRES_PASSWORD:llamastack}",
}
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Quick start template for running Llama Stack with several popular providers",
container_image=None,
template_path=None,
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": inference_providers,
"vector_io": vector_io_providers,
"agents": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
config=dict(
persistence_store=postgres_config,
responses_store=postgres_config,
),
)
],
"telemetry": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
config=dict(
service_name="${env.OTEL_SERVICE_NAME:}",
sinks="${env.TELEMETRY_SINKS:console}",
),
)
],
},
default_models=default_models,
default_tool_groups=default_tool_groups,
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
metadata_store=PostgresKVStoreConfig.model_validate(postgres_config),
inference_store=PostgresSqlStoreConfig.model_validate(postgres_config),
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"8321",
"Port for the Llama Stack distribution server",
),
"FIREWORKS_API_KEY": (
"",
"Fireworks API Key",
),
},
)

View file

@ -0,0 +1,224 @@
version: '2'
image_name: postgres-demo
apis:
- agents
- inference
- safety
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:}
- provider_id: vllm-inference
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
api_token: ${env.VLLM_API_TOKEN:fake}
tls_verify: ${env.VLLM_TLS_VERIFY:true}
vector_io:
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
responses_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
table_name: llamastack_kvstore
inference_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
models:
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-guard-3-8b
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-8B
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama-guard-3-11b-vision
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-11B-Vision
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama4-scout-instruct-basic
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic
model_type: llm
- metadata: {}
model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
model_type: llm
- metadata:
embedding_dimension: 768
context_length: 8192
model_id: nomic-ai/nomic-embed-text-v1.5
provider_id: fireworks
provider_model_id: nomic-ai/nomic-embed-text-v1.5
model_type: embedding
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference
model_type: llm
shields:
- shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -32,5 +32,5 @@ distribution_spec:
- remote::wolfram-alpha
image_type: conda
additional_pip_packages:
- sqlalchemy[asyncio]
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -23,4 +23,5 @@ distribution_spec:
- remote::wolfram-alpha
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -36,4 +36,5 @@ distribution_spec:
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -28,8 +28,8 @@ from llama_stack.distribution.datatypes import (
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
def get_model_registry(
@ -64,6 +64,8 @@ class RunConfigSettings(BaseModel):
default_tool_groups: list[ToolGroupInput] | None = None
default_datasets: list[DatasetInput] | None = None
default_benchmarks: list[BenchmarkInput] | None = None
metadata_store: KVStoreConfig | None = None
inference_store: SqlStoreConfig | None = None
def run_config(
self,
@ -114,11 +116,13 @@ class RunConfigSettings(BaseModel):
container_image=container_image,
apis=apis,
providers=provider_configs,
metadata_store=SqliteKVStoreConfig.sample_run_config(
metadata_store=self.metadata_store
or SqliteKVStoreConfig.sample_run_config(
__distro_dir__=f"~/.llama/distributions/{name}",
db_name="registry.db",
),
inference_store=SqliteSqlStoreConfig.sample_run_config(
inference_store=self.inference_store
or SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=f"~/.llama/distributions/{name}",
db_name="inference_store.db",
),
@ -164,7 +168,7 @@ class DistributionTemplate(BaseModel):
providers=self.providers,
),
image_type="conda", # default to conda, can be overridden
additional_pip_packages=additional_pip_packages,
additional_pip_packages=sorted(set(additional_pip_packages)),
)
def generate_markdown_docs(self) -> str:

View file

@ -31,5 +31,5 @@ distribution_spec:
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- sqlalchemy[asyncio]
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -32,5 +32,5 @@ distribution_spec:
- remote::wolfram-alpha
image_type: conda
additional_pip_packages:
- sqlalchemy[asyncio]
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -36,4 +36,5 @@ distribution_spec:
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -31,4 +31,5 @@ distribution_spec:
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -29,4 +29,5 @@ distribution_spec:
- remote::model-context-protocol
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -107,6 +107,13 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[
return None, []
def pre_import_templates(template_dirs: list[Path]) -> None:
# Pre-import all template modules to avoid deadlocks.
for template_dir in template_dirs:
module_name = f"llama_stack.templates.{template_dir.name}"
importlib.import_module(module_name)
def main():
templates_dir = REPO_ROOT / "llama_stack" / "templates"
change_tracker = ChangedPathTracker()
@ -118,6 +125,8 @@ def main():
template_dirs = list(find_template_dirs(templates_dir))
task = progress.add_task("Processing distribution templates...", total=len(template_dirs))
pre_import_templates(template_dirs)
# Create a partial function with the progress bar
process_func = partial(process_template, progress=progress, change_tracker=change_tracker)

View file

@ -268,9 +268,9 @@ def test_openai_chat_completion_streaming_with_n(compat_client, client_with_mode
False,
],
)
def test_inference_store(openai_client, client_with_models, text_model_id, stream):
def test_inference_store(compat_client, client_with_models, text_model_id, stream):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
client = openai_client
client = compat_client
# make a chat completion
message = "Hello, world!"
response = client.chat.completions.create(
@ -301,9 +301,14 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
retrieved_response = client.chat.completions.retrieve(response_id)
assert retrieved_response.id == response_id
assert retrieved_response.input_messages[0]["content"] == message, retrieved_response
assert retrieved_response.choices[0].message.content == content, retrieved_response
input_content = (
getattr(retrieved_response.input_messages[0], "content", None)
or retrieved_response.input_messages[0]["content"]
)
assert input_content == message, retrieved_response
@pytest.mark.parametrize(
"stream",
@ -312,9 +317,9 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
False,
],
)
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream):
def test_inference_store_tool_calls(compat_client, client_with_models, text_model_id, stream):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
client = openai_client
client = compat_client
# make a chat completion
message = "What's the weather in Tokyo? Use the get_weather function to get the weather."
response = client.chat.completions.create(
@ -361,7 +366,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
retrieved_response = client.chat.completions.retrieve(response_id)
assert retrieved_response.id == response_id
assert retrieved_response.input_messages[0]["content"] == message
input_content = (
getattr(retrieved_response.input_messages[0], "content", None)
or retrieved_response.input_messages[0]["content"]
)
assert input_content == message, retrieved_response
tool_calls = retrieved_response.choices[0].message.tool_calls
# sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called
if tool_calls:

View file

@ -628,3 +628,69 @@ async def test_responses_store_list_input_items_logic():
result = await responses_store.list_response_input_items("resp_123", limit=0, order=Order.asc)
assert result.object == "list"
assert len(result.data) == 0 # Should return no items
@pytest.mark.asyncio
async def test_store_response_uses_rehydrated_input_with_previous_response(
openai_responses_impl, mock_responses_store, mock_inference_api
):
"""Test that _store_response uses the full re-hydrated input (including previous responses)
rather than just the original input when previous_response_id is provided."""
# Setup - Create a previous response that should be included in the stored input
previous_response = OpenAIResponseObjectWithInput(
id="resp-previous-123",
object="response",
created_at=1234567890,
model="meta-llama/Llama-3.1-8B-Instruct",
status="completed",
input=[
OpenAIResponseMessage(
id="msg-prev-user", role="user", content=[OpenAIResponseInputMessageContentText(text="What is 2+2?")]
)
],
output=[
OpenAIResponseMessage(
id="msg-prev-assistant",
role="assistant",
content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")],
)
],
)
mock_responses_store.get_response_object.return_value = previous_response
current_input = "Now what is 3+3?"
model = "meta-llama/Llama-3.1-8B-Instruct"
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
# Execute - Create response with previous_response_id
result = await openai_responses_impl.create_openai_response(
input=current_input,
model=model,
previous_response_id="resp-previous-123",
store=True,
)
store_call_args = mock_responses_store.store_response_object.call_args
stored_input = store_call_args.kwargs["input"]
# Verify that the stored input contains the full re-hydrated conversation:
# 1. Previous user message
# 2. Previous assistant response
# 3. Current user message
assert len(stored_input) == 3
assert stored_input[0].role == "user"
assert stored_input[0].content[0].text == "What is 2+2?"
assert stored_input[1].role == "assistant"
assert stored_input[1].content[0].text == "2+2 equals 4."
assert stored_input[2].role == "user"
assert stored_input[2].content == "Now what is 3+3?"
# Verify the response itself is correct
assert result.model == model
assert result.status == "completed"

View file

@ -9,7 +9,7 @@ from tempfile import TemporaryDirectory
import pytest
from llama_stack.providers.utils.sqlstore.api import ColumnType
from llama_stack.providers.utils.sqlstore.sqlite.sqlite import SqliteSqlStoreImpl
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@ -17,7 +17,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
async def test_sqlite_sqlstore():
with TemporaryDirectory() as tmp_dir:
db_name = "test.db"
sqlstore = SqliteSqlStoreImpl(
sqlstore = SqlAlchemySqlStoreImpl(
SqliteSqlStoreConfig(
db_path=tmp_dir + "/" + db_name,
)