script for running client sdk tests (#895)

# What does this PR do?
Create a script for running all client-sdk tests on Async Library
client, with the option to generate report


## Test Plan

```
python llama_stack/scripts/run_client_sdk_tests.py --templates together fireworks --report
```



## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Sixian Yi 2025-02-19 22:38:06 -08:00 committed by GitHub
parent a3d8c49459
commit 531940aea9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 74 additions and 3 deletions

View file

@ -13,6 +13,7 @@ A Llama Stack API is described as a collection of REST endpoints. We currently s
- **DatasetIO**: interface with datasets and data loaders
- **Scoring**: evaluate outputs of the system
- **Eval**: generate outputs (via Inference or Agents) and perform scoring
- **VectorIO**: perform operations on vector stores, such as adding documents, searching, and deleting documents
- **Telemetry**: collect telemetry data from the system
We are working on adding a few more APIs to complete the application lifecycle. These will include:
@ -41,6 +42,7 @@ Some of these APIs are associated with a set of **Resources**. Here is the mappi
- **Safety** is associated with `Shield` resources.
- **Tool Runtime** is associated with `ToolGroup` resources.
- **DatasetIO** is associated with `Dataset` resources.
- **VectorIO** is associated with `VectorDB` resources.
- **Scoring** is associated with `ScoringFunction` resources.
- **Eval** is associated with `Model` and `Benchmark` resources.

View file

@ -10,7 +10,7 @@ conda_env: ollama
apis:
- agents
- inference
- memory
- vector_io
- safety
- telemetry
providers:
@ -19,7 +19,7 @@ providers:
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:http://localhost:11434}
memory:
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:

View file

@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
from pathlib import Path
import pytest
"""
Script for running client-sdk on AsyncLlamaStackAsLibraryClient with templates
Assuming directory structure:
- llama-stack
- llama_stack
- scripts
- tests
- client-sdk
Example command:
cd llama-stack
EXPORT TOGETHER_API_KEY=<..>
EXPORT FIREWORKS_API_KEY=<..>
python llama_stack/scripts/run_client_sdk_tests.py --templates together fireworks --report
"""
REPO_ROOT = Path(__file__).parent.parent.parent
CLIENT_SDK_TESTS_RELATIVE_PATH = "tests/client-sdk/"
def main(parser: argparse.ArgumentParser):
args = parser.parse_args()
templates_dir = REPO_ROOT / "llama_stack" / "templates"
user_specified_templates = (
[templates_dir / t for t in args.templates] if args.templates else []
)
for d in templates_dir.iterdir():
if d.is_dir() and d.name != "__pycache__":
template_configs = list(d.rglob("run.yaml"))
if len(template_configs) == 0:
continue
config = template_configs[0]
if user_specified_templates:
if not any(config.parent == t for t in user_specified_templates):
continue
os.environ["LLAMA_STACK_CONFIG"] = str(config)
pytest_args = "--report" if args.report else ""
pytest.main(
[
pytest_args,
"-s",
"-v",
REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH,
]
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="llama_test",
)
parser.add_argument("--templates", nargs="+")
parser.add_argument("--report", action="store_true")
main(parser)

View file

@ -175,7 +175,7 @@ class Report:
"|:-----|:-----|:-----|:-----|:-----|",
]
provider = [p for p in providers if p.api == str(api_group.name)]
provider_str = provider[0].provider_type if provider else ""
provider_str = ",".join(provider) if provider else ""
for api, capa_map in API_MAPS[api_group].items():
for capa, tests in capa_map.items():
for test_name in tests: