Merge branch 'main' into load-quickstart-py

This commit is contained in:
Francisco Arceo 2025-07-21 22:51:49 -04:00 committed by GitHub
commit 6411b053ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 647 additions and 165 deletions

View file

@ -276,8 +276,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
config = parse_and_maybe_upgrade_config(config_dict) config = parse_and_maybe_upgrade_config(config_dict)
if config.external_providers_dir and not config.external_providers_dir.exists(): if config.external_providers_dir and not config.external_providers_dir.exists():
config.external_providers_dir.mkdir(exist_ok=True) config.external_providers_dir.mkdir(exist_ok=True)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template) run_args = formulate_run_args(args.image_type, args.image_name)
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config]) run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", str(run_config)])
run_command(run_args) run_command(run_args)

View file

@ -82,39 +82,6 @@ class StackRun(Subcommand):
return ImageType.CONDA.value, args.image_name return ImageType.CONDA.value, args.image_name
return args.image_type, args.image_name return args.image_type, args.image_name
def _resolve_config_and_template(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
"""Resolve config file path and template name from args.config"""
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
if not args.config:
return None, None
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
return config_file, template_name
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml import yaml
@ -125,8 +92,15 @@ class StackRun(Subcommand):
self._start_ui_development_server(args.port) self._start_ui_development_server(args.port)
image_type, image_name = self._get_image_type_and_name(args) image_type, image_name = self._get_image_type_and_name(args)
# Resolve config file and template name first if args.config:
config_file, template_name = self._resolve_config_and_template(args) try:
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
config_file = resolve_config_or_template(args.config, Mode.RUN)
except ValueError as e:
self.parser.error(str(e))
else:
config_file = None
# Check if config is required based on image type # Check if config is required based on image type
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file: if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file:
@ -164,18 +138,14 @@ class StackRun(Subcommand):
if callable(getattr(args, arg)): if callable(getattr(args, arg)):
continue continue
if arg == "config": if arg == "config":
if template_name: server_args.config = str(config_file)
server_args.template = str(template_name)
else:
# Set the config file path
server_args.config = str(config_file)
else: else:
setattr(server_args, arg, getattr(args, arg)) setattr(server_args, arg, getattr(args, arg))
# Run the server # Run the server
server_main(server_args) server_main(server_args)
else: else:
run_args = formulate_run_args(image_type, image_name, config, template_name) run_args = formulate_run_args(image_type, image_name)
run_args.extend([str(args.port)]) run_args.extend([str(args.port)])

31
llama_stack/cli/utils.py Normal file
View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
def add_config_template_args(parser: argparse.ArgumentParser):
"""Add unified config/template arguments with backward compatibility."""
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"config",
nargs="?",
help="Configuration file path or template name",
)
# Backward compatibility arguments (deprecated)
group.add_argument(
"--config",
dest="config",
help="(DEPRECATED) Use positional argument [config] instead. Configuration file path",
)
group.add_argument(
"--template",
dest="config",
help="(DEPRECATED) Use positional argument [config] instead. Template name",
)

View file

@ -214,9 +214,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str, vector_store_id: str,
) -> VectorStoreObject: ) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}") logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
# Route based on vector store ID return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store( async def openai_update_vector_store(
self, self,
@ -226,9 +224,7 @@ class VectorIORouter(VectorIO):
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
) -> VectorStoreObject: ) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}") logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
# Route based on vector store ID return await self.routing_table.openai_update_vector_store(
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
name=name, name=name,
expires_after=expires_after, expires_after=expires_after,
@ -240,12 +236,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str, vector_store_id: str,
) -> VectorStoreDeleteResponse: ) -> VectorStoreDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}") logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
# Route based on vector store ID return await self.routing_table.openai_delete_vector_store(vector_store_id)
provider = self.routing_table.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
# drop from registry
await self.routing_table.unregister_vector_db(vector_store_id)
return result
async def openai_search_vector_store( async def openai_search_vector_store(
self, self,
@ -258,9 +249,7 @@ class VectorIORouter(VectorIO):
search_mode: str | None = "vector", search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage: ) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}") logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
# Route based on vector store ID return await self.routing_table.openai_search_vector_store(
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
query=query, query=query,
filters=filters, filters=filters,
@ -278,9 +267,7 @@ class VectorIORouter(VectorIO):
chunking_strategy: VectorStoreChunkingStrategy | None = None, chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}") logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
# Route based on vector store ID return await self.routing_table.openai_attach_file_to_vector_store(
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
attributes=attributes, attributes=attributes,
@ -297,9 +284,7 @@ class VectorIORouter(VectorIO):
filter: VectorStoreFileStatus | None = None, filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]: ) -> list[VectorStoreFileObject]:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}") logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
# Route based on vector store ID return await self.routing_table.openai_list_files_in_vector_store(
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
limit=limit, limit=limit,
order=order, order=order,
@ -314,9 +299,7 @@ class VectorIORouter(VectorIO):
file_id: str, file_id: str,
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}") logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
# Route based on vector store ID return await self.routing_table.openai_retrieve_vector_store_file(
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
) )
@ -327,9 +310,7 @@ class VectorIORouter(VectorIO):
file_id: str, file_id: str,
) -> VectorStoreFileContentsResponse: ) -> VectorStoreFileContentsResponse:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}") logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
# Route based on vector store ID return await self.routing_table.openai_retrieve_vector_store_file_contents(
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
) )
@ -341,9 +322,7 @@ class VectorIORouter(VectorIO):
attributes: dict[str, Any], attributes: dict[str, Any],
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}") logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
# Route based on vector store ID return await self.routing_table.openai_update_vector_store_file(
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
attributes=attributes, attributes=attributes,
@ -355,9 +334,7 @@ class VectorIORouter(VectorIO):
file_id: str, file_id: str,
) -> VectorStoreFileDeleteResponse: ) -> VectorStoreFileDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}") logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
# Route based on vector store ID return await self.routing_table.openai_delete_vector_store_file(
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
) )

View file

@ -9,6 +9,7 @@ from typing import Any
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.access_control.datatypes import Action
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
AccessRule, AccessRule,
RoutableObject, RoutableObject,
@ -209,6 +210,20 @@ class CommonRoutingTableImpl(RoutingTable):
await self.dist_registry.register(obj) await self.dist_registry.register(obj)
return obj return obj
async def assert_action_allowed(
self,
action: Action,
type: str,
identifier: str,
) -> None:
"""Fetch a registered object by type/identifier and enforce the given action permission."""
obj = await self.get_object_by_identifier(type, identifier)
if obj is None:
raise ValueError(f"{type.capitalize()} '{identifier}' not found")
user = get_authenticated_user()
if not is_action_allowed(self.policy, action, obj, user):
raise AccessDeniedError(action, obj, user)
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all() objs = await self.dist_registry.get_all()
filtered_objs = [obj for obj in objs if obj.type == type] filtered_objs = [obj for obj in objs if obj.type == type]

View file

@ -4,11 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any
from pydantic import TypeAdapter from pydantic import TypeAdapter
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.apis.vector_io.vector_io import (
SearchRankingOptions,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
VectorDBWithOwner, VectorDBWithOwner,
) )
@ -74,3 +87,135 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if existing_vector_db is None: if existing_vector_db is None:
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
await self.unregister_object(existing_vector_db) await self.unregister_object(existing_vector_db)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
metadata=metadata,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id)
await self.unregister_vector_db(vector_store_id)
return result
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
search_mode=search_mode,
)
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
after=after,
before=before,
filter=filter,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)

View file

@ -32,6 +32,7 @@ from openai import BadRequestError
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.cli.utils import add_config_template_args
from llama_stack.distribution.access_control.access_control import AccessDeniedError from llama_stack.distribution.access_control.access_control import AccessDeniedError
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
AuthenticationRequiredError, AuthenticationRequiredError,
@ -53,6 +54,7 @@ from llama_stack.distribution.stack import (
validate_env_pair, validate_env_pair,
) )
from llama_stack.distribution.utils.config import redact_sensitive_fields from llama_stack.distribution.utils.config import redact_sensitive_fields
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -377,20 +379,8 @@ class ClientVersionMiddleware:
def main(args: argparse.Namespace | None = None): def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server.""" """Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.") parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
"--yaml-config", add_config_template_args(parser)
dest="config",
help="(Deprecated) Path to YAML configuration file - use --config instead",
)
parser.add_argument(
"--config",
dest="config",
help="Path to YAML configuration file",
)
parser.add_argument(
"--template",
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
)
parser.add_argument( parser.add_argument(
"--port", "--port",
type=int, type=int,
@ -409,20 +399,7 @@ def main(args: argparse.Namespace | None = None):
if args is None: if args is None:
args = parser.parse_args() args = parser.parse_args()
log_line = "" config_file = resolve_config_or_template(args.config, Mode.RUN)
if hasattr(args, "config") and args.config:
# if the user provided a config file, use it, even if template was specified
config_file = Path(args.config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
log_line = f"Using config file: {config_file}"
elif hasattr(args, "template") and args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
log_line = f"Using template {args.template} config file: {config_file}"
else:
raise ValueError("Either --config or --template must be provided")
logger_config = None logger_config = None
with open(config_file) as fp: with open(config_file) as fp:
@ -442,9 +419,6 @@ def main(args: argparse.Namespace | None = None):
config = replace_env_vars(config_contents) config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config)) config = StackRunConfig(**cast_image_name_to_string(config))
# now that the logger is initialized, print the line about which type of config we are using.
logger.info(log_line)
_log_run_config(run_config=config) _log_run_config(run_config=config)
app = FastAPI( app = FastAPI(
@ -592,12 +566,29 @@ def main(args: argparse.Namespace | None = None):
"port": port, "port": port,
"lifespan": "on", "lifespan": "on",
"log_level": logger.getEffectiveLevel(), "log_level": logger.getEffectiveLevel(),
"log_config": logger_config,
} }
if ssl_config: if ssl_config:
uvicorn_config.update(ssl_config) uvicorn_config.update(ssl_config)
# Run uvicorn in the existing event loop to preserve background tasks # Run uvicorn in the existing event loop to preserve background tasks
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) # We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
# stack trace when using Ctrl+C or kill -2 (SIGINT).
# SIGTERM (kill -15) works fine without this because Python doesn't
# have a default handler for it.
#
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort.
try:
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
finally:
if not loop.is_closed():
logger.debug("Closing event loop")
loop.close()
def _log_run_config(run_config: StackRunConfig): def _log_run_config(run_config: StackRunConfig):

View file

@ -117,7 +117,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x set -x
if [ -n "$yaml_config" ]; then if [ -n "$yaml_config" ]; then
yaml_config_arg="--config $yaml_config" yaml_config_arg="$yaml_config"
else else
yaml_config_arg="" yaml_config_arg=""
fi fi

View file

@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from pathlib import Path
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="config_resolution")
TEMPLATE_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "templates"
class Mode(StrEnum):
RUN = "run"
BUILD = "build"
def resolve_config_or_template(
config_or_template: str,
mode: Mode = Mode.RUN,
) -> Path:
"""
Resolve a config/template argument to a concrete config file path.
Args:
config_or_template: User input (file path, template name, or built distribution)
mode: Mode resolving for ("run", "build", "server")
Returns:
Path to the resolved config file
Raises:
ValueError: If resolution fails
"""
# Strategy 1: Try as file path first
config_path = Path(config_or_template)
if config_path.exists() and config_path.is_file():
logger.info(f"Using file path: {config_path}")
return config_path.resolve()
# Strategy 2: Try as template name (if no .yaml extension)
if not config_or_template.endswith(".yaml"):
template_config = _get_template_config_path(config_or_template, mode)
if template_config.exists():
logger.info(f"Using template: {template_config}")
return template_config
# Strategy 3: Try as built distribution name
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
return distrib_config
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
return distrib_config
# Strategy 4: Failed - provide helpful error
raise ValueError(_format_resolution_error(config_or_template, mode))
def _get_template_config_path(template_name: str, mode: Mode) -> Path:
"""Get the config file path for a template."""
return TEMPLATE_DIR / template_name / f"{mode}.yaml"
def _format_resolution_error(config_or_template: str, mode: Mode) -> str:
"""Format a helpful error message for resolution failures."""
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
template_path = _get_template_config_path(config_or_template, mode)
distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
available_templates = _get_available_templates()
templates_str = ", ".join(available_templates) if available_templates else "none found"
return f"""Could not resolve config or template '{config_or_template}'.
Tried the following locations:
1. As file path: {Path(config_or_template).resolve()}
2. As template: {template_path}
3. As built distribution: ({distrib_path}, {distrib_path2})
Available templates: {templates_str}
Did you mean one of these templates?
{_format_template_suggestions(available_templates, config_or_template)}
"""
def _get_available_templates() -> list[str]:
"""Get list of available template names."""
if not TEMPLATE_DIR.exists() and not DISTRIBS_BASE_DIR.exists():
return []
return list(
set(
[d.name for d in TEMPLATE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
+ [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
)
)
def _format_template_suggestions(templates: list[str], user_input: str) -> str:
"""Format template suggestions for error messages, showing closest matches first."""
if not templates:
return " (no templates found)"
import difflib
# Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive)
close_matches = difflib.get_close_matches(user_input, templates, n=3, cutoff=0.3)
display_templates = close_matches if close_matches else templates[:3]
suggestions = [f" - {t}" for t in display_templates]
return "\n".join(suggestions)

View file

@ -21,7 +21,7 @@ from pathlib import Path
from llama_stack.distribution.utils.image_types import LlamaStackImageType from llama_stack.distribution.utils.image_types import LlamaStackImageType
def formulate_run_args(image_type, image_name, config, template_name) -> list: def formulate_run_args(image_type: str, image_name: str) -> list[str]:
env_name = "" env_name = ""
if image_type == LlamaStackImageType.CONDA.value: if image_type == LlamaStackImageType.CONDA.value:

View file

@ -50,6 +50,7 @@ def setup_telemetry_data(llama_stack_client, text_model_id):
yield yield
@pytest.mark.skip(reason="Skipping telemetry tests for now")
def test_query_traces_basic(llama_stack_client): def test_query_traces_basic(llama_stack_client):
"""Test basic trace querying functionality with proper data validation.""" """Test basic trace querying functionality with proper data validation."""
all_traces = llama_stack_client.telemetry.query_traces(limit=5) all_traces = llama_stack_client.telemetry.query_traces(limit=5)
@ -105,6 +106,7 @@ def test_query_traces_basic(llama_stack_client):
assert hasattr(trace, "root_span_id") and trace.root_span_id, "Each trace should have non-empty root_span_id" assert hasattr(trace, "root_span_id") and trace.root_span_id, "Each trace should have non-empty root_span_id"
@pytest.mark.skip(reason="Skipping telemetry tests for now")
def test_query_spans_basic(llama_stack_client): def test_query_spans_basic(llama_stack_client):
"""Test basic span querying functionality with proper validation.""" """Test basic span querying functionality with proper validation."""
spans = llama_stack_client.telemetry.query_spans(attribute_filters=[], attributes_to_return=[]) spans = llama_stack_client.telemetry.query_spans(attribute_filters=[], attributes_to_return=[])
@ -153,6 +155,7 @@ def test_query_spans_basic(llama_stack_client):
assert hasattr(span, attr) and getattr(span, attr), f"All spans should have non-empty {attr}" assert hasattr(span, attr) and getattr(span, attr), f"All spans should have non-empty {attr}"
@pytest.mark.skip(reason="Skipping telemetry tests for now")
def test_telemetry_pagination(llama_stack_client): def test_telemetry_pagination(llama_stack_client):
"""Test pagination in telemetry queries.""" """Test pagination in telemetry queries."""
# Get total count of traces # Get total count of traces

View file

@ -11,17 +11,15 @@ from unittest.mock import AsyncMock
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model
from llama_stack.apis.shields.shields import Shield from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable
from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable
from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable
class Impl: class Impl:
@ -54,17 +52,6 @@ class SafetyImpl(Impl):
return shield return shield
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
class DatasetsImpl(Impl): class DatasetsImpl(Impl):
def __init__(self): def __init__(self):
super().__init__(Api.datasetio) super().__init__(Api.datasetio)
@ -173,36 +160,6 @@ async def test_shields_routing_table(cached_disk_dist_registry):
assert "test-shield-2" in shield_ids assert "test-shield-2" in shield_ids
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
# Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids
assert "test-vectordb-2" in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb")
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_datasets_routing_table(cached_disk_dist_registry): async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {}) table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,274 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Unit tests for the routing tables vector_dbs
import time
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.apis.vector_io.vector_io import (
VectorStoreContent,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileCounts,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.distribution.access_control.datatypes import AccessRule, Scope
from llama_stack.distribution.datatypes import User
from llama_stack.distribution.request_headers import request_provider_data_context
from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable
from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
async def openai_retrieve_vector_store(self, vector_store_id):
return VectorStoreObject(
id=vector_store_id,
name="Test Store",
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def openai_update_vector_store(self, vector_store_id, **kwargs):
return VectorStoreObject(
id=vector_store_id,
name="Updated Store",
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def openai_delete_vector_store(self, vector_store_id):
return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True)
async def openai_search_vector_store(self, vector_store_id, query, **kwargs):
return VectorStoreSearchResponsePage(
object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None
)
async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs):
return [
VectorStoreFileObject(
id="1",
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
]
async def openai_retrieve_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id):
return VectorStoreFileContentsResponse(
file_id=file_id,
filename="Sample File name",
attributes={"key": "value"},
content=[VectorStoreContent(type="text", text="Sample content")],
)
async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
# Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids
assert "test-vectordb-2" in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb")
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
impl = VectorDBImpl()
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[])
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
authorized_table = "vs1"
authorized_team = "team1"
unauthorized_team = "team2"
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
with request_provider_data_context({}, authorized_user):
_ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
# Authorized reader
with request_provider_data_context({}, authorized_user):
res = await table.openai_retrieve_vector_store(authorized_table)
assert res == "OK"
# Authorized updater
impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED")
with request_provider_data_context({}, authorized_user):
res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
assert res == "UPDATED"
# Unauthorized reader
unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]})
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_retrieve_vector_store(authorized_table)
# Unauthorized updater
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
# Authorized deleter
impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED")
with request_provider_data_context({}, authorized_user):
res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
assert res == "DELETED"
# Unauthorized deleter
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry):
impl = VectorDBImpl()
policy = [
AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"),
AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"),
]
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy)
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
vector_db_id = "vs1"
file_id = "file-1"
admin_user = User(principal="admin", attributes={"roles": ["admin"]})
read_only_user = User(principal="reader", attributes={"roles": ["reader"]})
no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
with request_provider_data_context({}, admin_user):
await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
read_methods = [
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
(table.openai_search_vector_store, (vector_db_id, "query"), {}),
(table.openai_list_files_in_vector_store, (vector_db_id,), {}),
(table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}),
(table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}),
]
update_methods = [
(table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}),
(table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}),
(table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}),
]
delete_methods = [
(table.openai_delete_vector_store_file, (vector_db_id, file_id), {}),
(table.openai_delete_vector_store, (vector_db_id,), {}),
]
for user in [admin_user, read_only_user]:
with request_provider_data_context({}, user):
for method, args, kwargs in read_methods:
result = await method(*args, **kwargs)
assert result is not None, f"Read operation failed with user {user.principal}"
with request_provider_data_context({}, no_access_user):
for method, args, kwargs in read_methods:
with pytest.raises(ValueError):
await method(*args, **kwargs)
with request_provider_data_context({}, admin_user):
for method, args, kwargs in update_methods:
result = await method(*args, **kwargs)
assert result is not None, "Update operation failed with admin user"
with request_provider_data_context({}, admin_user):
for method, args, kwargs in delete_methods:
result = await method(*args, **kwargs)
assert result is not None, "Delete operation failed with admin user"
for user in [read_only_user, no_access_user]:
with request_provider_data_context({}, user):
for method, args, kwargs in delete_methods:
with pytest.raises(ValueError):
await method(*args, **kwargs)

