mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
script for running client sdk tests (#895)
# What does this PR do? Create a script for running all client-sdk tests on Async Library client, with the option to generate report ## Test Plan ``` python llama_stack/scripts/run_client_sdk_tests.py --templates together fireworks --report ``` ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
a3d8c49459
commit
531940aea9
4 changed files with 74 additions and 3 deletions
|
@ -13,6 +13,7 @@ A Llama Stack API is described as a collection of REST endpoints. We currently s
|
|||
- **DatasetIO**: interface with datasets and data loaders
|
||||
- **Scoring**: evaluate outputs of the system
|
||||
- **Eval**: generate outputs (via Inference or Agents) and perform scoring
|
||||
- **VectorIO**: perform operations on vector stores, such as adding documents, searching, and deleting documents
|
||||
- **Telemetry**: collect telemetry data from the system
|
||||
|
||||
We are working on adding a few more APIs to complete the application lifecycle. These will include:
|
||||
|
@ -41,6 +42,7 @@ Some of these APIs are associated with a set of **Resources**. Here is the mappi
|
|||
- **Safety** is associated with `Shield` resources.
|
||||
- **Tool Runtime** is associated with `ToolGroup` resources.
|
||||
- **DatasetIO** is associated with `Dataset` resources.
|
||||
- **VectorIO** is associated with `VectorDB` resources.
|
||||
- **Scoring** is associated with `ScoringFunction` resources.
|
||||
- **Eval** is associated with `Model` and `Benchmark` resources.
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ conda_env: ollama
|
|||
apis:
|
||||
- agents
|
||||
- inference
|
||||
- memory
|
||||
- vector_io
|
||||
- safety
|
||||
- telemetry
|
||||
providers:
|
||||
|
@ -19,7 +19,7 @@ providers:
|
|||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||
memory:
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
|
|
69
llama_stack/scripts/run_client_sdk_tests.py
Normal file
69
llama_stack/scripts/run_client_sdk_tests.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
"""
|
||||
Script for running client-sdk on AsyncLlamaStackAsLibraryClient with templates
|
||||
|
||||
Assuming directory structure:
|
||||
- llama-stack
|
||||
- llama_stack
|
||||
- scripts
|
||||
- tests
|
||||
- client-sdk
|
||||
|
||||
Example command:
|
||||
|
||||
cd llama-stack
|
||||
EXPORT TOGETHER_API_KEY=<..>
|
||||
EXPORT FIREWORKS_API_KEY=<..>
|
||||
python llama_stack/scripts/run_client_sdk_tests.py --templates together fireworks --report
|
||||
"""
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
CLIENT_SDK_TESTS_RELATIVE_PATH = "tests/client-sdk/"
|
||||
|
||||
|
||||
def main(parser: argparse.ArgumentParser):
|
||||
args = parser.parse_args()
|
||||
templates_dir = REPO_ROOT / "llama_stack" / "templates"
|
||||
user_specified_templates = (
|
||||
[templates_dir / t for t in args.templates] if args.templates else []
|
||||
)
|
||||
for d in templates_dir.iterdir():
|
||||
if d.is_dir() and d.name != "__pycache__":
|
||||
template_configs = list(d.rglob("run.yaml"))
|
||||
if len(template_configs) == 0:
|
||||
continue
|
||||
config = template_configs[0]
|
||||
if user_specified_templates:
|
||||
if not any(config.parent == t for t in user_specified_templates):
|
||||
continue
|
||||
os.environ["LLAMA_STACK_CONFIG"] = str(config)
|
||||
pytest_args = "--report" if args.report else ""
|
||||
pytest.main(
|
||||
[
|
||||
pytest_args,
|
||||
"-s",
|
||||
"-v",
|
||||
REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="llama_test",
|
||||
)
|
||||
parser.add_argument("--templates", nargs="+")
|
||||
parser.add_argument("--report", action="store_true")
|
||||
main(parser)
|
|
@ -175,7 +175,7 @@ class Report:
|
|||
"|:-----|:-----|:-----|:-----|:-----|",
|
||||
]
|
||||
provider = [p for p in providers if p.api == str(api_group.name)]
|
||||
provider_str = provider[0].provider_type if provider else ""
|
||||
provider_str = ",".join(provider) if provider else ""
|
||||
for api, capa_map in API_MAPS[api_group].items():
|
||||
for capa, tests in capa_map.items():
|
||||
for test_name in tests:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue