forked from phoenix-oss/llama-stack-mirror
refactor(test): introduce --stack-config and simplify options (#1404)
You now run the integration tests with these options:
```bash
Custom options:
--stack-config=STACK_CONFIG
a 'pointer' to the stack. this can be either be:
(a) a template name like `fireworks`, or
(b) a path to a run.yaml file, or
(c) an adhoc config spec, e.g.
`inference=fireworks,safety=llama-guard,agents=meta-
reference`
--env=ENV Set environment variables, e.g. --env KEY=value
--text-model=TEXT_MODEL
comma-separated list of text models. Fixture name:
text_model_id
--vision-model=VISION_MODEL
comma-separated list of vision models. Fixture name:
vision_model_id
--embedding-model=EMBEDDING_MODEL
comma-separated list of embedding models. Fixture name:
embedding_model_id
--safety-shield=SAFETY_SHIELD
comma-separated list of safety shields. Fixture name:
shield_id
--judge-model=JUDGE_MODEL
comma-separated list of judge models. Fixture name:
judge_model_id
--embedding-dimension=EMBEDDING_DIMENSION
Output dimensionality of the embedding model to use for
testing. Default: 384
--record-responses Record new API responses instead of using cached ones.
--report=REPORT Path where the test report should be written, e.g.
--report=/path/to/report.md
```
Importantly, if you don't specify any of the models (text-model,
vision-model, etc.) the relevant tests will get **skipped!**
This will make running tests somewhat more annoying since all options
will need to be specified. We will make this easier by adding some easy
wrapper yaml configs.
## Test Plan
Example:
```bash
ashwin@ashwin-mbp ~/local/llama-stack/tests/integration (unify_tests) $
LLAMA_STACK_CONFIG=fireworks pytest -s -v inference/test_text_inference.py \
--text-model meta-llama/Llama-3.2-3B-Instruct
```
This commit is contained in:
parent
a0d6b165b0
commit
2fe976ed0a
15 changed files with 536 additions and 1144 deletions
|
|
@ -1,101 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.tools import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
)
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.providers.utils.memory.vector_store import interleaved_content_as_str
|
||||
|
||||
|
||||
class TestRAGToolEndpoints:
|
||||
@pytest.fixture
|
||||
def base_url(self) -> str:
|
||||
return "http://localhost:8321/v1" # Adjust port if needed
|
||||
|
||||
@pytest.fixture
|
||||
def sample_documents(self) -> List[RAGDocument]:
|
||||
return [
|
||||
RAGDocument(
|
||||
document_id="doc1",
|
||||
content="Python is a high-level programming language.",
|
||||
metadata={"category": "programming", "difficulty": "beginner"},
|
||||
),
|
||||
RAGDocument(
|
||||
document_id="doc2",
|
||||
content="Machine learning is a subset of artificial intelligence.",
|
||||
metadata={"category": "AI", "difficulty": "advanced"},
|
||||
),
|
||||
RAGDocument(
|
||||
document_id="doc3",
|
||||
content="Data structures are fundamental to computer science.",
|
||||
metadata={"category": "computer science", "difficulty": "intermediate"},
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_workflow(self, base_url: str, sample_documents: List[RAGDocument]):
|
||||
vector_db_payload = {
|
||||
"vector_db_id": "test_vector_db",
|
||||
"embedding_model": "all-MiniLM-L6-v2",
|
||||
"embedding_dimension": 384,
|
||||
}
|
||||
|
||||
response = requests.post(f"{base_url}/vector-dbs", json=vector_db_payload)
|
||||
assert response.status_code == 200
|
||||
vector_db = VectorDB(**response.json())
|
||||
|
||||
insert_payload = {
|
||||
"documents": [json.loads(doc.model_dump_json()) for doc in sample_documents],
|
||||
"vector_db_id": vector_db.identifier,
|
||||
"chunk_size_in_tokens": 512,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{base_url}/tool-runtime/rag-tool/insert-documents",
|
||||
json=insert_payload,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
query = "What is Python?"
|
||||
query_config = RAGQueryConfig(
|
||||
query_generator_config=DefaultRAGQueryGeneratorConfig(),
|
||||
max_tokens_in_context=4096,
|
||||
max_chunks=2,
|
||||
)
|
||||
|
||||
query_payload = {
|
||||
"content": query,
|
||||
"query_config": json.loads(query_config.model_dump_json()),
|
||||
"vector_db_ids": [vector_db.identifier],
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{base_url}/tool-runtime/rag-tool/query-context",
|
||||
json=query_payload,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
result = TypeAdapter(RAGQueryResult).validate_python(result)
|
||||
|
||||
content_str = interleaved_content_as_str(result.content)
|
||||
print(f"content: {content_str}")
|
||||
assert len(content_str) > 0
|
||||
assert "Python" in content_str
|
||||
|
||||
# Clean up: Delete the vector DB
|
||||
response = requests.delete(f"{base_url}/vector-dbs/{vector_db.identifier}")
|
||||
assert response.status_code == 200
|
||||
Loading…
Add table
Add a link
Reference in a new issue