mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-25 21:57:45 +00:00
Merge branch 'main' into playground-ui
This commit is contained in:
commit
5119671f02
52 changed files with 938 additions and 773 deletions
3
.github/workflows/coverage-badge.yml
vendored
3
.github/workflows/coverage-badge.yml
vendored
|
@ -15,6 +15,9 @@ on:
|
|||
|
||||
jobs:
|
||||
unit-tests:
|
||||
permissions:
|
||||
contents: write # for peter-evans/create-pull-request to create branch
|
||||
pull-requests: write # for peter-evans/create-pull-request to create a PR
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
|
2
.github/workflows/python-build-test.yml
vendored
2
.github/workflows/python-build-test.yml
vendored
|
@ -20,7 +20,7 @@ jobs:
|
|||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@bd01e18f51369d5a26f1651c3cb451d3417e3bba # v6.3.1
|
||||
uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f # v6.4.1
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
activate-environment: true
|
||||
|
|
|
@ -10,8 +10,13 @@ If in doubt, please open a [discussion](https://github.com/meta-llama/llama-stac
|
|||
|
||||
**I'd like to contribute!**
|
||||
|
||||
All issues are actionable (please report if they are not.) Pick one and start working on it. Thank you.
|
||||
If you need help or guidance, comment on the issue. Issues that are extra friendly to new contributors are tagged with "contributor friendly".
|
||||
If you are new to the project, start by looking at the issues tagged with "good first issue". If you're interested
|
||||
leave a comment on the issue and a triager will assign it to you.
|
||||
|
||||
Please avoid picking up too many issues at once. This helps you stay focused and ensures that others in the community also have opportunities to contribute.
|
||||
- Try to work on only 1–2 issues at a time, especially if you’re still getting familiar with the codebase.
|
||||
- Before taking an issue, check if it’s already assigned or being actively discussed.
|
||||
- If you’re blocked or can’t continue with an issue, feel free to unassign yourself or leave a comment so others can step in.
|
||||
|
||||
**I have a bug!**
|
||||
|
||||
|
@ -41,6 +46,15 @@ If you need help or guidance, comment on the issue. Issues that are extra friend
|
|||
4. Make sure your code lints using `pre-commit`.
|
||||
5. If you haven't already, complete the Contributor License Agreement ("CLA").
|
||||
6. Ensure your pull request follows the [conventional commits format](https://www.conventionalcommits.org/en/v1.0.0/).
|
||||
7. Ensure your pull request follows the [coding style](#coding-style).
|
||||
|
||||
|
||||
Please keep pull requests (PRs) small and focused. If you have a large set of changes, consider splitting them into logically grouped, smaller PRs to facilitate review and testing.
|
||||
|
||||
> [!TIP]
|
||||
> As a general guideline:
|
||||
> - Experienced contributors should try to keep no more than 5 open PRs at a time.
|
||||
> - New contributors are encouraged to have only one open PR at a time until they’re familiar with the codebase and process.
|
||||
|
||||
## Contributor License Agreement ("CLA")
|
||||
In order to accept your pull request, we need you to submit a CLA. You only need
|
||||
|
@ -140,7 +154,9 @@ uv sync
|
|||
* Don't use unicode characters in the codebase. ASCII-only is preferred for compatibility or
|
||||
readability reasons.
|
||||
* Providers configuration class should be Pydantic Field class. It should have a `description` field
|
||||
that describes the configuration. These descriptions will be used to generate the provider documentation.
|
||||
that describes the configuration. These descriptions will be used to generate the provider
|
||||
documentation.
|
||||
* When possible, use keyword arguments only when calling functions.
|
||||
|
||||
## Common Tasks
|
||||
|
||||
|
|
3
docs/_static/llama-stack-spec.html
vendored
3
docs/_static/llama-stack-spec.html
vendored
|
@ -13596,9 +13596,6 @@
|
|||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"name"
|
||||
],
|
||||
"title": "OpenaiCreateVectorStoreRequest"
|
||||
},
|
||||
"VectorStoreFileCounts": {
|
||||
|
|
2
docs/_static/llama-stack-spec.yaml
vendored
2
docs/_static/llama-stack-spec.yaml
vendored
|
@ -9497,8 +9497,6 @@ components:
|
|||
description: >-
|
||||
The ID of the provider to use for this vector store.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- name
|
||||
title: OpenaiCreateVectorStoreRequest
|
||||
VectorStoreFileCounts:
|
||||
type: object
|
||||
|
|
62
docs/source/getting_started/demo_script.py
Normal file
62
docs/source/getting_started/demo_script.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient
|
||||
|
||||
vector_db_id = "my_demo_vector_db"
|
||||
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||
|
||||
models = client.models.list()
|
||||
|
||||
# Select the first LLM and first embedding models
|
||||
model_id = next(m for m in models if m.model_type == "llm").identifier
|
||||
embedding_model_id = (
|
||||
em := next(m for m in models if m.model_type == "embedding")
|
||||
).identifier
|
||||
embedding_dimension = em.metadata["embedding_dimension"]
|
||||
|
||||
_ = client.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id="faiss",
|
||||
)
|
||||
source = "https://www.paulgraham.com/greatwork.html"
|
||||
print("rag_tool> Ingesting document:", source)
|
||||
document = RAGDocument(
|
||||
document_id="document_1",
|
||||
content=source,
|
||||
mime_type="text/html",
|
||||
metadata={},
|
||||
)
|
||||
client.tool_runtime.rag_tool.insert(
|
||||
documents=[document],
|
||||
vector_db_id=vector_db_id,
|
||||
chunk_size_in_tokens=50,
|
||||
)
|
||||
agent = Agent(
|
||||
client,
|
||||
model=model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
tools=[
|
||||
{
|
||||
"name": "builtin::rag/knowledge_search",
|
||||
"args": {"vector_db_ids": [vector_db_id]},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
prompt = "How do you do great work?"
|
||||
print("prompt>", prompt)
|
||||
|
||||
response = agent.create_turn(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
session_id=agent.create_session("rag_session"),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
|
@ -24,63 +24,8 @@ ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL=llama3.2:3b uv run --with llama-stac
|
|||
#### Step 3: Run the demo
|
||||
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
|
||||
|
||||
```python
|
||||
from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient
|
||||
|
||||
vector_db_id = "my_demo_vector_db"
|
||||
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||
|
||||
models = client.models.list()
|
||||
|
||||
# Select the first LLM and first embedding models
|
||||
model_id = next(m for m in models if m.model_type == "llm").identifier
|
||||
embedding_model_id = (
|
||||
em := next(m for m in models if m.model_type == "embedding")
|
||||
).identifier
|
||||
embedding_dimension = em.metadata["embedding_dimension"]
|
||||
|
||||
_ = client.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id="faiss",
|
||||
)
|
||||
source = "https://www.paulgraham.com/greatwork.html"
|
||||
print("rag_tool> Ingesting document:", source)
|
||||
document = RAGDocument(
|
||||
document_id="document_1",
|
||||
content=source,
|
||||
mime_type="text/html",
|
||||
metadata={},
|
||||
)
|
||||
client.tool_runtime.rag_tool.insert(
|
||||
documents=[document],
|
||||
vector_db_id=vector_db_id,
|
||||
chunk_size_in_tokens=50,
|
||||
)
|
||||
agent = Agent(
|
||||
client,
|
||||
model=model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
tools=[
|
||||
{
|
||||
"name": "builtin::rag/knowledge_search",
|
||||
"args": {"vector_db_ids": [vector_db_id]},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
prompt = "How do you do great work?"
|
||||
print("prompt>", prompt)
|
||||
|
||||
response = agent.create_turn(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
session_id=agent.create_session("rag_session"),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
```{literalinclude} ./demo_script.py
|
||||
:language: python
|
||||
```
|
||||
We will use `uv` to run the script
|
||||
```
|
||||
|
|
|
@ -7,13 +7,10 @@ This section contains documentation for all available providers for the **infere
|
|||
- [remote::anthropic](remote_anthropic.md)
|
||||
- [remote::bedrock](remote_bedrock.md)
|
||||
- [remote::cerebras](remote_cerebras.md)
|
||||
- [remote::cerebras-openai-compat](remote_cerebras-openai-compat.md)
|
||||
- [remote::databricks](remote_databricks.md)
|
||||
- [remote::fireworks](remote_fireworks.md)
|
||||
- [remote::fireworks-openai-compat](remote_fireworks-openai-compat.md)
|
||||
- [remote::gemini](remote_gemini.md)
|
||||
- [remote::groq](remote_groq.md)
|
||||
- [remote::groq-openai-compat](remote_groq-openai-compat.md)
|
||||
- [remote::hf::endpoint](remote_hf_endpoint.md)
|
||||
- [remote::hf::serverless](remote_hf_serverless.md)
|
||||
- [remote::llama-openai-compat](remote_llama-openai-compat.md)
|
||||
|
@ -23,9 +20,7 @@ This section contains documentation for all available providers for the **infere
|
|||
- [remote::passthrough](remote_passthrough.md)
|
||||
- [remote::runpod](remote_runpod.md)
|
||||
- [remote::sambanova](remote_sambanova.md)
|
||||
- [remote::sambanova-openai-compat](remote_sambanova-openai-compat.md)
|
||||
- [remote::tgi](remote_tgi.md)
|
||||
- [remote::together](remote_together.md)
|
||||
- [remote::together-openai-compat](remote_together-openai-compat.md)
|
||||
- [remote::vllm](remote_vllm.md)
|
||||
- [remote::watsonx](remote_watsonx.md)
|
|
@ -17,7 +17,7 @@ That means you'll get fast and efficient vector retrieval.
|
|||
To use PGVector in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Install the necessary dependencies.
|
||||
2. Configure your Llama Stack project to use Faiss.
|
||||
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
|
|
@ -34,6 +34,7 @@ class VectorDBInput(BaseModel):
|
|||
vector_db_id: str
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
provider_id: str | None = None
|
||||
provider_vector_db_id: str | None = None
|
||||
|
||||
|
||||
|
|
|
@ -338,7 +338,7 @@ class VectorIO(Protocol):
|
|||
@webmethod(route="/openai/v1/vector_stores", method="POST")
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
name: str | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
|
|
|
@ -276,8 +276,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
if config.external_providers_dir and not config.external_providers_dir.exists():
|
||||
config.external_providers_dir.mkdir(exist_ok=True)
|
||||
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
|
||||
run_args = formulate_run_args(args.image_type, args.image_name)
|
||||
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", str(run_config)])
|
||||
run_command(run_args)
|
||||
|
||||
|
||||
|
|
|
@ -82,39 +82,6 @@ class StackRun(Subcommand):
|
|||
return ImageType.CONDA.value, args.image_name
|
||||
return args.image_type, args.image_name
|
||||
|
||||
def _resolve_config_and_template(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
|
||||
"""Resolve config file path and template name from args.config"""
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
if not args.config:
|
||||
return None, None
|
||||
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
template_name = None
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if this is a template
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
|
||||
if config_file.exists():
|
||||
template_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
|
||||
if not config_file.is_file():
|
||||
self.parser.error(
|
||||
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
||||
)
|
||||
|
||||
return config_file, template_name
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
||||
|
@ -125,8 +92,15 @@ class StackRun(Subcommand):
|
|||
self._start_ui_development_server(args.port)
|
||||
image_type, image_name = self._get_image_type_and_name(args)
|
||||
|
||||
# Resolve config file and template name first
|
||||
config_file, template_name = self._resolve_config_and_template(args)
|
||||
if args.config:
|
||||
try:
|
||||
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
|
||||
|
||||
config_file = resolve_config_or_template(args.config, Mode.RUN)
|
||||
except ValueError as e:
|
||||
self.parser.error(str(e))
|
||||
else:
|
||||
config_file = None
|
||||
|
||||
# Check if config is required based on image type
|
||||
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file:
|
||||
|
@ -164,18 +138,14 @@ class StackRun(Subcommand):
|
|||
if callable(getattr(args, arg)):
|
||||
continue
|
||||
if arg == "config":
|
||||
if template_name:
|
||||
server_args.template = str(template_name)
|
||||
else:
|
||||
# Set the config file path
|
||||
server_args.config = str(config_file)
|
||||
server_args.config = str(config_file)
|
||||
else:
|
||||
setattr(server_args, arg, getattr(args, arg))
|
||||
|
||||
# Run the server
|
||||
server_main(server_args)
|
||||
else:
|
||||
run_args = formulate_run_args(image_type, image_name, config, template_name)
|
||||
run_args = formulate_run_args(image_type, image_name)
|
||||
|
||||
run_args.extend([str(args.port)])
|
||||
|
||||
|
|
31
llama_stack/cli/utils.py
Normal file
31
llama_stack/cli/utils.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def add_config_template_args(parser: argparse.ArgumentParser):
|
||||
"""Add unified config/template arguments with backward compatibility."""
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
|
||||
group.add_argument(
|
||||
"config",
|
||||
nargs="?",
|
||||
help="Configuration file path or template name",
|
||||
)
|
||||
|
||||
# Backward compatibility arguments (deprecated)
|
||||
group.add_argument(
|
||||
"--config",
|
||||
dest="config",
|
||||
help="(DEPRECATED) Use positional argument [config] instead. Configuration file path",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--template",
|
||||
dest="config",
|
||||
help="(DEPRECATED) Use positional argument [config] instead. Template name",
|
||||
)
|
|
@ -214,9 +214,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
|
@ -226,9 +224,7 @@ class VectorIORouter(VectorIO):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
return await self.routing_table.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
|
@ -240,12 +236,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||
# drop from registry
|
||||
await self.routing_table.unregister_vector_db(vector_store_id)
|
||||
return result
|
||||
return await self.routing_table.openai_delete_vector_store(vector_store_id)
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
|
@ -258,9 +249,7 @@ class VectorIORouter(VectorIO):
|
|||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
return await self.routing_table.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
|
@ -278,9 +267,7 @@ class VectorIORouter(VectorIO):
|
|||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
return await self.routing_table.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
|
@ -297,9 +284,7 @@ class VectorIORouter(VectorIO):
|
|||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
return await self.routing_table.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
|
@ -314,9 +299,7 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
return await self.routing_table.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
@ -327,9 +310,7 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
return await self.routing_table.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
@ -341,9 +322,7 @@ class VectorIORouter(VectorIO):
|
|||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
return await self.routing_table.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
|
@ -355,9 +334,7 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
return await self.routing_table.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any
|
|||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.access_control.datatypes import Action
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessRule,
|
||||
RoutableObject,
|
||||
|
@ -209,6 +210,20 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
await self.dist_registry.register(obj)
|
||||
return obj
|
||||
|
||||
async def assert_action_allowed(
|
||||
self,
|
||||
action: Action,
|
||||
type: str,
|
||||
identifier: str,
|
||||
) -> None:
|
||||
"""Fetch a registered object by type/identifier and enforce the given action permission."""
|
||||
obj = await self.get_object_by_identifier(type, identifier)
|
||||
if obj is None:
|
||||
raise ValueError(f"{type.capitalize()} '{identifier}' not found")
|
||||
user = get_authenticated_user()
|
||||
if not is_action_allowed(self.policy, action, obj, user):
|
||||
raise AccessDeniedError(action, obj, user)
|
||||
|
||||
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||
|
|
|
@ -4,11 +4,24 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import (
|
||||
VectorDBWithOwner,
|
||||
)
|
||||
|
@ -74,3 +87,135 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
if existing_vector_db is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
await self.unregister_object(existing_vector_db)
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id)
|
||||
await self.unregister_vector_db(vector_store_id)
|
||||
return result
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
max_num_results=max_num_results,
|
||||
ranking_options=ranking_options,
|
||||
rewrite_query=rewrite_query,
|
||||
search_mode=search_mode,
|
||||
)
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
|
|
@ -32,6 +32,7 @@ from openai import BadRequestError
|
|||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.cli.utils import add_config_template_args
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
|
@ -53,6 +54,7 @@ from llama_stack.distribution.stack import (
|
|||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
@ -377,20 +379,8 @@ class ClientVersionMiddleware:
|
|||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
parser.add_argument(
|
||||
"--yaml-config",
|
||||
dest="config",
|
||||
help="(Deprecated) Path to YAML configuration file - use --config instead",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
dest="config",
|
||||
help="Path to YAML configuration file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--template",
|
||||
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
|
||||
)
|
||||
|
||||
add_config_template_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
|
@ -409,20 +399,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
log_line = ""
|
||||
if hasattr(args, "config") and args.config:
|
||||
# if the user provided a config file, use it, even if template was specified
|
||||
config_file = Path(args.config)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
log_line = f"Using config file: {config_file}"
|
||||
elif hasattr(args, "template") and args.template:
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
log_line = f"Using template {args.template} config file: {config_file}"
|
||||
else:
|
||||
raise ValueError("Either --config or --template must be provided")
|
||||
config_file = resolve_config_or_template(args.config, Mode.RUN)
|
||||
|
||||
logger_config = None
|
||||
with open(config_file) as fp:
|
||||
|
@ -442,9 +419,6 @@ def main(args: argparse.Namespace | None = None):
|
|||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
# now that the logger is initialized, print the line about which type of config we are using.
|
||||
logger.info(log_line)
|
||||
|
||||
_log_run_config(run_config=config)
|
||||
|
||||
app = FastAPI(
|
||||
|
@ -592,12 +566,29 @@ def main(args: argparse.Namespace | None = None):
|
|||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
}
|
||||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
||||
# Run uvicorn in the existing event loop to preserve background tasks
|
||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
# stack trace when using Ctrl+C or kill -2 (SIGINT).
|
||||
# SIGTERM (kill -15) works fine without this because Python doesn't
|
||||
# have a default handler for it.
|
||||
#
|
||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
finally:
|
||||
if not loop.is_closed():
|
||||
logger.debug("Closing event loop")
|
||||
loop.close()
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
|
|
|
@ -117,7 +117,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
|||
set -x
|
||||
|
||||
if [ -n "$yaml_config" ]; then
|
||||
yaml_config_arg="--config $yaml_config"
|
||||
yaml_config_arg="$yaml_config"
|
||||
else
|
||||
yaml_config_arg=""
|
||||
fi
|
||||
|
|
125
llama_stack/distribution/utils/config_resolution.py
Normal file
125
llama_stack/distribution/utils/config_resolution.py
Normal file
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="config_resolution")
|
||||
|
||||
|
||||
TEMPLATE_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "templates"
|
||||
|
||||
|
||||
class Mode(StrEnum):
|
||||
RUN = "run"
|
||||
BUILD = "build"
|
||||
|
||||
|
||||
def resolve_config_or_template(
|
||||
config_or_template: str,
|
||||
mode: Mode = Mode.RUN,
|
||||
) -> Path:
|
||||
"""
|
||||
Resolve a config/template argument to a concrete config file path.
|
||||
|
||||
Args:
|
||||
config_or_template: User input (file path, template name, or built distribution)
|
||||
mode: Mode resolving for ("run", "build", "server")
|
||||
|
||||
Returns:
|
||||
Path to the resolved config file
|
||||
|
||||
Raises:
|
||||
ValueError: If resolution fails
|
||||
"""
|
||||
|
||||
# Strategy 1: Try as file path first
|
||||
config_path = Path(config_or_template)
|
||||
if config_path.exists() and config_path.is_file():
|
||||
logger.info(f"Using file path: {config_path}")
|
||||
return config_path.resolve()
|
||||
|
||||
# Strategy 2: Try as template name (if no .yaml extension)
|
||||
if not config_or_template.endswith(".yaml"):
|
||||
template_config = _get_template_config_path(config_or_template, mode)
|
||||
if template_config.exists():
|
||||
logger.info(f"Using template: {template_config}")
|
||||
return template_config
|
||||
|
||||
# Strategy 3: Try as built distribution name
|
||||
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
if distrib_config.exists():
|
||||
logger.info(f"Using built distribution: {distrib_config}")
|
||||
return distrib_config
|
||||
|
||||
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
if distrib_config.exists():
|
||||
logger.info(f"Using built distribution: {distrib_config}")
|
||||
return distrib_config
|
||||
|
||||
# Strategy 4: Failed - provide helpful error
|
||||
raise ValueError(_format_resolution_error(config_or_template, mode))
|
||||
|
||||
|
||||
def _get_template_config_path(template_name: str, mode: Mode) -> Path:
|
||||
"""Get the config file path for a template."""
|
||||
return TEMPLATE_DIR / template_name / f"{mode}.yaml"
|
||||
|
||||
|
||||
def _format_resolution_error(config_or_template: str, mode: Mode) -> str:
|
||||
"""Format a helpful error message for resolution failures."""
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
template_path = _get_template_config_path(config_or_template, mode)
|
||||
distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
|
||||
available_templates = _get_available_templates()
|
||||
templates_str = ", ".join(available_templates) if available_templates else "none found"
|
||||
|
||||
return f"""Could not resolve config or template '{config_or_template}'.
|
||||
|
||||
Tried the following locations:
|
||||
1. As file path: {Path(config_or_template).resolve()}
|
||||
2. As template: {template_path}
|
||||
3. As built distribution: ({distrib_path}, {distrib_path2})
|
||||
|
||||
Available templates: {templates_str}
|
||||
|
||||
Did you mean one of these templates?
|
||||
{_format_template_suggestions(available_templates, config_or_template)}
|
||||
"""
|
||||
|
||||
|
||||
def _get_available_templates() -> list[str]:
|
||||
"""Get list of available template names."""
|
||||
if not TEMPLATE_DIR.exists() and not DISTRIBS_BASE_DIR.exists():
|
||||
return []
|
||||
|
||||
return list(
|
||||
set(
|
||||
[d.name for d in TEMPLATE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
|
||||
+ [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _format_template_suggestions(templates: list[str], user_input: str) -> str:
|
||||
"""Format template suggestions for error messages, showing closest matches first."""
|
||||
if not templates:
|
||||
return " (no templates found)"
|
||||
|
||||
import difflib
|
||||
|
||||
# Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive)
|
||||
close_matches = difflib.get_close_matches(user_input, templates, n=3, cutoff=0.3)
|
||||
display_templates = close_matches if close_matches else templates[:3]
|
||||
|
||||
suggestions = [f" - {t}" for t in display_templates]
|
||||
return "\n".join(suggestions)
|
|
@ -21,7 +21,7 @@ from pathlib import Path
|
|||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||
def formulate_run_args(image_type: str, image_name: str) -> list[str]:
|
||||
env_name = ""
|
||||
|
||||
if image_type == LlamaStackImageType.CONDA.value:
|
||||
|
|
|
@ -10,6 +10,7 @@ import re
|
|||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
@ -911,8 +912,16 @@ async def load_data_from_url(url: str) -> str:
|
|||
|
||||
|
||||
async def get_raw_document_text(document: Document) -> str:
|
||||
if not document.mime_type.startswith("text/"):
|
||||
# Handle deprecated text/yaml mime type with warning
|
||||
if document.mime_type == "text/yaml":
|
||||
warnings.warn(
|
||||
"The 'text/yaml' MIME type is deprecated. Please use 'application/yaml' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"):
|
||||
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
|
||||
|
||||
if isinstance(document.content, URL):
|
||||
return await load_data_from_url(document.content.uri)
|
||||
elif isinstance(document.content, str):
|
||||
|
|
|
@ -128,6 +128,11 @@ class AgentPersistence:
|
|||
except Exception as e:
|
||||
log.error(f"Error parsing turn: {e}")
|
||||
continue
|
||||
|
||||
# The kvstore does not guarantee order, so we sort by started_at
|
||||
# to ensure consistent ordering of turns.
|
||||
turns.sort(key=lambda t: t.started_at)
|
||||
|
||||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||
|
|
|
@ -224,17 +224,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="fireworks-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.fireworks_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator",
|
||||
description="Fireworks AI OpenAI-compatible provider for using Fireworks models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
@ -246,50 +235,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="together-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.together_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherProviderDataValidator",
|
||||
description="Together AI OpenAI-compatible provider for using Together models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="groq-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.groq_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqProviderDataValidator",
|
||||
description="Groq OpenAI-compatible provider for using Groq models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sambanova-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.sambanova_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova OpenAI-compatible provider for using SambaNova models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="cerebras-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.cerebras_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasProviderDataValidator",
|
||||
description="Cerebras OpenAI-compatible provider for using Cerebras models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
@ -395,7 +395,7 @@ That means you'll get fast and efficient vector retrieval.
|
|||
To use PGVector in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Install the necessary dependencies.
|
||||
2. Configure your Llama Stack project to use Faiss.
|
||||
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import CerebrasCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .cerebras import CerebrasCompatInferenceAdapter
|
||||
|
||||
adapter = CerebrasCompatInferenceAdapter(config)
|
||||
return adapter
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.cerebras_openai_compat.config import CerebrasCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..cerebras.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class CerebrasCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: CerebrasCompatConfig
|
||||
|
||||
def __init__(self, config: CerebrasCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="cerebras_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class CerebrasProviderDataValidator(BaseModel):
|
||||
cerebras_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Cerebras models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CerebrasCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Cerebras API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.cerebras.ai/v1",
|
||||
description="The URL for the Cerebras API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.cerebras.ai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import FireworksCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .fireworks import FireworksCompatInferenceAdapter
|
||||
|
||||
adapter = FireworksCompatInferenceAdapter(config)
|
||||
return adapter
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class FireworksProviderDataValidator(BaseModel):
|
||||
fireworks_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Fireworks models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FireworksCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Fireworks API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": api_key,
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.fireworks_openai_compat.config import FireworksCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..fireworks.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class FireworksCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: FireworksCompatConfig
|
||||
|
||||
def __init__(self, config: FireworksCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="fireworks_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import GroqCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .groq import GroqCompatInferenceAdapter
|
||||
|
||||
adapter = GroqCompatInferenceAdapter(config)
|
||||
return adapter
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GroqProviderDataValidator(BaseModel):
|
||||
groq_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Groq models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GroqCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Groq API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.groq.com/openai/v1",
|
||||
description="The URL for the Groq API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.groq.com/openai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.groq_openai_compat.config import GroqCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..groq.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class GroqCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: GroqCompatConfig
|
||||
|
||||
def __init__(self, config: GroqCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import SambaNovaCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .sambanova import SambaNovaCompatInferenceAdapter
|
||||
|
||||
adapter = SambaNovaCompatInferenceAdapter(config)
|
||||
return adapter
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class SambaNovaProviderDataValidator(BaseModel):
|
||||
sambanova_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for SambaNova models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SambaNovaCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The SambaNova API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.sambanova_openai_compat.config import SambaNovaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..sambanova.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class SambaNovaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: SambaNovaCompatConfig
|
||||
|
||||
def __init__(self, config: SambaNovaCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="sambanova_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import TogetherCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .together import TogetherCompatInferenceAdapter
|
||||
|
||||
adapter = TogetherCompatInferenceAdapter(config)
|
||||
return adapter
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class TogetherProviderDataValidator(BaseModel):
|
||||
together_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Together models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TogetherCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Together API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.TOGETHER_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.together.xyz/v1",
|
||||
"api_key": api_key,
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.together_openai_compat.config import TogetherCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..together.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class TogetherCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: TogetherCompatConfig
|
||||
|
||||
def __init__(self, config: TogetherCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="together_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
|
@ -161,7 +161,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
name: str | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
|
|
|
@ -83,6 +83,7 @@ class SQLiteTraceStore(TraceStore):
|
|||
)
|
||||
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
|
||||
FROM filtered_traces
|
||||
WHERE root_span_id IS NOT NULL
|
||||
LIMIT {limit} OFFSET {offset}
|
||||
"""
|
||||
|
||||
|
@ -166,7 +167,11 @@ class SQLiteTraceStore(TraceStore):
|
|||
return spans_by_id
|
||||
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
query = "SELECT * FROM traces WHERE trace_id = ?"
|
||||
query = """
|
||||
SELECT *
|
||||
FROM traces t
|
||||
WHERE t.trace_id = ?
|
||||
"""
|
||||
async with aiosqlite.connect(self.conn_string) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(query, (trace_id,)) as cursor:
|
||||
|
|
|
@ -19,12 +19,7 @@ distribution_spec:
|
|||
- remote::anthropic
|
||||
- remote::gemini
|
||||
- remote::groq
|
||||
- remote::fireworks-openai-compat
|
||||
- remote::llama-openai-compat
|
||||
- remote::together-openai-compat
|
||||
- remote::groq-openai-compat
|
||||
- remote::sambanova-openai-compat
|
||||
- remote::cerebras-openai-compat
|
||||
- remote::sambanova
|
||||
- remote::passthrough
|
||||
- inline::sentence-transformers
|
||||
|
|
|
@ -90,36 +90,11 @@ providers:
|
|||
config:
|
||||
url: https://api.groq.com
|
||||
api_key: ${env.GROQ_API_KEY}
|
||||
- provider_id: ${env.ENABLE_FIREWORKS_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::fireworks-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY}
|
||||
- provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::llama-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.llama.com/compat/v1/
|
||||
api_key: ${env.LLAMA_API_KEY}
|
||||
- provider_id: ${env.ENABLE_TOGETHER_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::together-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY}
|
||||
- provider_id: ${env.ENABLE_GROQ_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::groq-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY}
|
||||
- provider_id: ${env.ENABLE_SAMBANOVA_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::sambanova-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY}
|
||||
- provider_id: ${env.ENABLE_CEREBRAS_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::cerebras-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY}
|
||||
- provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
|
|
|
@ -19,12 +19,7 @@ distribution_spec:
|
|||
- remote::anthropic
|
||||
- remote::gemini
|
||||
- remote::groq
|
||||
- remote::fireworks-openai-compat
|
||||
- remote::llama-openai-compat
|
||||
- remote::together-openai-compat
|
||||
- remote::groq-openai-compat
|
||||
- remote::sambanova-openai-compat
|
||||
- remote::cerebras-openai-compat
|
||||
- remote::sambanova
|
||||
- remote::passthrough
|
||||
- inline::sentence-transformers
|
||||
|
|
|
@ -90,36 +90,11 @@ providers:
|
|||
config:
|
||||
url: https://api.groq.com
|
||||
api_key: ${env.GROQ_API_KEY}
|
||||
- provider_id: ${env.ENABLE_FIREWORKS_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::fireworks-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY}
|
||||
- provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::llama-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.llama.com/compat/v1/
|
||||
api_key: ${env.LLAMA_API_KEY}
|
||||
- provider_id: ${env.ENABLE_TOGETHER_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::together-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY}
|
||||
- provider_id: ${env.ENABLE_GROQ_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::groq-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY}
|
||||
- provider_id: ${env.ENABLE_SAMBANOVA_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::sambanova-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY}
|
||||
- provider_id: ${env.ENABLE_CEREBRAS_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::cerebras-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY}
|
||||
- provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
|
|
|
@ -15,7 +15,7 @@ set -Eeuo pipefail
|
|||
PORT=8321
|
||||
OLLAMA_PORT=11434
|
||||
MODEL_ALIAS="llama3.2:3b"
|
||||
SERVER_IMAGE="llamastack/distribution-ollama:0.2.2"
|
||||
SERVER_IMAGE="docker.io/llamastack/distribution-ollama:0.2.2"
|
||||
WAIT_TIMEOUT=300
|
||||
|
||||
log(){ printf "\e[1;32m%s\e[0m\n" "$*"; }
|
||||
|
@ -165,7 +165,7 @@ log "🦙 Starting Ollama…"
|
|||
$ENGINE run -d "${PLATFORM_OPTS[@]}" --name ollama-server \
|
||||
--network llama-net \
|
||||
-p "${OLLAMA_PORT}:${OLLAMA_PORT}" \
|
||||
ollama/ollama > /dev/null 2>&1
|
||||
docker.io/ollama/ollama > /dev/null 2>&1
|
||||
|
||||
if ! wait_for_service "http://localhost:${OLLAMA_PORT}/" "Ollama" "$WAIT_TIMEOUT" "Ollama daemon"; then
|
||||
log "❌ Ollama daemon did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"
|
||||
|
|
|
@ -47,6 +47,9 @@ def setup_telemetry_data(llama_stack_client, text_model_id):
|
|||
if len(traces) < 4:
|
||||
pytest.fail(f"Failed to create sufficient telemetry data after 30s. Got {len(traces)} traces.")
|
||||
|
||||
# Wait for 5 seconds to ensure traces has completed logging
|
||||
time.sleep(5)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
|
|
|
@ -11,17 +11,15 @@ from unittest.mock import AsyncMock
|
|||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.shields.shields import Shield
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
|
||||
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
||||
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
||||
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
||||
from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||
from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable
|
||||
from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||
|
||||
|
||||
class Impl:
|
||||
|
@ -54,17 +52,6 @@ class SafetyImpl(Impl):
|
|||
return shield
|
||||
|
||||
|
||||
class VectorDBImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.vector_io)
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB):
|
||||
return vector_db
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str):
|
||||
return vector_db_id
|
||||
|
||||
|
||||
class DatasetsImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.datasetio)
|
||||
|
@ -173,36 +160,6 @@ async def test_shields_routing_table(cached_disk_dist_registry):
|
|||
assert "test-shield-2" in shield_ids
|
||||
|
||||
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
# Register multiple vector databases and verify listing
|
||||
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
|
||||
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
|
||||
assert len(vector_dbs.data) == 2
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
assert "test-vectordb" in vector_db_ids
|
||||
assert "test-vectordb-2" in vector_db_ids
|
||||
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb")
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 0
|
||||
|
||||
|
||||
async def test_datasets_routing_table(cached_disk_dist_registry):
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
|
5
tests/unit/distribution/routing_tables/__init__.py
Normal file
5
tests/unit/distribution/routing_tables/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
274
tests/unit/distribution/routing_tables/test_vector_dbs.py
Normal file
274
tests/unit/distribution/routing_tables/test_vector_dbs.py
Normal file
|
@ -0,0 +1,274 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Unit tests for the routing tables vector_dbs
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
VectorStoreContent,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileCounts,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.distribution.access_control.datatypes import AccessRule, Scope
|
||||
from llama_stack.distribution.datatypes import User
|
||||
from llama_stack.distribution.request_headers import request_provider_data_context
|
||||
from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||
from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable
|
||||
|
||||
|
||||
class VectorDBImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.vector_io)
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB):
|
||||
return vector_db
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str):
|
||||
return vector_db_id
|
||||
|
||||
async def openai_retrieve_vector_store(self, vector_store_id):
|
||||
return VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name="Test Store",
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
|
||||
async def openai_update_vector_store(self, vector_store_id, **kwargs):
|
||||
return VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name="Updated Store",
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store(self, vector_store_id):
|
||||
return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True)
|
||||
|
||||
async def openai_search_vector_store(self, vector_store_id, query, **kwargs):
|
||||
return VectorStoreSearchResponsePage(
|
||||
object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None
|
||||
)
|
||||
|
||||
async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs):
|
||||
return VectorStoreFileObject(
|
||||
id=file_id,
|
||||
status="completed",
|
||||
chunking_strategy={"type": "auto"},
|
||||
created_at=int(time.time()),
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs):
|
||||
return [
|
||||
VectorStoreFileObject(
|
||||
id="1",
|
||||
status="completed",
|
||||
chunking_strategy={"type": "auto"},
|
||||
created_at=int(time.time()),
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
]
|
||||
|
||||
async def openai_retrieve_vector_store_file(self, vector_store_id, file_id):
|
||||
return VectorStoreFileObject(
|
||||
id=file_id,
|
||||
status="completed",
|
||||
chunking_strategy={"type": "auto"},
|
||||
created_at=int(time.time()),
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id):
|
||||
return VectorStoreFileContentsResponse(
|
||||
file_id=file_id,
|
||||
filename="Sample File name",
|
||||
attributes={"key": "value"},
|
||||
content=[VectorStoreContent(type="text", text="Sample content")],
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs):
|
||||
return VectorStoreFileObject(
|
||||
id=file_id,
|
||||
status="completed",
|
||||
chunking_strategy={"type": "auto"},
|
||||
created_at=int(time.time()),
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
|
||||
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
|
||||
|
||||
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
# Register multiple vector databases and verify listing
|
||||
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
|
||||
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
|
||||
assert len(vector_dbs.data) == 2
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
assert "test-vectordb" in vector_db_ids
|
||||
assert "test-vectordb-2" in vector_db_ids
|
||||
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb")
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 0
|
||||
|
||||
|
||||
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
|
||||
impl = VectorDBImpl()
|
||||
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[])
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
|
||||
authorized_table = "vs1"
|
||||
authorized_team = "team1"
|
||||
unauthorized_team = "team2"
|
||||
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
_ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
|
||||
|
||||
# Authorized reader
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
res = await table.openai_retrieve_vector_store(authorized_table)
|
||||
assert res == "OK"
|
||||
|
||||
# Authorized updater
|
||||
impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED")
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
|
||||
assert res == "UPDATED"
|
||||
|
||||
# Unauthorized reader
|
||||
unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]})
|
||||
with request_provider_data_context({}, unauthorized_user):
|
||||
with pytest.raises(ValueError):
|
||||
await table.openai_retrieve_vector_store(authorized_table)
|
||||
|
||||
# Unauthorized updater
|
||||
with request_provider_data_context({}, unauthorized_user):
|
||||
with pytest.raises(ValueError):
|
||||
await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
|
||||
|
||||
# Authorized deleter
|
||||
impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED")
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
|
||||
assert res == "DELETED"
|
||||
|
||||
# Unauthorized deleter
|
||||
with request_provider_data_context({}, unauthorized_user):
|
||||
with pytest.raises(ValueError):
|
||||
await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
|
||||
|
||||
|
||||
async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry):
|
||||
impl = VectorDBImpl()
|
||||
|
||||
policy = [
|
||||
AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"),
|
||||
AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"),
|
||||
]
|
||||
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy)
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
|
||||
|
||||
vector_db_id = "vs1"
|
||||
file_id = "file-1"
|
||||
|
||||
admin_user = User(principal="admin", attributes={"roles": ["admin"]})
|
||||
read_only_user = User(principal="reader", attributes={"roles": ["reader"]})
|
||||
no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]})
|
||||
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
with request_provider_data_context({}, admin_user):
|
||||
await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
|
||||
|
||||
read_methods = [
|
||||
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
|
||||
(table.openai_search_vector_store, (vector_db_id, "query"), {}),
|
||||
(table.openai_list_files_in_vector_store, (vector_db_id,), {}),
|
||||
(table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}),
|
||||
(table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}),
|
||||
]
|
||||
update_methods = [
|
||||
(table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}),
|
||||
(table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}),
|
||||
(table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}),
|
||||
]
|
||||
delete_methods = [
|
||||
(table.openai_delete_vector_store_file, (vector_db_id, file_id), {}),
|
||||
(table.openai_delete_vector_store, (vector_db_id,), {}),
|
||||
]
|
||||
|
||||
for user in [admin_user, read_only_user]:
|
||||
with request_provider_data_context({}, user):
|
||||
for method, args, kwargs in read_methods:
|
||||
result = await method(*args, **kwargs)
|
||||
assert result is not None, f"Read operation failed with user {user.principal}"
|
||||
|
||||
with request_provider_data_context({}, no_access_user):
|
||||
for method, args, kwargs in read_methods:
|
||||
with pytest.raises(ValueError):
|
||||
await method(*args, **kwargs)
|
||||
|
||||
with request_provider_data_context({}, admin_user):
|
||||
for method, args, kwargs in update_methods:
|
||||
result = await method(*args, **kwargs)
|
||||
assert result is not None, "Update operation failed with admin user"
|
||||
|
||||
with request_provider_data_context({}, admin_user):
|
||||
for method, args, kwargs in delete_methods:
|
||||
result = await method(*args, **kwargs)
|
||||
assert result is not None, "Delete operation failed with admin user"
|
||||
|
||||
for user in [read_only_user, no_access_user]:
|
||||
with request_provider_data_context({}, user):
|
||||
for method, args, kwargs in delete_methods:
|
||||
with pytest.raises(ValueError):
|
||||
await method(*args, **kwargs)
|
176
tests/unit/providers/agent/test_get_raw_document_text.py
Normal file
176
tests/unit/providers/agent/test_get_raw_document_text.py
Normal file
|
@ -0,0 +1,176 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import Document
|
||||
from llama_stack.apis.common.content_types import URL, TextContentItem
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import get_raw_document_text
|
||||
|
||||
|
||||
async def test_get_raw_document_text_supports_text_mime_types():
|
||||
"""Test that the function accepts text/* mime types."""
|
||||
document = Document(content="Sample text content", mime_type="text/plain")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == "Sample text content"
|
||||
|
||||
|
||||
async def test_get_raw_document_text_supports_yaml_mime_type():
|
||||
"""Test that the function accepts application/yaml mime type."""
|
||||
yaml_content = """
|
||||
name: test
|
||||
version: 1.0
|
||||
items:
|
||||
- item1
|
||||
- item2
|
||||
"""
|
||||
|
||||
document = Document(content=yaml_content, mime_type="application/yaml")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == yaml_content
|
||||
|
||||
|
||||
async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning():
|
||||
"""Test that the function accepts text/yaml but emits a deprecation warning."""
|
||||
yaml_content = """
|
||||
name: test
|
||||
version: 1.0
|
||||
items:
|
||||
- item1
|
||||
- item2
|
||||
"""
|
||||
|
||||
document = Document(content=yaml_content, mime_type="text/yaml")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = await get_raw_document_text(document)
|
||||
|
||||
# Check that result is correct
|
||||
assert result == yaml_content
|
||||
|
||||
# Check that exactly one warning was issued
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "text/yaml" in str(w[0].message)
|
||||
assert "application/yaml" in str(w[0].message)
|
||||
assert "deprecated" in str(w[0].message).lower()
|
||||
|
||||
|
||||
async def test_get_raw_document_text_deprecated_text_yaml_with_url():
|
||||
"""Test that text/yaml works with URL content and emits warning."""
|
||||
yaml_content = "name: test\nversion: 1.0"
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
|
||||
mock_load.return_value = yaml_content
|
||||
|
||||
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="text/yaml")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = await get_raw_document_text(document)
|
||||
|
||||
# Check that result is correct
|
||||
assert result == yaml_content
|
||||
mock_load.assert_called_once_with("https://example.com/config.yaml")
|
||||
|
||||
# Check that deprecation warning was issued
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "text/yaml" in str(w[0].message)
|
||||
|
||||
|
||||
async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item():
|
||||
"""Test that text/yaml works with TextContentItem and emits warning."""
|
||||
yaml_content = "key: value\nlist:\n - item1\n - item2"
|
||||
|
||||
document = Document(content=TextContentItem(text=yaml_content), mime_type="text/yaml")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = await get_raw_document_text(document)
|
||||
|
||||
# Check that result is correct
|
||||
assert result == yaml_content
|
||||
|
||||
# Check that deprecation warning was issued
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "text/yaml" in str(w[0].message)
|
||||
|
||||
|
||||
async def test_get_raw_document_text_rejects_unsupported_mime_types():
|
||||
"""Test that the function rejects unsupported mime types."""
|
||||
document = Document(
|
||||
content="Some content",
|
||||
mime_type="application/json", # Not supported
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unexpected document mime type: application/json"):
|
||||
await get_raw_document_text(document)
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_url_content():
|
||||
"""Test that the function handles URL content correctly."""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.text = "Content from URL"
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
|
||||
mock_load.return_value = "Content from URL"
|
||||
|
||||
document = Document(content=URL(uri="https://example.com/test.txt"), mime_type="text/plain")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == "Content from URL"
|
||||
mock_load.assert_called_once_with("https://example.com/test.txt")
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_yaml_url():
|
||||
"""Test that the function handles YAML URLs correctly."""
|
||||
yaml_content = "name: test\nversion: 1.0"
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
|
||||
mock_load.return_value = yaml_content
|
||||
|
||||
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="application/yaml")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == yaml_content
|
||||
mock_load.assert_called_once_with("https://example.com/config.yaml")
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_text_content_item():
|
||||
"""Test that the function handles TextContentItem correctly."""
|
||||
document = Document(content=TextContentItem(text="Text content item"), mime_type="text/plain")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == "Text content item"
|
||||
|
||||
|
||||
async def test_get_raw_document_text_with_yaml_text_content_item():
|
||||
"""Test that the function handles YAML TextContentItem correctly."""
|
||||
yaml_content = "key: value\nlist:\n - item1\n - item2"
|
||||
|
||||
document = Document(content=TextContentItem(text=yaml_content), mime_type="application/yaml")
|
||||
|
||||
result = await get_raw_document_text(document)
|
||||
assert result == yaml_content
|
||||
|
||||
|
||||
async def test_get_raw_document_text_rejects_unexpected_content_type():
|
||||
"""Test that the function rejects unexpected document content types."""
|
||||
# Create a mock document that bypasses Pydantic validation
|
||||
mock_document = MagicMock(spec=Document)
|
||||
mock_document.mime_type = "text/plain"
|
||||
mock_document.content = 123 # Unexpected content type (not str, URL, or TextContentItem)
|
||||
|
||||
with pytest.raises(ValueError, match="Unexpected document content type: <class 'int'>"):
|
||||
await get_raw_document_text(mock_document)
|
Loading…
Add table
Add a link
Reference in a new issue