View file

@ -14,7 +14,6 @@ from llama_stack.apis.common.content_types import URL, TextContentItem
from llama_stack.providers.inline.agents.meta_reference.agent_instance import get_raw_document_text from llama_stack.providers.inline.agents.meta_reference.agent_instance import get_raw_document_text
@pytest.mark.asyncio
async def test_get_raw_document_text_supports_text_mime_types(): async def test_get_raw_document_text_supports_text_mime_types():
"""Test that the function accepts text/* mime types.""" """Test that the function accepts text/* mime types."""
document = Document(content="Sample text content", mime_type="text/plain") document = Document(content="Sample text content", mime_type="text/plain")
@ -23,7 +22,6 @@ async def test_get_raw_document_text_supports_text_mime_types():
assert result == "Sample text content" assert result == "Sample text content"
@pytest.mark.asyncio
async def test_get_raw_document_text_supports_yaml_mime_type(): async def test_get_raw_document_text_supports_yaml_mime_type():
"""Test that the function accepts application/yaml mime type.""" """Test that the function accepts application/yaml mime type."""
yaml_content = """ yaml_content = """
@ -40,7 +38,6 @@ async def test_get_raw_document_text_supports_yaml_mime_type():
assert result == yaml_content assert result == yaml_content
@pytest.mark.asyncio
async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning(): async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning():
"""Test that the function accepts text/yaml but emits a deprecation warning.""" """Test that the function accepts text/yaml but emits a deprecation warning."""
yaml_content = """ yaml_content = """
@ -68,7 +65,6 @@ async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning(
assert "deprecated" in str(w[0].message).lower() assert "deprecated" in str(w[0].message).lower()
@pytest.mark.asyncio
async def test_get_raw_document_text_deprecated_text_yaml_with_url(): async def test_get_raw_document_text_deprecated_text_yaml_with_url():
"""Test that text/yaml works with URL content and emits warning.""" """Test that text/yaml works with URL content and emits warning."""
yaml_content = "name: test\nversion: 1.0" yaml_content = "name: test\nversion: 1.0"
@ -92,7 +88,6 @@ async def test_get_raw_document_text_deprecated_text_yaml_with_url():
assert "text/yaml" in str(w[0].message) assert "text/yaml" in str(w[0].message)
@pytest.mark.asyncio
async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item(): async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item():
"""Test that text/yaml works with TextContentItem and emits warning.""" """Test that text/yaml works with TextContentItem and emits warning."""
yaml_content = "key: value\nlist:\n - item1\n - item2" yaml_content = "key: value\nlist:\n - item1\n - item2"
@ -112,7 +107,6 @@ async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item
assert "text/yaml" in str(w[0].message) assert "text/yaml" in str(w[0].message)
@pytest.mark.asyncio
async def test_get_raw_document_text_rejects_unsupported_mime_types(): async def test_get_raw_document_text_rejects_unsupported_mime_types():
"""Test that the function rejects unsupported mime types.""" """Test that the function rejects unsupported mime types."""
document = Document( document = Document(
@ -124,7 +118,6 @@ async def test_get_raw_document_text_rejects_unsupported_mime_types():
await get_raw_document_text(document) await get_raw_document_text(document)
@pytest.mark.asyncio
async def test_get_raw_document_text_with_url_content(): async def test_get_raw_document_text_with_url_content():
"""Test that the function handles URL content correctly.""" """Test that the function handles URL content correctly."""
mock_response = AsyncMock() mock_response = AsyncMock()
@ -140,7 +133,6 @@ async def test_get_raw_document_text_with_url_content():
mock_load.assert_called_once_with("https://example.com/test.txt") mock_load.assert_called_once_with("https://example.com/test.txt")
@pytest.mark.asyncio
async def test_get_raw_document_text_with_yaml_url(): async def test_get_raw_document_text_with_yaml_url():
"""Test that the function handles YAML URLs correctly.""" """Test that the function handles YAML URLs correctly."""
yaml_content = "name: test\nversion: 1.0" yaml_content = "name: test\nversion: 1.0"
@ -155,7 +147,6 @@ async def test_get_raw_document_text_with_yaml_url():
mock_load.assert_called_once_with("https://example.com/config.yaml") mock_load.assert_called_once_with("https://example.com/config.yaml")
@pytest.mark.asyncio
async def test_get_raw_document_text_with_text_content_item(): async def test_get_raw_document_text_with_text_content_item():
"""Test that the function handles TextContentItem correctly.""" """Test that the function handles TextContentItem correctly."""
document = Document(content=TextContentItem(text="Text content item"), mime_type="text/plain") document = Document(content=TextContentItem(text="Text content item"), mime_type="text/plain")
@ -164,7 +155,6 @@ async def test_get_raw_document_text_with_text_content_item():
assert result == "Text content item" assert result == "Text content item"
@pytest.mark.asyncio
async def test_get_raw_document_text_with_yaml_text_content_item(): async def test_get_raw_document_text_with_yaml_text_content_item():
"""Test that the function handles YAML TextContentItem correctly.""" """Test that the function handles YAML TextContentItem correctly."""
yaml_content = "key: value\nlist:\n - item1\n - item2" yaml_content = "key: value\nlist:\n - item1\n - item2"
@ -175,7 +165,6 @@ async def test_get_raw_document_text_with_yaml_text_content_item():
assert result == yaml_content assert result == yaml_content
@pytest.mark.asyncio
async def test_get_raw_document_text_rejects_unexpected_content_type(): async def test_get_raw_document_text_rejects_unexpected_content_type():
"""Test that the function rejects unexpected document content types.""" """Test that the function rejects unexpected document content types."""
# Create a mock document that bypasses Pydantic validation # Create a mock document that bypasses Pydantic validation