Merge branch 'main' into playground-ui

This commit is contained in:
Francisco Arceo 2025-07-22 16:37:18 -04:00 committed by GitHub
commit 5119671f02
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
52 changed files with 938 additions and 773 deletions

View file

@ -15,6 +15,9 @@ on:
jobs:
unit-tests:
permissions:
contents: write # for peter-evans/create-pull-request to create branch
pull-requests: write # for peter-evans/create-pull-request to create a PR
runs-on: ubuntu-latest
steps:
- name: Checkout repository

View file

@ -20,7 +20,7 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install uv
uses: astral-sh/setup-uv@bd01e18f51369d5a26f1651c3cb451d3417e3bba # v6.3.1
uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f # v6.4.1
with:
python-version: ${{ matrix.python-version }}
activate-environment: true

View file

@ -10,8 +10,13 @@ If in doubt, please open a [discussion](https://github.com/meta-llama/llama-stac
**I'd like to contribute!**
All issues are actionable (please report if they are not.) Pick one and start working on it. Thank you.
If you need help or guidance, comment on the issue. Issues that are extra friendly to new contributors are tagged with "contributor friendly".
If you are new to the project, start by looking at the issues tagged with "good first issue". If you're interested
leave a comment on the issue and a triager will assign it to you.
Please avoid picking up too many issues at once. This helps you stay focused and ensures that others in the community also have opportunities to contribute.
- Try to work on only 12 issues at a time, especially if youre still getting familiar with the codebase.
- Before taking an issue, check if its already assigned or being actively discussed.
- If youre blocked or cant continue with an issue, feel free to unassign yourself or leave a comment so others can step in.
**I have a bug!**
@ -41,6 +46,15 @@ If you need help or guidance, comment on the issue. Issues that are extra friend
4. Make sure your code lints using `pre-commit`.
5. If you haven't already, complete the Contributor License Agreement ("CLA").
6. Ensure your pull request follows the [conventional commits format](https://www.conventionalcommits.org/en/v1.0.0/).
7. Ensure your pull request follows the [coding style](#coding-style).
Please keep pull requests (PRs) small and focused. If you have a large set of changes, consider splitting them into logically grouped, smaller PRs to facilitate review and testing.
> [!TIP]
> As a general guideline:
> - Experienced contributors should try to keep no more than 5 open PRs at a time.
> - New contributors are encouraged to have only one open PR at a time until theyre familiar with the codebase and process.
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
@ -140,7 +154,9 @@ uv sync
* Don't use unicode characters in the codebase. ASCII-only is preferred for compatibility or
readability reasons.
* Providers configuration class should be Pydantic Field class. It should have a `description` field
that describes the configuration. These descriptions will be used to generate the provider documentation.
that describes the configuration. These descriptions will be used to generate the provider
documentation.
* When possible, use keyword arguments only when calling functions.
## Common Tasks

View file

@ -13596,9 +13596,6 @@
}
},
"additionalProperties": false,
"required": [
"name"
],
"title": "OpenaiCreateVectorStoreRequest"
},
"VectorStoreFileCounts": {

View file

@ -9497,8 +9497,6 @@ components:
description: >-
The ID of the provider to use for this vector store.
additionalProperties: false
required:
- name
title: OpenaiCreateVectorStoreRequest
VectorStoreFileCounts:
type: object

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient
vector_db_id = "my_demo_vector_db"
client = LlamaStackClient(base_url="http://localhost:8321")
models = client.models.list()
# Select the first LLM and first embedding models
model_id = next(m for m in models if m.model_type == "llm").identifier
embedding_model_id = (
em := next(m for m in models if m.model_type == "embedding")
).identifier
embedding_dimension = em.metadata["embedding_dimension"]
_ = client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
provider_id="faiss",
)
source = "https://www.paulgraham.com/greatwork.html"
print("rag_tool> Ingesting document:", source)
document = RAGDocument(
document_id="document_1",
content=source,
mime_type="text/html",
metadata={},
)
client.tool_runtime.rag_tool.insert(
documents=[document],
vector_db_id=vector_db_id,
chunk_size_in_tokens=50,
)
agent = Agent(
client,
model=model_id,
instructions="You are a helpful assistant",
tools=[
{
"name": "builtin::rag/knowledge_search",
"args": {"vector_db_ids": [vector_db_id]},
}
],
)
prompt = "How do you do great work?"
print("prompt>", prompt)
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=agent.create_session("rag_session"),
stream=True,
)
for log in AgentEventLogger().log(response):
log.print()

View file

@ -24,63 +24,8 @@ ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL=llama3.2:3b uv run --with llama-stac
#### Step 3: Run the demo
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
```python
from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient
vector_db_id = "my_demo_vector_db"
client = LlamaStackClient(base_url="http://localhost:8321")
models = client.models.list()
# Select the first LLM and first embedding models
model_id = next(m for m in models if m.model_type == "llm").identifier
embedding_model_id = (
em := next(m for m in models if m.model_type == "embedding")
).identifier
embedding_dimension = em.metadata["embedding_dimension"]
_ = client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
provider_id="faiss",
)
source = "https://www.paulgraham.com/greatwork.html"
print("rag_tool> Ingesting document:", source)
document = RAGDocument(
document_id="document_1",
content=source,
mime_type="text/html",
metadata={},
)
client.tool_runtime.rag_tool.insert(
documents=[document],
vector_db_id=vector_db_id,
chunk_size_in_tokens=50,
)
agent = Agent(
client,
model=model_id,
instructions="You are a helpful assistant",
tools=[
{
"name": "builtin::rag/knowledge_search",
"args": {"vector_db_ids": [vector_db_id]},
}
],
)
prompt = "How do you do great work?"
print("prompt>", prompt)
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=agent.create_session("rag_session"),
stream=True,
)
for log in AgentEventLogger().log(response):
log.print()
```{literalinclude} ./demo_script.py
:language: python
```
We will use `uv` to run the script
```

View file

@ -7,13 +7,10 @@ This section contains documentation for all available providers for the **infere
- [remote::anthropic](remote_anthropic.md)
- [remote::bedrock](remote_bedrock.md)
- [remote::cerebras](remote_cerebras.md)
- [remote::cerebras-openai-compat](remote_cerebras-openai-compat.md)
- [remote::databricks](remote_databricks.md)
- [remote::fireworks](remote_fireworks.md)
- [remote::fireworks-openai-compat](remote_fireworks-openai-compat.md)
- [remote::gemini](remote_gemini.md)
- [remote::groq](remote_groq.md)
- [remote::groq-openai-compat](remote_groq-openai-compat.md)
- [remote::hf::endpoint](remote_hf_endpoint.md)
- [remote::hf::serverless](remote_hf_serverless.md)
- [remote::llama-openai-compat](remote_llama-openai-compat.md)
@ -23,9 +20,7 @@ This section contains documentation for all available providers for the **infere
- [remote::passthrough](remote_passthrough.md)
- [remote::runpod](remote_runpod.md)
- [remote::sambanova](remote_sambanova.md)
- [remote::sambanova-openai-compat](remote_sambanova-openai-compat.md)
- [remote::tgi](remote_tgi.md)
- [remote::together](remote_together.md)
- [remote::together-openai-compat](remote_together-openai-compat.md)
- [remote::vllm](remote_vllm.md)
- [remote::watsonx](remote_watsonx.md)

View file

@ -17,7 +17,7 @@ That means you'll get fast and efficient vector retrieval.
To use PGVector in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Faiss.
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
3. Start storing and querying vectors.
## Installation

View file

@ -34,6 +34,7 @@ class VectorDBInput(BaseModel):
vector_db_id: str
embedding_model: str
embedding_dimension: int
provider_id: str | None = None
provider_vector_db_id: str | None = None

View file

@ -338,7 +338,7 @@ class VectorIO(Protocol):
@webmethod(route="/openai/v1/vector_stores", method="POST")
async def openai_create_vector_store(
self,
name: str,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,

View file

@ -276,8 +276,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
config = parse_and_maybe_upgrade_config(config_dict)
if config.external_providers_dir and not config.external_providers_dir.exists():
config.external_providers_dir.mkdir(exist_ok=True)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
run_args = formulate_run_args(args.image_type, args.image_name)
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", str(run_config)])
run_command(run_args)

View file

@ -82,39 +82,6 @@ class StackRun(Subcommand):
return ImageType.CONDA.value, args.image_name
return args.image_type, args.image_name
def _resolve_config_and_template(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
"""Resolve config file path and template name from args.config"""
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
if not args.config:
return None, None
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
return config_file, template_name
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml
@ -125,8 +92,15 @@ class StackRun(Subcommand):
self._start_ui_development_server(args.port)
image_type, image_name = self._get_image_type_and_name(args)
# Resolve config file and template name first
config_file, template_name = self._resolve_config_and_template(args)
if args.config:
try:
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
config_file = resolve_config_or_template(args.config, Mode.RUN)
except ValueError as e:
self.parser.error(str(e))
else:
config_file = None
# Check if config is required based on image type
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file:
@ -164,18 +138,14 @@ class StackRun(Subcommand):
if callable(getattr(args, arg)):
continue
if arg == "config":
if template_name:
server_args.template = str(template_name)
else:
# Set the config file path
server_args.config = str(config_file)
server_args.config = str(config_file)
else:
setattr(server_args, arg, getattr(args, arg))
# Run the server
server_main(server_args)
else:
run_args = formulate_run_args(image_type, image_name, config, template_name)
run_args = formulate_run_args(image_type, image_name)
run_args.extend([str(args.port)])

31
llama_stack/cli/utils.py Normal file
View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
def add_config_template_args(parser: argparse.ArgumentParser):
"""Add unified config/template arguments with backward compatibility."""
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"config",
nargs="?",
help="Configuration file path or template name",
)
# Backward compatibility arguments (deprecated)
group.add_argument(
"--config",
dest="config",
help="(DEPRECATED) Use positional argument [config] instead. Configuration file path",
)
group.add_argument(
"--template",
dest="config",
help="(DEPRECATED) Use positional argument [config] instead. Template name",
)

View file

@ -214,9 +214,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
@ -226,9 +224,7 @@ class VectorIORouter(VectorIO):
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
return await self.routing_table.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
@ -240,12 +236,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
) -> VectorStoreDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
# drop from registry
await self.routing_table.unregister_vector_db(vector_store_id)
return result
return await self.routing_table.openai_delete_vector_store(vector_store_id)
async def openai_search_vector_store(
self,
@ -258,9 +249,7 @@ class VectorIORouter(VectorIO):
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
return await self.routing_table.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
@ -278,9 +267,7 @@ class VectorIORouter(VectorIO):
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
return await self.routing_table.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
@ -297,9 +284,7 @@ class VectorIORouter(VectorIO):
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
return await self.routing_table.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
@ -314,9 +299,7 @@ class VectorIORouter(VectorIO):
file_id: str,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
return await self.routing_table.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
@ -327,9 +310,7 @@ class VectorIORouter(VectorIO):
file_id: str,
) -> VectorStoreFileContentsResponse:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
return await self.routing_table.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
@ -341,9 +322,7 @@ class VectorIORouter(VectorIO):
attributes: dict[str, Any],
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
return await self.routing_table.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
@ -355,9 +334,7 @@ class VectorIORouter(VectorIO):
file_id: str,
) -> VectorStoreFileDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
return await self.routing_table.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)

View file

@ -9,6 +9,7 @@ from typing import Any
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.access_control.datatypes import Action
from llama_stack.distribution.datatypes import (
AccessRule,
RoutableObject,
@ -209,6 +210,20 @@ class CommonRoutingTableImpl(RoutingTable):
await self.dist_registry.register(obj)
return obj
async def assert_action_allowed(
self,
action: Action,
type: str,
identifier: str,
) -> None:
"""Fetch a registered object by type/identifier and enforce the given action permission."""
obj = await self.get_object_by_identifier(type, identifier)
if obj is None:
raise ValueError(f"{type.capitalize()} '{identifier}' not found")
user = get_authenticated_user()
if not is_action_allowed(self.policy, action, obj, user):
raise AccessDeniedError(action, obj, user)
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
filtered_objs = [obj for obj in objs if obj.type == type]

View file

@ -4,11 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.apis.vector_io.vector_io import (
SearchRankingOptions,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.distribution.datatypes import (
VectorDBWithOwner,
)
@ -74,3 +87,135 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if existing_vector_db is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
await self.unregister_object(existing_vector_db)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
metadata=metadata,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id)
await self.unregister_vector_db(vector_store_id)
return result
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
search_mode=search_mode,
)
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
after=after,
before=before,
filter=filter,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)

View file

@ -32,6 +32,7 @@ from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.cli.utils import add_config_template_args
from llama_stack.distribution.access_control.access_control import AccessDeniedError
from llama_stack.distribution.datatypes import (
AuthenticationRequiredError,
@ -53,6 +54,7 @@ from llama_stack.distribution.stack import (
validate_env_pair,
)
from llama_stack.distribution.utils.config import redact_sensitive_fields
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
@ -377,20 +379,8 @@ class ClientVersionMiddleware:
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
"--yaml-config",
dest="config",
help="(Deprecated) Path to YAML configuration file - use --config instead",
)
parser.add_argument(
"--config",
dest="config",
help="Path to YAML configuration file",
)
parser.add_argument(
"--template",
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
)
add_config_template_args(parser)
parser.add_argument(
"--port",
type=int,
@ -409,20 +399,7 @@ def main(args: argparse.Namespace | None = None):
if args is None:
args = parser.parse_args()
log_line = ""
if hasattr(args, "config") and args.config:
# if the user provided a config file, use it, even if template was specified
config_file = Path(args.config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
log_line = f"Using config file: {config_file}"
elif hasattr(args, "template") and args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
log_line = f"Using template {args.template} config file: {config_file}"
else:
raise ValueError("Either --config or --template must be provided")
config_file = resolve_config_or_template(args.config, Mode.RUN)
logger_config = None
with open(config_file) as fp:
@ -442,9 +419,6 @@ def main(args: argparse.Namespace | None = None):
config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config))
# now that the logger is initialized, print the line about which type of config we are using.
logger.info(log_line)
_log_run_config(run_config=config)
app = FastAPI(
@ -592,12 +566,29 @@ def main(args: argparse.Namespace | None = None):
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
"log_config": logger_config,
}
if ssl_config:
uvicorn_config.update(ssl_config)
# Run uvicorn in the existing event loop to preserve background tasks
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
# We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
# stack trace when using Ctrl+C or kill -2 (SIGINT).
# SIGTERM (kill -15) works fine without this because Python doesn't
# have a default handler for it.
#
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort.
try:
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
finally:
if not loop.is_closed():
logger.debug("Closing event loop")
loop.close()
def _log_run_config(run_config: StackRunConfig):

View file

@ -117,7 +117,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x
if [ -n "$yaml_config" ]; then
yaml_config_arg="--config $yaml_config"
yaml_config_arg="$yaml_config"
else
yaml_config_arg=""
fi

View file

@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from pathlib import Path
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="config_resolution")
TEMPLATE_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "templates"
class Mode(StrEnum):
RUN = "run"
BUILD = "build"
def resolve_config_or_template(
config_or_template: str,
mode: Mode = Mode.RUN,
) -> Path:
"""
Resolve a config/template argument to a concrete config file path.
Args:
config_or_template: User input (file path, template name, or built distribution)
mode: Mode resolving for ("run", "build", "server")
Returns:
Path to the resolved config file
Raises:
ValueError: If resolution fails
"""
# Strategy 1: Try as file path first
config_path = Path(config_or_template)
if config_path.exists() and config_path.is_file():
logger.info(f"Using file path: {config_path}")
return config_path.resolve()
# Strategy 2: Try as template name (if no .yaml extension)
if not config_or_template.endswith(".yaml"):
template_config = _get_template_config_path(config_or_template, mode)
if template_config.exists():
logger.info(f"Using template: {template_config}")
return template_config
# Strategy 3: Try as built distribution name
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
return distrib_config
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
return distrib_config
# Strategy 4: Failed - provide helpful error
raise ValueError(_format_resolution_error(config_or_template, mode))
def _get_template_config_path(template_name: str, mode: Mode) -> Path:
"""Get the config file path for a template."""
return TEMPLATE_DIR / template_name / f"{mode}.yaml"
def _format_resolution_error(config_or_template: str, mode: Mode) -> str:
"""Format a helpful error message for resolution failures."""
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
template_path = _get_template_config_path(config_or_template, mode)
distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
available_templates = _get_available_templates()
templates_str = ", ".join(available_templates) if available_templates else "none found"
return f"""Could not resolve config or template '{config_or_template}'.
Tried the following locations:
1. As file path: {Path(config_or_template).resolve()}
2. As template: {template_path}
3. As built distribution: ({distrib_path}, {distrib_path2})
Available templates: {templates_str}
Did you mean one of these templates?
{_format_template_suggestions(available_templates, config_or_template)}
"""
def _get_available_templates() -> list[str]:
"""Get list of available template names."""
if not TEMPLATE_DIR.exists() and not DISTRIBS_BASE_DIR.exists():
return []
return list(
set(
[d.name for d in TEMPLATE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
+ [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
)
)
def _format_template_suggestions(templates: list[str], user_input: str) -> str:
"""Format template suggestions for error messages, showing closest matches first."""
if not templates:
return " (no templates found)"
import difflib
# Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive)
close_matches = difflib.get_close_matches(user_input, templates, n=3, cutoff=0.3)
display_templates = close_matches if close_matches else templates[:3]
suggestions = [f" - {t}" for t in display_templates]
return "\n".join(suggestions)

View file

@ -21,7 +21,7 @@ from pathlib import Path
from llama_stack.distribution.utils.image_types import LlamaStackImageType
def formulate_run_args(image_type, image_name, config, template_name) -> list:
def formulate_run_args(image_type: str, image_name: str) -> list[str]:
env_name = ""
if image_type == LlamaStackImageType.CONDA.value:

View file

@ -10,6 +10,7 @@ import re
import secrets
import string
import uuid
import warnings
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
@ -911,8 +912,16 @@ async def load_data_from_url(url: str) -> str:
async def get_raw_document_text(document: Document) -> str:
if not document.mime_type.startswith("text/"):
# Handle deprecated text/yaml mime type with warning
if document.mime_type == "text/yaml":
warnings.warn(
"The 'text/yaml' MIME type is deprecated. Please use 'application/yaml' instead.",
DeprecationWarning,
stacklevel=2,
)
elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"):
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
if isinstance(document.content, URL):
return await load_data_from_url(document.content.uri)
elif isinstance(document.content, str):

View file

@ -128,6 +128,11 @@ class AgentPersistence:
except Exception as e:
log.error(f"Error parsing turn: {e}")
continue
# The kvstore does not guarantee order, so we sort by started_at
# to ensure consistent ordering of turns.
turns.sort(key=lambda t: t.started_at)
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:

View file

@ -224,17 +224,6 @@ def available_providers() -> list[ProviderSpec]:
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="fireworks-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.fireworks_openai_compat",
config_class="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator",
description="Fireworks AI OpenAI-compatible provider for using Fireworks models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
@ -246,50 +235,6 @@ def available_providers() -> list[ProviderSpec]:
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="together-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.together_openai_compat",
config_class="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherProviderDataValidator",
description="Together AI OpenAI-compatible provider for using Together models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.groq_openai_compat",
config_class="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqProviderDataValidator",
description="Groq OpenAI-compatible provider for using Groq models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="sambanova-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.sambanova_openai_compat",
config_class="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaProviderDataValidator",
description="SambaNova OpenAI-compatible provider for using SambaNova models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="cerebras-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.cerebras_openai_compat",
config_class="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasProviderDataValidator",
description="Cerebras OpenAI-compatible provider for using Cerebras models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -395,7 +395,7 @@ That means you'll get fast and efficient vector retrieval.
To use PGVector in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Faiss.
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
3. Start storing and querying vectors.
## Installation

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import CerebrasCompatConfig
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .cerebras import CerebrasCompatInferenceAdapter
adapter = CerebrasCompatInferenceAdapter(config)
return adapter

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.cerebras_openai_compat.config import CerebrasCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from ..cerebras.models import MODEL_ENTRIES
class CerebrasCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: CerebrasCompatConfig
def __init__(self, config: CerebrasCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="cerebras_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class CerebrasProviderDataValidator(BaseModel):
cerebras_api_key: str | None = Field(
default=None,
description="API key for Cerebras models",
)
@json_schema_type
class CerebrasCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The Cerebras API key",
)
openai_compat_api_base: str = Field(
default="https://api.cerebras.ai/v1",
description="The URL for the Cerebras API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.cerebras.ai/v1",
"api_key": api_key,
}

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import FireworksCompatConfig
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .fireworks import FireworksCompatInferenceAdapter
adapter = FireworksCompatInferenceAdapter(config)
return adapter

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class FireworksProviderDataValidator(BaseModel):
fireworks_api_key: str | None = Field(
default=None,
description="API key for Fireworks models",
)
@json_schema_type
class FireworksCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The Fireworks API key",
)
openai_compat_api_base: str = Field(
default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.fireworks.ai/inference/v1",
"api_key": api_key,
}

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.fireworks_openai_compat.config import FireworksCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from ..fireworks.models import MODEL_ENTRIES
class FireworksCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: FireworksCompatConfig
def __init__(self, config: FireworksCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="fireworks_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import GroqCompatConfig
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .groq import GroqCompatInferenceAdapter
adapter = GroqCompatInferenceAdapter(config)
return adapter

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class GroqProviderDataValidator(BaseModel):
groq_api_key: str | None = Field(
default=None,
description="API key for Groq models",
)
@json_schema_type
class GroqCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The Groq API key",
)
openai_compat_api_base: str = Field(
default="https://api.groq.com/openai/v1",
description="The URL for the Groq API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.groq.com/openai/v1",
"api_key": api_key,
}

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.groq_openai_compat.config import GroqCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from ..groq.models import MODEL_ENTRIES
class GroqCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: GroqCompatConfig
def __init__(self, config: GroqCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="groq_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import SambaNovaCompatConfig
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .sambanova import SambaNovaCompatInferenceAdapter
adapter = SambaNovaCompatInferenceAdapter(config)
return adapter

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class SambaNovaProviderDataValidator(BaseModel):
sambanova_api_key: str | None = Field(
default=None,
description="API key for SambaNova models",
)
@json_schema_type
class SambaNovaCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The SambaNova API key",
)
openai_compat_api_base: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.sambanova.ai/v1",
"api_key": api_key,
}

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.sambanova_openai_compat.config import SambaNovaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from ..sambanova.models import MODEL_ENTRIES
class SambaNovaCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: SambaNovaCompatConfig
def __init__(self, config: SambaNovaCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="sambanova_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import TogetherCompatConfig
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .together import TogetherCompatInferenceAdapter
adapter = TogetherCompatInferenceAdapter(config)
return adapter

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class TogetherProviderDataValidator(BaseModel):
together_api_key: str | None = Field(
default=None,
description="API key for Together models",
)
@json_schema_type
class TogetherCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The Together API key",
)
openai_compat_api_base: str = Field(
default="https://api.together.xyz/v1",
description="The URL for the Together API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.TOGETHER_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.together.xyz/v1",
"api_key": api_key,
}

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.together_openai_compat.config import TogetherCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from ..together.models import MODEL_ENTRIES
class TogetherCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: TogetherCompatConfig
def __init__(self, config: TogetherCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="together_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -161,7 +161,7 @@ class OpenAIVectorStoreMixin(ABC):
async def openai_create_vector_store(
self,
name: str,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,

View file

@ -83,6 +83,7 @@ class SQLiteTraceStore(TraceStore):
)
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
FROM filtered_traces
WHERE root_span_id IS NOT NULL
LIMIT {limit} OFFSET {offset}
"""
@ -166,7 +167,11 @@ class SQLiteTraceStore(TraceStore):
return spans_by_id
async def get_trace(self, trace_id: str) -> Trace:
query = "SELECT * FROM traces WHERE trace_id = ?"
query = """
SELECT *
FROM traces t
WHERE t.trace_id = ?
"""
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, (trace_id,)) as cursor:

View file

@ -19,12 +19,7 @@ distribution_spec:
- remote::anthropic
- remote::gemini
- remote::groq
- remote::fireworks-openai-compat
- remote::llama-openai-compat
- remote::together-openai-compat
- remote::groq-openai-compat
- remote::sambanova-openai-compat
- remote::cerebras-openai-compat
- remote::sambanova
- remote::passthrough
- inline::sentence-transformers

View file

@ -90,36 +90,11 @@ providers:
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY}
- provider_id: ${env.ENABLE_FIREWORKS_OPENAI_COMPAT:=__disabled__}
provider_type: remote::fireworks-openai-compat
config:
openai_compat_api_base: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY}
- provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__}
provider_type: remote::llama-openai-compat
config:
openai_compat_api_base: https://api.llama.com/compat/v1/
api_key: ${env.LLAMA_API_KEY}
- provider_id: ${env.ENABLE_TOGETHER_OPENAI_COMPAT:=__disabled__}
provider_type: remote::together-openai-compat
config:
openai_compat_api_base: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY}
- provider_id: ${env.ENABLE_GROQ_OPENAI_COMPAT:=__disabled__}
provider_type: remote::groq-openai-compat
config:
openai_compat_api_base: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY}
- provider_id: ${env.ENABLE_SAMBANOVA_OPENAI_COMPAT:=__disabled__}
provider_type: remote::sambanova-openai-compat
config:
openai_compat_api_base: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY}
- provider_id: ${env.ENABLE_CEREBRAS_OPENAI_COMPAT:=__disabled__}
provider_type: remote::cerebras-openai-compat
config:
openai_compat_api_base: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY}
- provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
provider_type: remote::sambanova
config:

View file

@ -19,12 +19,7 @@ distribution_spec:
- remote::anthropic
- remote::gemini
- remote::groq
- remote::fireworks-openai-compat
- remote::llama-openai-compat
- remote::together-openai-compat
- remote::groq-openai-compat
- remote::sambanova-openai-compat
- remote::cerebras-openai-compat
- remote::sambanova
- remote::passthrough
- inline::sentence-transformers

View file

@ -90,36 +90,11 @@ providers:
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY}
- provider_id: ${env.ENABLE_FIREWORKS_OPENAI_COMPAT:=__disabled__}
provider_type: remote::fireworks-openai-compat
config:
openai_compat_api_base: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY}
- provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__}
provider_type: remote::llama-openai-compat
config:
openai_compat_api_base: https://api.llama.com/compat/v1/
api_key: ${env.LLAMA_API_KEY}
- provider_id: ${env.ENABLE_TOGETHER_OPENAI_COMPAT:=__disabled__}
provider_type: remote::together-openai-compat
config:
openai_compat_api_base: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY}
- provider_id: ${env.ENABLE_GROQ_OPENAI_COMPAT:=__disabled__}
provider_type: remote::groq-openai-compat
config:
openai_compat_api_base: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY}
- provider_id: ${env.ENABLE_SAMBANOVA_OPENAI_COMPAT:=__disabled__}
provider_type: remote::sambanova-openai-compat
config:
openai_compat_api_base: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY}
- provider_id: ${env.ENABLE_CEREBRAS_OPENAI_COMPAT:=__disabled__}
provider_type: remote::cerebras-openai-compat
config:
openai_compat_api_base: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY}
- provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
provider_type: remote::sambanova
config:

View file

@ -15,7 +15,7 @@ set -Eeuo pipefail
PORT=8321
OLLAMA_PORT=11434
MODEL_ALIAS="llama3.2:3b"
SERVER_IMAGE="llamastack/distribution-ollama:0.2.2"
SERVER_IMAGE="docker.io/llamastack/distribution-ollama:0.2.2"
WAIT_TIMEOUT=300
log(){ printf "\e[1;32m%s\e[0m\n" "$*"; }
@ -165,7 +165,7 @@ log "🦙 Starting Ollama…"
$ENGINE run -d "${PLATFORM_OPTS[@]}" --name ollama-server \
--network llama-net \
-p "${OLLAMA_PORT}:${OLLAMA_PORT}" \
ollama/ollama > /dev/null 2>&1
docker.io/ollama/ollama > /dev/null 2>&1
if ! wait_for_service "http://localhost:${OLLAMA_PORT}/" "Ollama" "$WAIT_TIMEOUT" "Ollama daemon"; then
log "❌ Ollama daemon did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"

View file

@ -47,6 +47,9 @@ def setup_telemetry_data(llama_stack_client, text_model_id):
if len(traces) < 4:
pytest.fail(f"Failed to create sufficient telemetry data after 30s. Got {len(traces)} traces.")
# Wait for 5 seconds to ensure traces has completed logging
time.sleep(5)
yield

View file

@ -11,17 +11,15 @@ from unittest.mock import AsyncMock
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.models import Model
from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable
from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable
from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable
class Impl:
@ -54,17 +52,6 @@ class SafetyImpl(Impl):
return shield
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
class DatasetsImpl(Impl):
def __init__(self):
super().__init__(Api.datasetio)
@ -173,36 +160,6 @@ async def test_shields_routing_table(cached_disk_dist_registry):
assert "test-shield-2" in shield_ids
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
# Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids
assert "test-vectordb-2" in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb")
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
await table.initialize()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,274 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Unit tests for the routing tables vector_dbs
import time
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.apis.vector_io.vector_io import (
VectorStoreContent,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileCounts,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.distribution.access_control.datatypes import AccessRule, Scope
from llama_stack.distribution.datatypes import User
from llama_stack.distribution.request_headers import request_provider_data_context
from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable
from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
async def openai_retrieve_vector_store(self, vector_store_id):
return VectorStoreObject(
id=vector_store_id,
name="Test Store",
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def openai_update_vector_store(self, vector_store_id, **kwargs):
return VectorStoreObject(
id=vector_store_id,
name="Updated Store",
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def openai_delete_vector_store(self, vector_store_id):
return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True)
async def openai_search_vector_store(self, vector_store_id, query, **kwargs):
return VectorStoreSearchResponsePage(
object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None
)
async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs):
return [
VectorStoreFileObject(
id="1",
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
]
async def openai_retrieve_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id):
return VectorStoreFileContentsResponse(
file_id=file_id,
filename="Sample File name",
attributes={"key": "value"},
content=[VectorStoreContent(type="text", text="Sample content")],
)
async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
# Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids
assert "test-vectordb-2" in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb")
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
impl = VectorDBImpl()
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[])
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
authorized_table = "vs1"
authorized_team = "team1"
unauthorized_team = "team2"
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
with request_provider_data_context({}, authorized_user):
_ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
# Authorized reader
with request_provider_data_context({}, authorized_user):
res = await table.openai_retrieve_vector_store(authorized_table)
assert res == "OK"
# Authorized updater
impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED")
with request_provider_data_context({}, authorized_user):
res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
assert res == "UPDATED"
# Unauthorized reader
unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]})
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_retrieve_vector_store(authorized_table)
# Unauthorized updater
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
# Authorized deleter
impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED")
with request_provider_data_context({}, authorized_user):
res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
assert res == "DELETED"
# Unauthorized deleter
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry):
impl = VectorDBImpl()
policy = [
AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"),
AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"),
]
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy)
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
vector_db_id = "vs1"
file_id = "file-1"
admin_user = User(principal="admin", attributes={"roles": ["admin"]})
read_only_user = User(principal="reader", attributes={"roles": ["reader"]})
no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
with request_provider_data_context({}, admin_user):
await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
read_methods = [
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
(table.openai_search_vector_store, (vector_db_id, "query"), {}),
(table.openai_list_files_in_vector_store, (vector_db_id,), {}),
(table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}),
(table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}),
]
update_methods = [
(table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}),
(table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}),
(table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}),
]
delete_methods = [
(table.openai_delete_vector_store_file, (vector_db_id, file_id), {}),
(table.openai_delete_vector_store, (vector_db_id,), {}),
]
for user in [admin_user, read_only_user]:
with request_provider_data_context({}, user):
for method, args, kwargs in read_methods:
result = await method(*args, **kwargs)
assert result is not None, f"Read operation failed with user {user.principal}"
with request_provider_data_context({}, no_access_user):
for method, args, kwargs in read_methods:
with pytest.raises(ValueError):
await method(*args, **kwargs)
with request_provider_data_context({}, admin_user):
for method, args, kwargs in update_methods:
result = await method(*args, **kwargs)
assert result is not None, "Update operation failed with admin user"
with request_provider_data_context({}, admin_user):
for method, args, kwargs in delete_methods:
result = await method(*args, **kwargs)
assert result is not None, "Delete operation failed with admin user"
for user in [read_only_user, no_access_user]:
with request_provider_data_context({}, user):
for method, args, kwargs in delete_methods:
with pytest.raises(ValueError):
await method(*args, **kwargs)

View file

@ -0,0 +1,176 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.apis.agents import Document
from llama_stack.apis.common.content_types import URL, TextContentItem
from llama_stack.providers.inline.agents.meta_reference.agent_instance import get_raw_document_text
async def test_get_raw_document_text_supports_text_mime_types():
"""Test that the function accepts text/* mime types."""
document = Document(content="Sample text content", mime_type="text/plain")
result = await get_raw_document_text(document)
assert result == "Sample text content"
async def test_get_raw_document_text_supports_yaml_mime_type():
"""Test that the function accepts application/yaml mime type."""
yaml_content = """
name: test
version: 1.0
items:
- item1
- item2
"""
document = Document(content=yaml_content, mime_type="application/yaml")
result = await get_raw_document_text(document)
assert result == yaml_content
async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning():
"""Test that the function accepts text/yaml but emits a deprecation warning."""
yaml_content = """
name: test
version: 1.0
items:
- item1
- item2
"""
document = Document(content=yaml_content, mime_type="text/yaml")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = await get_raw_document_text(document)
# Check that result is correct
assert result == yaml_content
# Check that exactly one warning was issued
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "text/yaml" in str(w[0].message)
assert "application/yaml" in str(w[0].message)
assert "deprecated" in str(w[0].message).lower()
async def test_get_raw_document_text_deprecated_text_yaml_with_url():
"""Test that text/yaml works with URL content and emits warning."""
yaml_content = "name: test\nversion: 1.0"
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
mock_load.return_value = yaml_content
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="text/yaml")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = await get_raw_document_text(document)
# Check that result is correct
assert result == yaml_content
mock_load.assert_called_once_with("https://example.com/config.yaml")
# Check that deprecation warning was issued
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "text/yaml" in str(w[0].message)
async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item():
"""Test that text/yaml works with TextContentItem and emits warning."""
yaml_content = "key: value\nlist:\n - item1\n - item2"
document = Document(content=TextContentItem(text=yaml_content), mime_type="text/yaml")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = await get_raw_document_text(document)
# Check that result is correct
assert result == yaml_content
# Check that deprecation warning was issued
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "text/yaml" in str(w[0].message)
async def test_get_raw_document_text_rejects_unsupported_mime_types():
"""Test that the function rejects unsupported mime types."""
document = Document(
content="Some content",
mime_type="application/json", # Not supported
)
with pytest.raises(ValueError, match="Unexpected document mime type: application/json"):
await get_raw_document_text(document)
async def test_get_raw_document_text_with_url_content():
"""Test that the function handles URL content correctly."""
mock_response = AsyncMock()
mock_response.text = "Content from URL"
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
mock_load.return_value = "Content from URL"
document = Document(content=URL(uri="https://example.com/test.txt"), mime_type="text/plain")
result = await get_raw_document_text(document)
assert result == "Content from URL"
mock_load.assert_called_once_with("https://example.com/test.txt")
async def test_get_raw_document_text_with_yaml_url():
"""Test that the function handles YAML URLs correctly."""
yaml_content = "name: test\nversion: 1.0"
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
mock_load.return_value = yaml_content
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="application/yaml")
result = await get_raw_document_text(document)
assert result == yaml_content
mock_load.assert_called_once_with("https://example.com/config.yaml")
async def test_get_raw_document_text_with_text_content_item():
"""Test that the function handles TextContentItem correctly."""
document = Document(content=TextContentItem(text="Text content item"), mime_type="text/plain")
result = await get_raw_document_text(document)
assert result == "Text content item"
async def test_get_raw_document_text_with_yaml_text_content_item():
"""Test that the function handles YAML TextContentItem correctly."""
yaml_content = "key: value\nlist:\n - item1\n - item2"
document = Document(content=TextContentItem(text=yaml_content), mime_type="application/yaml")
result = await get_raw_document_text(document)
assert result == yaml_content
async def test_get_raw_document_text_rejects_unexpected_content_type():
"""Test that the function rejects unexpected document content types."""
# Create a mock document that bypasses Pydantic validation
mock_document = MagicMock(spec=Document)
mock_document.mime_type = "text/plain"
mock_document.content = 123 # Unexpected content type (not str, URL, or TextContentItem)
with pytest.raises(ValueError, match="Unexpected document content type: <class 'int'>"):
await get_raw_document_text(mock_document)