mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 05:38:38 +00:00
chore: Updating documentation and adding exception handling for Vector Stores in RAG Tool and updating inference to use openai and updating memory implementation to use existing libraries
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
28696c3f30
commit
ff0bd414b1
27 changed files with 926 additions and 403 deletions
|
@ -6,9 +6,7 @@ Integration tests verify complete workflows across different providers using Lla
|
|||
|
||||
```bash
|
||||
# Run all integration tests with existing recordings
|
||||
LLAMA_STACK_TEST_INFERENCE_MODE=replay \
|
||||
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
|
||||
uv run --group test \
|
||||
uv run --group test \
|
||||
pytest -sv tests/integration/ --stack-config=starter
|
||||
```
|
||||
|
||||
|
@ -42,25 +40,35 @@ Model parameters can be influenced by the following options:
|
|||
Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped
|
||||
if no model is specified.
|
||||
|
||||
### Suites (fast selection + sane defaults)
|
||||
### Suites and Setups
|
||||
|
||||
- `--suite`: comma-separated list of named suites that both narrow which tests are collected and prefill common model options (unless you pass them explicitly).
|
||||
- `--suite`: single named suite that narrows which tests are collected.
|
||||
- Available suites:
|
||||
- `responses`: collects tests under `tests/integration/responses`; this is a separate suite because it needs a strong tool-calling model.
|
||||
- `vision`: collects only `tests/integration/inference/test_vision_inference.py`; defaults `--vision-model=ollama/llama3.2-vision:11b`, `--embedding-model=sentence-transformers/all-MiniLM-L6-v2`.
|
||||
- Explicit flags always win. For example, `--suite=responses --text-model=<X>` overrides the suite’s text model.
|
||||
- `base`: collects most tests (excludes responses and post_training)
|
||||
- `responses`: collects tests under `tests/integration/responses` (needs strong tool-calling models)
|
||||
- `vision`: collects only `tests/integration/inference/test_vision_inference.py`
|
||||
- `--setup`: global configuration that can be used with any suite. Setups prefill model/env defaults; explicit CLI flags always win.
|
||||
- Available setups:
|
||||
- `ollama`: Local Ollama provider with lightweight models (sets OLLAMA_URL, uses llama3.2:3b-instruct-fp16)
|
||||
- `vllm`: VLLM provider for efficient local inference (sets VLLM_URL, uses Llama-3.2-1B-Instruct)
|
||||
- `gpt`: OpenAI GPT models for high-quality responses (uses gpt-4o)
|
||||
- `claude`: Anthropic Claude models for high-quality responses (uses claude-3-5-sonnet)
|
||||
|
||||
Examples:
|
||||
Examples
|
||||
|
||||
```bash
|
||||
# Fast responses run with defaults
|
||||
pytest -s -v tests/integration --stack-config=server:starter --suite=responses
|
||||
# Fast responses run with a strong tool-calling model
|
||||
pytest -s -v tests/integration --stack-config=server:starter --suite=responses --setup=gpt
|
||||
|
||||
# Fast single-file vision run with defaults
|
||||
pytest -s -v tests/integration --stack-config=server:starter --suite=vision
|
||||
# Fast single-file vision run with Ollama defaults
|
||||
pytest -s -v tests/integration --stack-config=server:starter --suite=vision --setup=ollama
|
||||
|
||||
# Combine suites and override a default
|
||||
pytest -s -v tests/integration --stack-config=server:starter --suite=responses,vision --embedding-model=text-embedding-3-small
|
||||
# Base suite with VLLM for performance
|
||||
pytest -s -v tests/integration --stack-config=server:starter --suite=base --setup=vllm
|
||||
|
||||
# Override a default from setup
|
||||
pytest -s -v tests/integration --stack-config=server:starter \
|
||||
--suite=responses --setup=gpt --embedding-model=text-embedding-3-small
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
@ -127,14 +135,13 @@ pytest tests/integration/
|
|||
### RECORD Mode
|
||||
Captures API interactions for later replay:
|
||||
```bash
|
||||
LLAMA_STACK_TEST_INFERENCE_MODE=record \
|
||||
pytest tests/integration/inference/test_new_feature.py
|
||||
pytest tests/integration/inference/test_new_feature.py --inference-mode=record
|
||||
```
|
||||
|
||||
### LIVE Mode
|
||||
Tests make real API calls (but not recorded):
|
||||
```bash
|
||||
LLAMA_STACK_TEST_INFERENCE_MODE=live pytest tests/integration/
|
||||
pytest tests/integration/ --inference-mode=live
|
||||
```
|
||||
|
||||
By default, the recording directory is `tests/integration/recordings`. You can override this by setting the `LLAMA_STACK_TEST_RECORDING_DIR` environment variable.
|
||||
|
@ -155,15 +162,14 @@ cat recordings/responses/abc123.json | jq '.'
|
|||
#### Remote Re-recording (Recommended)
|
||||
Use the automated workflow script for easier re-recording:
|
||||
```bash
|
||||
./scripts/github/schedule-record-workflow.sh --test-subdirs "inference,agents"
|
||||
./scripts/github/schedule-record-workflow.sh --subdirs "inference,agents"
|
||||
```
|
||||
See the [main testing guide](../README.md#remote-re-recording-recommended) for full details.
|
||||
|
||||
#### Local Re-recording
|
||||
```bash
|
||||
# Re-record specific tests
|
||||
LLAMA_STACK_TEST_INFERENCE_MODE=record \
|
||||
pytest -s -v --stack-config=server:starter tests/integration/inference/test_modified.py
|
||||
pytest -s -v --stack-config=server:starter tests/integration/inference/test_modified.py --inference-mode=record
|
||||
```
|
||||
|
||||
Note that when re-recording tests, you must use a Stack pointing to a server (i.e., `server:starter`). This subtlety exists because the set of tests run in server are a superset of the set of tests run in the library client.
|
||||
|
|
|
@ -15,7 +15,7 @@ from dotenv import load_dotenv
|
|||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .suites import SUITE_DEFINITIONS
|
||||
from .suites import SETUP_DEFINITIONS, SUITE_DEFINITIONS
|
||||
|
||||
logger = get_logger(__name__, category="tests")
|
||||
|
||||
|
@ -63,19 +63,33 @@ def pytest_configure(config):
|
|||
key, value = env_var.split("=", 1)
|
||||
os.environ[key] = value
|
||||
|
||||
suites_raw = config.getoption("--suite")
|
||||
suites: list[str] = []
|
||||
if suites_raw:
|
||||
suites = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
|
||||
unknown = [p for p in suites if p not in SUITE_DEFINITIONS]
|
||||
if unknown:
|
||||
inference_mode = config.getoption("--inference-mode")
|
||||
os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = inference_mode
|
||||
|
||||
suite = config.getoption("--suite")
|
||||
if suite:
|
||||
if suite not in SUITE_DEFINITIONS:
|
||||
raise pytest.UsageError(f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}")
|
||||
|
||||
# Apply setups (global parameterizations): env + defaults
|
||||
setup = config.getoption("--setup")
|
||||
if suite and not setup:
|
||||
setup = SUITE_DEFINITIONS[suite].default_setup
|
||||
|
||||
if setup:
|
||||
if setup not in SETUP_DEFINITIONS:
|
||||
raise pytest.UsageError(
|
||||
f"Unknown suite(s): {', '.join(unknown)}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}"
|
||||
f"Unknown setup '{setup}'. Available: {', '.join(sorted(SETUP_DEFINITIONS.keys()))}"
|
||||
)
|
||||
for suite in suites:
|
||||
suite_def = SUITE_DEFINITIONS.get(suite, {})
|
||||
defaults: dict = suite_def.get("defaults", {})
|
||||
for dest, value in defaults.items():
|
||||
|
||||
setup_obj = SETUP_DEFINITIONS[setup]
|
||||
logger.info(f"Applying setup '{setup}'{' for suite ' + suite if suite else ''}")
|
||||
# Apply env first
|
||||
for k, v in setup_obj.env.items():
|
||||
if k not in os.environ:
|
||||
os.environ[k] = str(v)
|
||||
# Apply defaults if not provided explicitly
|
||||
for dest, value in setup_obj.defaults.items():
|
||||
current = getattr(config.option, dest, None)
|
||||
if not current:
|
||||
setattr(config.option, dest, value)
|
||||
|
@ -120,6 +134,13 @@ def pytest_addoption(parser):
|
|||
default=384,
|
||||
help="Output dimensionality of the embedding model to use for testing. Default: 384",
|
||||
)
|
||||
|
||||
parser.addoption(
|
||||
"--inference-mode",
|
||||
help="Inference mode: { record, replay, live } (default: replay)",
|
||||
choices=["record", "replay", "live"],
|
||||
default="replay",
|
||||
)
|
||||
parser.addoption(
|
||||
"--report",
|
||||
help="Path where the test report should be written, e.g. --report=/path/to/report.md",
|
||||
|
@ -127,14 +148,18 @@ def pytest_addoption(parser):
|
|||
|
||||
available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys()))
|
||||
suite_help = (
|
||||
"Comma-separated integration test suites to narrow collection and prefill defaults. "
|
||||
"Available: "
|
||||
f"{available_suites}. "
|
||||
"Explicit CLI flags (e.g., --text-model) override suite defaults. "
|
||||
"Examples: --suite=responses or --suite=responses,vision."
|
||||
f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses"
|
||||
)
|
||||
parser.addoption("--suite", help=suite_help)
|
||||
|
||||
# Global setups for any suite
|
||||
available_setups = ", ".join(sorted(SETUP_DEFINITIONS.keys()))
|
||||
setup_help = (
|
||||
f"Global test setup configuration. Available: {available_setups}. "
|
||||
"Can be used with any suite. Example: --setup=ollama"
|
||||
)
|
||||
parser.addoption("--setup", help=setup_help)
|
||||
|
||||
|
||||
MODEL_SHORT_IDS = {
|
||||
"meta-llama/Llama-3.2-3B-Instruct": "3B",
|
||||
|
@ -221,16 +246,12 @@ pytest_plugins = ["tests.integration.fixtures.common"]
|
|||
|
||||
def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
|
||||
"""Skip collecting paths outside the selected suite roots for speed."""
|
||||
suites_raw = config.getoption("--suite")
|
||||
if not suites_raw:
|
||||
suite = config.getoption("--suite")
|
||||
if not suite:
|
||||
return False
|
||||
|
||||
names = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
|
||||
roots: list[str] = []
|
||||
for name in names:
|
||||
suite_def = SUITE_DEFINITIONS.get(name)
|
||||
if suite_def:
|
||||
roots.extend(suite_def.get("roots", []))
|
||||
sobj = SUITE_DEFINITIONS.get(suite)
|
||||
roots: list[str] = sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", [])
|
||||
if not roots:
|
||||
return False
|
||||
|
||||
|
|
|
@ -76,6 +76,9 @@ def skip_if_doesnt_support_n(client_with_models, model_id):
|
|||
"remote::gemini",
|
||||
# https://docs.anthropic.com/en/api/openai-sdk#simple-fields
|
||||
"remote::anthropic",
|
||||
"remote::vertexai",
|
||||
# Error code: 400 - [{'error': {'code': 400, 'message': 'Unable to submit request because candidateCount must be 1 but
|
||||
# the entered value was 2. Update the candidateCount value and try again.', 'status': 'INVALID_ARGUMENT'}
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
|
||||
|
||||
|
|
|
@ -8,46 +8,112 @@
|
|||
# For example:
|
||||
#
|
||||
# ```bash
|
||||
# pytest tests/integration/ --suite=vision
|
||||
# pytest tests/integration/ --suite=vision --setup=ollama
|
||||
# ```
|
||||
#
|
||||
# Each suite can:
|
||||
# - restrict collection to specific roots (dirs or files)
|
||||
# - provide default CLI option values (e.g. text_model, embedding_model, etc.)
|
||||
"""
|
||||
Each suite defines what to run (roots). Suites can be run with different global setups defined in setups.py.
|
||||
Setups provide environment variables and model defaults that can be reused across multiple suites.
|
||||
|
||||
CLI examples:
|
||||
pytest tests/integration --suite=responses --setup=gpt
|
||||
pytest tests/integration --suite=vision --setup=ollama
|
||||
pytest tests/integration --suite=base --setup=vllm
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
this_dir = Path(__file__).parent
|
||||
default_roots = [
|
||||
|
||||
|
||||
class Suite(BaseModel):
|
||||
name: str
|
||||
roots: list[str]
|
||||
default_setup: str | None = None
|
||||
|
||||
|
||||
class Setup(BaseModel):
|
||||
"""A reusable test configuration with environment and CLI defaults."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
defaults: dict[str, str] = Field(default_factory=dict)
|
||||
env: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
# Global setups - can be used with any suite "technically" but in reality, some setups might work
|
||||
# only for specific test suites.
|
||||
SETUP_DEFINITIONS: dict[str, Setup] = {
|
||||
"ollama": Setup(
|
||||
name="ollama",
|
||||
description="Local Ollama provider with text + safety models",
|
||||
env={
|
||||
"OLLAMA_URL": "http://0.0.0.0:11434",
|
||||
"SAFETY_MODEL": "ollama/llama-guard3:1b",
|
||||
},
|
||||
defaults={
|
||||
"text_model": "ollama/llama3.2:3b-instruct-fp16",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"safety_model": "ollama/llama-guard3:1b",
|
||||
"safety_shield": "llama-guard",
|
||||
},
|
||||
),
|
||||
"ollama-vision": Setup(
|
||||
name="ollama",
|
||||
description="Local Ollama provider with a vision model",
|
||||
env={
|
||||
"OLLAMA_URL": "http://0.0.0.0:11434",
|
||||
},
|
||||
defaults={
|
||||
"vision_model": "ollama/llama3.2-vision:11b",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
},
|
||||
),
|
||||
"vllm": Setup(
|
||||
name="vllm",
|
||||
description="vLLM provider with a text model",
|
||||
env={
|
||||
"VLLM_URL": "http://localhost:8000/v1",
|
||||
},
|
||||
defaults={
|
||||
"text_model": "vllm/meta-llama/Llama-3.2-1B-Instruct",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
},
|
||||
),
|
||||
"gpt": Setup(
|
||||
name="gpt",
|
||||
description="OpenAI GPT models for high-quality responses and tool calling",
|
||||
defaults={
|
||||
"text_model": "openai/gpt-4o",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
base_roots = [
|
||||
str(p)
|
||||
for p in this_dir.glob("*")
|
||||
if p.is_dir()
|
||||
and p.name not in ("__pycache__", "fixtures", "test_cases", "recordings", "responses", "post_training")
|
||||
]
|
||||
|
||||
SUITE_DEFINITIONS: dict[str, dict] = {
|
||||
"base": {
|
||||
"description": "Base suite that includes most tests but runs them with a text Ollama model",
|
||||
"roots": default_roots,
|
||||
"defaults": {
|
||||
"text_model": "ollama/llama3.2:3b-instruct-fp16",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
},
|
||||
},
|
||||
"responses": {
|
||||
"description": "Suite that includes only the OpenAI Responses tests; needs a strong tool-calling model",
|
||||
"roots": ["tests/integration/responses"],
|
||||
"defaults": {
|
||||
"text_model": "openai/gpt-4o",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
},
|
||||
},
|
||||
"vision": {
|
||||
"description": "Suite that includes only the vision tests",
|
||||
"roots": ["tests/integration/inference/test_vision_inference.py"],
|
||||
"defaults": {
|
||||
"vision_model": "ollama/llama3.2-vision:11b",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
},
|
||||
},
|
||||
SUITE_DEFINITIONS: dict[str, Suite] = {
|
||||
"base": Suite(
|
||||
name="base",
|
||||
roots=base_roots,
|
||||
default_setup="ollama",
|
||||
),
|
||||
"responses": Suite(
|
||||
name="responses",
|
||||
roots=["tests/integration/responses"],
|
||||
default_setup="gpt",
|
||||
),
|
||||
"vision": Suite(
|
||||
name="vision",
|
||||
roots=["tests/integration/inference/test_vision_inference.py"],
|
||||
default_setup="ollama-vision",
|
||||
),
|
||||
}
|
||||
|
|
|
@ -183,6 +183,110 @@ def test_vector_db_insert_from_url_and_query(
|
|||
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
|
||||
|
||||
|
||||
def test_rag_tool_openai_apis(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||
vector_db_id = "test_openai_vector_db"
|
||||
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
|
||||
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||
actual_vector_db_id = available_vector_dbs[0]
|
||||
|
||||
# different document formats that should work with OpenAI APIs
|
||||
documents = [
|
||||
Document(
|
||||
document_id="text-doc",
|
||||
content="This is a plain text document about machine learning algorithms.",
|
||||
metadata={"type": "text", "category": "AI"},
|
||||
),
|
||||
Document(
|
||||
document_id="url-doc",
|
||||
content="https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst",
|
||||
mime_type="text/plain",
|
||||
metadata={"type": "url", "source": "pytorch"},
|
||||
),
|
||||
Document(
|
||||
document_id="data-url-doc",
|
||||
content="data:text/plain;base64,VGhpcyBpcyBhIGRhdGEgVVJMIGRvY3VtZW50IGFib3V0IGRlZXAgbGVhcm5pbmcu", # "This is a data URL document about deep learning."
|
||||
metadata={"type": "data_url", "encoding": "base64"},
|
||||
),
|
||||
]
|
||||
|
||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||
documents=documents,
|
||||
vector_db_id=actual_vector_db_id,
|
||||
chunk_size_in_tokens=256,
|
||||
)
|
||||
|
||||
files_list = client_with_empty_registry.files.list()
|
||||
assert len(files_list.data) >= len(documents), (
|
||||
f"Expected at least {len(documents)} files, got {len(files_list.data)}"
|
||||
)
|
||||
|
||||
vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
|
||||
vector_store_id=actual_vector_db_id
|
||||
)
|
||||
assert len(vector_store_files.data) >= len(documents), f"Expected at least {len(documents)} files in vector store"
|
||||
|
||||
response = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||
vector_db_ids=[actual_vector_db_id],
|
||||
content="Tell me about machine learning and deep learning",
|
||||
)
|
||||
|
||||
assert_valid_text_response(response)
|
||||
content_text = " ".join([chunk.text for chunk in response.content]).lower()
|
||||
assert "machine learning" in content_text or "deep learning" in content_text
|
||||
|
||||
|
||||
def test_rag_tool_exception_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||
vector_db_id = "test_exception_handling"
|
||||
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
|
||||
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||
actual_vector_db_id = available_vector_dbs[0]
|
||||
|
||||
documents = [
|
||||
Document(
|
||||
document_id="valid-doc",
|
||||
content="This is a valid document that should be processed successfully.",
|
||||
metadata={"status": "valid"},
|
||||
),
|
||||
Document(
|
||||
document_id="invalid-url-doc",
|
||||
content="https://nonexistent-domain-12345.com/invalid.txt",
|
||||
metadata={"status": "invalid_url"},
|
||||
),
|
||||
Document(
|
||||
document_id="another-valid-doc",
|
||||
content="This is another valid document for testing resilience.",
|
||||
metadata={"status": "valid"},
|
||||
),
|
||||
]
|
||||
|
||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||
documents=documents,
|
||||
vector_db_id=actual_vector_db_id,
|
||||
chunk_size_in_tokens=256,
|
||||
)
|
||||
|
||||
response = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||
vector_db_ids=[actual_vector_db_id],
|
||||
content="valid document",
|
||||
)
|
||||
|
||||
assert_valid_text_response(response)
|
||||
content_text = " ".join([chunk.text for chunk in response.content]).lower()
|
||||
assert "valid document" in content_text
|
||||
|
||||
|
||||
def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
||||
assert len(providers) > 0
|
||||
|
@ -249,3 +353,107 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
|
|||
"chunk_template": "This should raise a ValueError because it is missing the proper template variables",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_rag_tool_query_generation(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||
vector_db_id = "test_query_generation_db"
|
||||
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
|
||||
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||
actual_vector_db_id = available_vector_dbs[0]
|
||||
|
||||
documents = [
|
||||
Document(
|
||||
document_id="ai-doc",
|
||||
content="Artificial intelligence and machine learning are transforming technology.",
|
||||
metadata={"category": "AI"},
|
||||
),
|
||||
Document(
|
||||
document_id="banana-doc",
|
||||
content="Don't bring a banana to a knife fight.",
|
||||
metadata={"category": "wisdom"},
|
||||
),
|
||||
]
|
||||
|
||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||
documents=documents,
|
||||
vector_db_id=actual_vector_db_id,
|
||||
chunk_size_in_tokens=256,
|
||||
)
|
||||
|
||||
response = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||
vector_db_ids=[actual_vector_db_id],
|
||||
content="Tell me about AI",
|
||||
)
|
||||
|
||||
assert_valid_text_response(response)
|
||||
content_text = " ".join([chunk.text for chunk in response.content]).lower()
|
||||
assert "artificial intelligence" in content_text or "machine learning" in content_text
|
||||
|
||||
|
||||
def test_rag_tool_pdf_data_url_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||
vector_db_id = "test_pdf_data_url_db"
|
||||
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
|
||||
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||
actual_vector_db_id = available_vector_dbs[0]
|
||||
|
||||
sample_pdf = b"%PDF-1.3\n3 0 obj\n<</Type /Page\n/Parent 1 0 R\n/Resources 2 0 R\n/Contents 4 0 R>>\nendobj\n4 0 obj\n<</Filter /FlateDecode /Length 115>>\nstream\nx\x9c\x15\xcc1\x0e\x820\x18@\xe1\x9dS\xbcM]jk$\xd5\xd5(\x83!\x86\xa1\x17\xf8\xa3\xa5`LIh+\xd7W\xc6\xf7\r\xef\xc0\xbd\xd2\xaa\xb6,\xd5\xc5\xb1o\x0c\xa6VZ\xe3znn%\xf3o\xab\xb1\xe7\xa3:Y\xdc\x8bm\xeb\xf3&1\xc8\xd7\xd3\x97\xc82\xe6\x81\x87\xe42\xcb\x87Vb(\x12<\xdd<=}Jc\x0cL\x91\xee\xda$\xb5\xc3\xbd\xd7\xe9\x0f\x8d\x97 $\nendstream\nendobj\n1 0 obj\n<</Type /Pages\n/Kids [3 0 R ]\n/Count 1\n/MediaBox [0 0 595.28 841.89]\n>>\nendobj\n5 0 obj\n<</Type /Font\n/BaseFont /Helvetica\n/Subtype /Type1\n/Encoding /WinAnsiEncoding\n>>\nendobj\n2 0 obj\n<<\n/ProcSet [/PDF /Text /ImageB /ImageC /ImageI]\n/Font <<\n/F1 5 0 R\n>>\n/XObject <<\n>>\n>>\nendobj\n6 0 obj\n<<\n/Producer (PyFPDF 1.7.2 http://pyfpdf.googlecode.com/)\n/Title (This is a sample title.)\n/Author (Llama Stack Developers)\n/CreationDate (D:20250312165548)\n>>\nendobj\n7 0 obj\n<<\n/Type /Catalog\n/Pages 1 0 R\n/OpenAction [3 0 R /FitH null]\n/PageLayout /OneColumn\n>>\nendobj\nxref\n0 8\n0000000000 65535 f \n0000000272 00000 n \n0000000455 00000 n \n0000000009 00000 n \n0000000087 00000 n \n0000000359 00000 n \n0000000559 00000 n \n0000000734 00000 n \ntrailer\n<<\n/Size 8\n/Root 7 0 R\n/Info 6 0 R\n>>\nstartxref\n837\n%%EOF\n"
|
||||
|
||||
import base64
|
||||
|
||||
pdf_base64 = base64.b64encode(sample_pdf).decode("utf-8")
|
||||
pdf_data_url = f"data:application/pdf;base64,{pdf_base64}"
|
||||
|
||||
documents = [
|
||||
Document(
|
||||
document_id="test-pdf-data-url",
|
||||
content=pdf_data_url,
|
||||
metadata={"type": "pdf", "source": "data_url"},
|
||||
),
|
||||
]
|
||||
|
||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||
documents=documents,
|
||||
vector_db_id=actual_vector_db_id,
|
||||
chunk_size_in_tokens=256,
|
||||
)
|
||||
|
||||
files_list = client_with_empty_registry.files.list()
|
||||
assert len(files_list.data) >= 1, "PDF should have been uploaded to Files API"
|
||||
|
||||
pdf_file = None
|
||||
for file in files_list.data:
|
||||
if file.filename and "test-pdf-data-url" in file.filename:
|
||||
pdf_file = file
|
||||
break
|
||||
|
||||
assert pdf_file is not None, "PDF file should be found in Files API"
|
||||
assert pdf_file.bytes == len(sample_pdf), f"File size should match original PDF ({len(sample_pdf)} bytes)"
|
||||
|
||||
file_content = client_with_empty_registry.files.retrieve_content(pdf_file.id)
|
||||
assert file_content.startswith(b"%PDF-"), "Retrieved file should be a valid PDF"
|
||||
|
||||
vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
|
||||
vector_store_id=actual_vector_db_id
|
||||
)
|
||||
assert len(vector_store_files.data) >= 1, "PDF should be attached to vector store"
|
||||
|
||||
response = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||
vector_db_ids=[actual_vector_db_id],
|
||||
content="sample title",
|
||||
)
|
||||
|
||||
assert_valid_text_response(response)
|
||||
content_text = " ".join([chunk.text for chunk in response.content]).lower()
|
||||
assert "sample title" in content_text or "title" in content_text
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue