Merge branch 'llamastack:main' into model_unregisteration_error_message

This commit is contained in:
Omar Abdelwahab 2025-10-06 10:14:30 -07:00 committed by GitHub
commit aa09a44c94
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
1036 changed files with 314835 additions and 114394 deletions

View file

@ -324,14 +324,14 @@ fi
RUN pip uninstall -y uv
EOF
# If a run config is provided, we use the --config flag
# If a run config is provided, we use the llama stack CLI
if [[ -n "$run_config" ]]; then
add_to_container << EOF
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$RUN_CONFIG_PATH"]
ENTRYPOINT ["llama", "stack", "run", "$RUN_CONFIG_PATH"]
EOF
elif [[ "$distro_or_config" != *.yaml ]]; then
add_to_container << EOF
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$distro_or_config"]
ENTRYPOINT ["llama", "stack", "run", "$distro_or_config"]
EOF
fi

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,306 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import secrets
import time
from typing import Any
from openai import NOT_GIVEN
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.conversations.conversations import (
Conversation,
ConversationDeletedResource,
ConversationItem,
ConversationItemDeletedResource,
ConversationItemList,
Conversations,
Metadata,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import (
SqliteSqlStoreConfig,
SqlStoreConfig,
sqlstore_impl,
)
logger = get_logger(name=__name__, category="openai::conversations")
class ConversationServiceConfig(BaseModel):
"""Configuration for the built-in conversation service.
:param conversations_store: SQL store configuration for conversations (defaults to SQLite)
:param policy: Access control rules
"""
conversations_store: SqlStoreConfig = SqliteSqlStoreConfig(
db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix()
)
policy: list[AccessRule] = []
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
"""Get the conversation service implementation."""
impl = ConversationServiceImpl(config, deps)
await impl.initialize()
return impl
class ConversationServiceImpl(Conversations):
"""Built-in conversation service implementation using AuthorizedSqlStore."""
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
self.config = config
self.deps = deps
self.policy = config.policy
base_sql_store = sqlstore_impl(config.conversations_store)
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
async def initialize(self) -> None:
"""Initialize the store and create tables."""
if isinstance(self.config.conversations_store, SqliteSqlStoreConfig):
os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True)
await self.sql_store.create_table(
"openai_conversations",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"created_at": ColumnType.INTEGER,
"items": ColumnType.JSON,
"metadata": ColumnType.JSON,
},
)
await self.sql_store.create_table(
"conversation_items",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"conversation_id": ColumnType.STRING,
"created_at": ColumnType.INTEGER,
"item_data": ColumnType.JSON,
},
)
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation."""
random_bytes = secrets.token_bytes(24)
conversation_id = f"conv_{random_bytes.hex()}"
created_at = int(time.time())
record_data = {
"id": conversation_id,
"created_at": created_at,
"items": [],
"metadata": metadata,
}
await self.sql_store.insert(
table="openai_conversations",
data=record_data,
)
if items:
item_records = []
for item in items:
item_dict = item.model_dump()
item_id = self._get_or_generate_item_id(item, item_dict)
item_record = {
"id": item_id,
"conversation_id": conversation_id,
"created_at": created_at,
"item_data": item_dict,
}
item_records.append(item_record)
await self.sql_store.insert(table="conversation_items", data=item_records)
conversation = Conversation(
id=conversation_id,
created_at=created_at,
metadata=metadata,
object="conversation",
)
logger.info(f"Created conversation {conversation_id}")
return conversation
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Get a conversation with the given ID."""
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
if record is None:
raise ValueError(f"Conversation {conversation_id} not found")
return Conversation(
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation's metadata with the given ID"""
await self.sql_store.update(
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
)
return await self.get_conversation(conversation_id)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation with the given ID."""
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
logger.info(f"Deleted conversation {conversation_id}")
return ConversationDeletedResource(id=conversation_id)
def _validate_conversation_id(self, conversation_id: str) -> None:
"""Validate conversation ID format."""
if not conversation_id.startswith("conv_"):
raise ValueError(
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
)
def _get_or_generate_item_id(self, item: ConversationItem, item_dict: dict) -> str:
"""Get existing item ID or generate one if missing."""
if item.id is None:
random_bytes = secrets.token_bytes(24)
if item.type == "message":
item_id = f"msg_{random_bytes.hex()}"
else:
item_id = f"item_{random_bytes.hex()}"
item_dict["id"] = item_id
return item_id
return item.id
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
"""Validate conversation ID and return the conversation if it exists."""
self._validate_conversation_id(conversation_id)
return await self.get_conversation(conversation_id)
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create (add) items to a conversation."""
await self._get_validated_conversation(conversation_id)
created_items = []
created_at = int(time.time())
for item in items:
item_dict = item.model_dump()
item_id = self._get_or_generate_item_id(item, item_dict)
item_record = {
"id": item_id,
"conversation_id": conversation_id,
"created_at": created_at,
"item_data": item_dict,
}
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
try:
await self.sql_store.insert(table="conversation_items", data=item_record)
except Exception:
# If insert fails due to ID conflict, update existing record
await self.sql_store.update(
table="conversation_items",
data={"created_at": created_at, "item_data": item_dict},
where={"id": item_id},
)
created_items.append(item_dict)
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
# Convert created items (dicts) to proper ConversationItem types
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
return ConversationItemList(
data=response_items,
first_id=created_items[0]["id"] if created_items else None,
last_id=created_items[-1]["id"] if created_items else None,
has_more=False,
)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
# Get item from conversation_items table
record = await self.sql_store.fetch_one(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
if record is None:
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
return adapter.validate_python(record["item_data"])
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
"""List items in the conversation."""
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
records = result.data
if order != NOT_GIVEN and order == "asc":
records.sort(key=lambda x: x["created_at"])
else:
records.sort(key=lambda x: x["created_at"], reverse=True)
actual_limit = 20
if limit != NOT_GIVEN and isinstance(limit, int):
actual_limit = limit
records = records[:actual_limit]
items = [record["item_data"] for record in records]
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items]
first_id = response_items[0].id if response_items else None
last_id = response_items[-1].id if response_items else None
return ConversationItemList(
data=response_items,
first_id=first_id,
last_id=last_id,
has_more=False,
)
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
_ = await self._get_validated_conversation(conversation_id)
record = await self.sql_store.fetch_one(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
if record is None:
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
await self.sql_store.delete(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
return ConversationItemDeletedResource(id=item_id)

View file

@ -22,7 +22,7 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.access_control.datatypes import AccessRule
@ -84,15 +84,11 @@ class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
pass
class ToolWithOwner(Tool, ResourceWithOwner):
pass
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
pass
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | ToolGroup
RoutableObjectWithProvider = Annotated[
ModelWithOwner
@ -101,7 +97,6 @@ RoutableObjectWithProvider = Annotated[
| DatasetWithOwner
| ScoringFnWithOwner
| BenchmarkWithOwner
| ToolWithOwner
| ToolGroupWithOwner,
Field(discriminator="type"),
]
@ -480,6 +475,13 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
If not specified, a default SQLite store will be used.""",
)
conversations_store: SqlStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the conversations API.
If not specified, a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list)

View file

@ -25,7 +25,7 @@ from llama_stack.providers.datatypes import (
logger = get_logger(name=__name__, category="core")
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts}
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
def stack_apis() -> list[Api]:
@ -243,6 +243,7 @@ def get_external_providers_from_module(
spec = module.get_provider_spec()
else:
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
# in the case we are building we CANNOT import this module of course because it has not been installed.
spec = ProviderSpec(
api=Api(provider_api),
provider_type=provider.provider_type,
@ -251,9 +252,20 @@ def get_external_providers_from_module(
config_class="",
)
provider_type = provider.provider_type
# in the case we are building we CANNOT import this module of course because it has not been installed.
# return a partially filled out spec that the build script will populate.
registry[Api(provider_api)][provider_type] = spec
if isinstance(spec, list):
# optionally allow people to pass inline and remote provider specs as a returned list.
# with the old method, users could pass in directories of specs using overlapping code
# we want to ensure we preserve that flexibility in this method.
logger.info(
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
)
for provider_spec in spec:
if provider_spec.provider_type != provider.provider_type:
continue
logger.info(f"Adding {provider.provider_type} to registry")
registry[Api(provider_api)][provider.provider_type] = provider_spec
else:
registry[Api(provider_api)][provider_type] = spec
except ModuleNotFoundError as exc:
raise ValueError(
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"

View file

@ -374,6 +374,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = options.params or {}
body |= options.json_data or {}
# Merge extra_json parameters (extra_body from SDK is converted to extra_json)
if hasattr(options, "extra_json") and options.extra_json:
body |= options.extra_json
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params

View file

@ -10,6 +10,7 @@ from typing import Any
from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec
@ -96,6 +97,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.tool_runtime: ToolRuntime,
Api.files: Files,
Api.prompts: Prompts,
Api.conversations: Conversations,
}
if external_apis:

View file

@ -27,7 +27,6 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk,
Inference,
ListOpenAIChatCompletionResponse,
LogProbConfig,
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
@ -42,12 +41,7 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
OpenAIResponseFormatParam,
Order,
ResponseFormat,
SamplingParams,
StopReason,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
@ -185,129 +179,6 @@ class InferenceRouter(Inference):
raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = None,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id, ModelType.llm)
if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match")
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
else:
params = {}
if tool_choice:
params["tool_choice"] = tool_choice
if tool_prompt_format:
params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params)
tools = tools or []
if tool_config.tool_choice == ToolChoice.none:
tools = []
elif tool_config.tool_choice == ToolChoice.auto:
pass
elif tool_config.tool_choice == ToolChoice.required:
pass
else:
# verify tool_choice is one of the tools
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
if tool_config.tool_choice not in tool_names:
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
params = dict(
model_id=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
provider = await self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream:
response_stream = await provider.chat_completion(**params)
return self.stream_tokens_and_compute_metrics(
response=response_stream,
prompt_tokens=prompt_tokens,
model=model,
tool_prompt_format=tool_config.tool_prompt_format,
)
response = await provider.chat_completion(**params)
metrics = await self.count_tokens_and_compute_metrics(
response=response,
prompt_tokens=prompt_tokens,
model=model,
tool_prompt_format=tool_config.tool_prompt_format,
)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
logger.debug(
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
)
model = await self._get_model(model_id, ModelType.llm)
provider = await self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
prompt_tokens = await self._count_tokens(content)
response = await provider.completion(**params)
if stream:
return self.stream_tokens_and_compute_metrics(
response=response,
prompt_tokens=prompt_tokens,
model=model,
)
metrics = await self.count_tokens_and_compute_metrics(
response=response, prompt_tokens=prompt_tokens, model=model
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def openai_completion(
self,
model: str,

View file

@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.tools import (
ListToolsResponse,
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
@ -86,6 +86,6 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolsResponse:
) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.list_tools(tool_group_id)

View file

@ -8,7 +8,7 @@ from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.errors import ToolGroupNotFoundError
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.apis.tools import ListToolDefsResponse, ListToolGroupsResponse, ToolDef, ToolGroup, ToolGroups
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
from llama_stack.log import get_logger
@ -27,7 +27,7 @@ def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
toolgroups_to_tools: dict[str, list[Tool]] = {}
toolgroups_to_tools: dict[str, list[ToolDef]] = {}
tool_to_toolgroup: dict[str, str] = {}
# overridden
@ -43,7 +43,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
routing_key = self.tool_to_toolgroup[routing_key]
return await super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
if toolgroup_id:
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
toolgroup_id = group_id
@ -68,30 +68,19 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
continue
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
return ListToolsResponse(data=all_tools)
return ListToolDefsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup):
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
# TODO: kill this Tool vs ToolDef distinction
tooldefs = tooldefs_response.data
tools = []
for t in tooldefs:
tools.append(
Tool(
identifier=t.name,
toolgroup_id=toolgroup.identifier,
description=t.description or "",
parameters=t.parameters or [],
metadata=t.metadata,
provider_id=toolgroup.provider_id,
)
)
t.toolgroup_id = toolgroup.identifier
self.toolgroups_to_tools[toolgroup.identifier] = tools
for tool in tools:
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
self.toolgroups_to_tools[toolgroup.identifier] = tooldefs
for tool in tooldefs:
self.tool_to_toolgroup[tool.name] = toolgroup.identifier
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
@ -102,12 +91,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
raise ToolGroupNotFoundError(toolgroup_id)
return tool_group
async def get_tool(self, tool_name: str) -> Tool:
async def get_tool(self, tool_name: str) -> ToolDef:
if tool_name in self.tool_to_toolgroup:
toolgroup_id = self.tool_to_toolgroup[tool_name]
tools = self.toolgroups_to_tools[toolgroup_id]
for tool in tools:
if tool.identifier == tool_name:
if tool.name == tool_name:
return tool
raise ValueError(f"Tool '{tool_name}' not found")
@ -132,7 +121,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
# baked in some of the code and tests right now.
if not toolgroup.mcp_endpoint:
await self._index_tools(toolgroup)
return toolgroup
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
await self.unregister_object(await self.get_tool_group(toolgroup_id))

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import concurrent.futures
import functools
@ -12,7 +11,6 @@ import inspect
import json
import logging # allow-direct-logging
import os
import ssl
import sys
import traceback
import warnings
@ -35,7 +33,6 @@ from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
from llama_stack.core.access_control.access_control import AccessDeniedError
from llama_stack.core.datatypes import (
AuthenticationRequiredError,
@ -55,7 +52,6 @@ from llama_stack.core.stack import (
Stack,
cast_image_name_to_string,
replace_env_vars,
validate_env_pair,
)
from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
@ -257,7 +253,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
return result
except Exception as e:
if logger.isEnabledFor(logging.DEBUG):
if logger.isEnabledFor(logging.INFO):
logger.exception(f"Error executing endpoint {route=} {method=}")
else:
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
@ -333,23 +329,18 @@ class ClientVersionMiddleware:
return await self.app(scope, receive, send)
def create_app(
config_file: str | None = None,
env_vars: list[str] | None = None,
) -> StackApp:
def create_app() -> StackApp:
"""Create and configure the FastAPI application.
Args:
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
env_vars: List of environment variables in KEY=value format.
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
This factory function reads configuration from environment variables:
- LLAMA_STACK_CONFIG: Path to config file (required)
Returns:
Configured StackApp instance.
"""
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
config_file = os.getenv("LLAMA_STACK_CONFIG")
if config_file is None:
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
raise ValueError("LLAMA_STACK_CONFIG environment variable is required")
config_file = resolve_config_or_distro(config_file, Mode.RUN)
@ -361,16 +352,6 @@ def create_app(
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="core::server", config=logger_config)
if env_vars:
for env_pair in env_vars:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config))
@ -451,6 +432,7 @@ def create_app(
apis_to_serve.add("inspect")
apis_to_serve.add("providers")
apis_to_serve.add("prompts")
apis_to_serve.add("conversations")
for api_str in apis_to_serve:
api = Api(api_str)
@ -493,101 +475,6 @@ def create_app(
return app
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
add_config_distro_args(parser)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
config_or_distro = get_config_from_args(args)
try:
app = create_app(
config_file=config_or_distro,
env_vars=args.env,
)
except Exception as e:
logger.error(f"Error creating app: {str(e)}")
sys.exit(1)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
else:
logger_config = None
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
import uvicorn
# Configure SSL if certificates are provided
port = args.port or config.server.port
ssl_config = None
keyfile = config.server.tls_keyfile
certfile = config.server.tls_certfile
if keyfile and certfile:
ssl_config = {
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
if config.server.tls_cafile:
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
logger.info(
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
)
else:
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = config.server.host or ["::", "0.0.0.0"]
logger.info(f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,
"host": listen_host,
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
"log_config": logger_config,
}
if ssl_config:
uvicorn_config.update(ssl_config)
# We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
# stack trace when using Ctrl+C or kill -2 (SIGINT).
# SIGTERM (kill -15) works fine without this because Python doesn't
# have a default handler for it.
#
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort.
try:
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
def _log_run_config(run_config: StackRunConfig):
"""Logs the run config with redacted fields and disabled providers removed."""
logger.info("Run configuration:")
@ -614,7 +501,3 @@ def remove_disabled_providers(obj):
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
else:
return obj
if __name__ == "__main__":
main()

View file

@ -15,6 +15,7 @@ import yaml
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
@ -34,6 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
@ -73,6 +75,7 @@ class LlamaStack(
RAGToolRuntime,
Files,
Prompts,
Conversations,
):
pass
@ -312,6 +315,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
)
impls[Api.prompts] = prompts_impl
conversations_impl = ConversationServiceImpl(
ConversationServiceConfig(run_config=run_config),
deps=impls,
)
impls[Api.conversations] = conversations_impl
class Stack:
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
@ -342,6 +351,8 @@ class Stack:
if Api.prompts in impls:
await impls[Api.prompts].initialize()
if Api.conversations in impls:
await impls[Api.conversations].initialize()
await register_resources(self.run_config, impls)

View file

@ -116,7 +116,7 @@ if [[ "$env_type" == "venv" ]]; then
yaml_config_arg=""
fi
$PYTHON_BINARY -m llama_stack.core.server.server \
llama stack run \
$yaml_config_arg \
--port "$port" \
$env_vars \

View file

@ -39,7 +39,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v9"
KEY_VERSION = "v10"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -81,7 +81,7 @@ def tool_chat_page():
for toolgroup_id in toolgroup_selection:
tools = client.tools.list(toolgroup_id=toolgroup_id)
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
grouped_tools[toolgroup_id] = [tool.name for tool in tools]
total_tools += len(tools)
st.markdown(f"Active Tools: 🛠 {total_tools}")