From 9b0f783e5479cfada6c61df2a0fdac1096546426 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 24 Feb 2025 14:43:21 -0800 Subject: [PATCH] test: add a ci-tests distro template for running e2e tests (#1237) --- distributions/dependencies.json | 35 ++++ llama_stack/templates/ci-tests/__init__.py | 7 + llama_stack/templates/ci-tests/build.yaml | 33 ++++ llama_stack/templates/ci-tests/ci_tests.py | 123 ++++++++++++++ llama_stack/templates/ci-tests/run.yaml | 169 +++++++++++++++++++ llama_stack/templates/fireworks/fireworks.py | 2 +- 6 files changed, 368 insertions(+), 1 deletion(-) create mode 100644 llama_stack/templates/ci-tests/__init__.py create mode 100644 llama_stack/templates/ci-tests/build.yaml create mode 100644 llama_stack/templates/ci-tests/ci_tests.py create mode 100644 llama_stack/templates/ci-tests/run.yaml diff --git a/distributions/dependencies.json b/distributions/dependencies.json index 9e468f08d..18a2484f2 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -66,6 +66,41 @@ "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" ], + "ci-tests": [ + "aiosqlite", + "autoevals", + "blobfile", + "chardet", + "chromadb-client", + "datasets", + "fastapi", + "fire", + "fireworks-ai", + "httpx", + "matplotlib", + "mcp", + "nltk", + "numpy", + "openai", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pymongo", + "pypdf", + "redis", + "requests", + "scikit-learn", + "scipy", + "sentencepiece", + "sqlite-vec", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch torchvision --index-url https://download.pytorch.org/whl/cpu" + ], "dell": [ "aiohttp", "aiosqlite", diff --git a/llama_stack/templates/ci-tests/__init__.py b/llama_stack/templates/ci-tests/__init__.py new file mode 100644 index 000000000..b309587f5 --- /dev/null +++ b/llama_stack/templates/ci-tests/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .ci_tests import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/ci-tests/build.yaml b/llama_stack/templates/ci-tests/build.yaml new file mode 100644 index 000000000..a5c615f2f --- /dev/null +++ b/llama_stack/templates/ci-tests/build.yaml @@ -0,0 +1,33 @@ +version: '2' +distribution_spec: + description: Distribution for running e2e tests in CI + providers: + inference: + - remote::fireworks + - inline::sentence-transformers + vector_io: + - inline::sqlite-vec + - remote::chromadb + - remote::pgvector + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference + eval: + - inline::meta-reference + datasetio: + - remote::huggingface + - inline::localfs + scoring: + - inline::basic + - inline::llm-as-judge + - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::rag-runtime + - remote::model-context-protocol +image_type: conda diff --git a/llama_stack/templates/ci-tests/ci_tests.py b/llama_stack/templates/ci-tests/ci_tests.py new file mode 100644 index 000000000..992d9936e --- /dev/null +++ b/llama_stack/templates/ci-tests/ci_tests.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.apis.models.models import ModelType +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) +from llama_stack.models.llama.sku_list import all_registered_models +from llama_stack.providers.inline.inference.sentence_transformers import ( + SentenceTransformersInferenceConfig, +) +from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig +from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig +from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES +from llama_stack.templates.template import DistributionTemplate, RunConfigSettings + + +def get_distribution_template() -> DistributionTemplate: + providers = { + "inference": ["remote::fireworks", "inline::sentence-transformers"], + "vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"], + "safety": ["inline::llama-guard"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + "eval": ["inline::meta-reference"], + "datasetio": ["remote::huggingface", "inline::localfs"], + "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::rag-runtime", + "remote::model-context-protocol", + ], + } + name = "ci-tests" + inference_provider = Provider( + provider_id="fireworks", + provider_type="remote::fireworks", + config=FireworksImplConfig.sample_run_config(), + ) + vector_io_provider = Provider( + provider_id="sqlite-vec", + provider_type="inline::sqlite-vec", + config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"), + ) + embedding_provider = Provider( + provider_id="sentence-transformers", + provider_type="inline::sentence-transformers", + config=SentenceTransformersInferenceConfig.sample_run_config(), + ) + + core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} + default_models = [ + ModelInput( + model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id, + provider_model_id=m.provider_model_id, + provider_id="fireworks", + metadata=m.metadata, + model_type=m.model_type, + ) + for m in MODEL_ENTRIES + ] + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] + embedding_model = ModelInput( + model_id="all-MiniLM-L6-v2", + provider_id="sentence-transformers", + model_type=ModelType.embedding, + metadata={ + "embedding_dimension": 384, + }, + ) + + return DistributionTemplate( + name=name, + distro_type="self_hosted", + description="Distribution for running e2e tests in CI", + container_image=None, + template_path=None, + providers=providers, + default_models=default_models, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider, embedding_provider], + "vector_io": [vector_io_provider], + }, + default_models=default_models + [embedding_model], + default_tool_groups=default_tool_groups, + default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + ), + }, + run_config_env_vars={ + "LLAMA_STACK_PORT": ( + "5001", + "Port for the Llama Stack distribution server", + ), + "FIREWORKS_API_KEY": ( + "", + "Fireworks API Key", + ), + }, + ) diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/templates/ci-tests/run.yaml new file mode 100644 index 000000000..6696c8041 --- /dev/null +++ b/llama_stack/templates/ci-tests/run.yaml @@ -0,0 +1,169 @@ +version: '2' +image_name: ci-tests +apis: +- agents +- datasetio +- eval +- inference +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: fireworks + provider_type: remote::fireworks + config: + url: https://api.fireworks.ai/inference/v1 + api_key: ${env.FIREWORKS_API_KEY} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + vector_io: + - provider_id: sqlite-vec + provider_type: inline::sqlite-vec + config: + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/sqlite_vec.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: ${env.OTEL_SERVICE_NAME:llama-stack} + sinks: ${env.TELEMETRY_SINKS:console,sqlite} + sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ci-tests/trace_store.db} + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: {} + - provider_id: localfs + provider_type: inline::localfs + config: {} + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/registry.db +models: +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-70B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-1B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-3B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.3-70B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-Guard-3-8B + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-guard-3-8b + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-Guard-3-11B-Vision + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision + model_type: llm +- metadata: + embedding_dimension: 768 + context_length: 8192 + model_id: nomic-ai/nomic-embed-text-v1.5 + provider_id: fireworks + provider_model_id: nomic-ai/nomic-embed-text-v1.5 + model_type: embedding +- metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: sentence-transformers + model_type: embedding +shields: +- shield_id: meta-llama/Llama-Guard-3-8B +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter +server: + port: 8321 diff --git a/llama_stack/templates/fireworks/fireworks.py b/llama_stack/templates/fireworks/fireworks.py index 4457296b0..c78664dde 100644 --- a/llama_stack/templates/fireworks/fireworks.py +++ b/llama_stack/templates/fireworks/fireworks.py @@ -18,7 +18,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig -from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig +from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings