From aa65610e756c3bd2b962dafa1df8b034937986d3 Mon Sep 17 00:00:00 2001 From: snova-edwardm Date: Mon, 27 Jan 2025 15:46:30 -0800 Subject: [PATCH] Sambanova - LlamaGuard (#886) # What does this PR do? - Fix loading SambaNovaImpl issue - Add LlamaGuard model support for inference ## Test Plan Run the following unit test scripts and results ### Embedding ``` pytest -s -v --providers inference=sambanova llama_stack/providers/tests/inference/test_embeddings.py --inference-model meta-llama/Llama-3.2-11B-Vision-Instruct --env SAMBANOVA_API_KEY={SAMBANOVA_API_KEY} ``` ``` llama_stack/providers/tests/inference/test_embeddings.py::TestEmbeddings::test_embeddings[-sambanova] SKIPPED (This test is only applicable for embedding models) llama_stack/providers/tests/inference/test_embeddings.py::TestEmbeddings::test_batch_embeddings[-sambanova] SKIPPED (This test is only applicable for embedding models) =================================================================================================================== 2 skipped, 1 warning in 0.32s =================================================================================================================== ``` ### Vision ``` pytest -s -v --providers inference=sambanova llama_stack/providers/tests/inference/test_vision_inference.py --inference-model meta-llama/Llama-3.2-11B-Vision-Instruct --env SAMBANOVA_API_KEY={SAMBANOVA_API_KEY} ``` ``` llama_stack/providers/tests/inference/test_vision_inference.py::TestVisionModelInference::test_vision_chat_completion_non_streaming[-sambanova-image0-expected_strings0] PASSED llama_stack/providers/tests/inference/test_vision_inference.py::TestVisionModelInference::test_vision_chat_completion_non_streaming[-sambanova-image1-expected_strings1] PASSED llama_stack/providers/tests/inference/test_vision_inference.py::TestVisionModelInference::test_vision_chat_completion_streaming[-sambanova] PASSED =================================================================================================================== 3 passed, 1 warning in 2.68s ==================================================================================================================== ``` ### Text ``` pytest -s -v --providers inference=sambanova llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming --env SAMBANOVA_API_KEY={SAMBANOVA_API_KEY} ``` ``` llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[-sambanova] PASSED =================================================================================================================== 1 passed, 1 warning in 0.46s ==================================================================================================================== ``` ``` pytest -s -v --providers inference=sambanova llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming --env SAMBANOVA_API_KEY={SAMBANOVA_API_KEY} ``` ``` llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[-sambanova] PASSED =================================================================================================================== 1 passed, 1 warning in 0.48s ==================================================================================================================== ``` ## Before submitting - [] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [Y] Ran pre-commit to handle lint / formatting issues. - [Y] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [Y] Updated relevant documentation. - [Y] Wrote necessary unit or integration tests. --- README.md | 2 + distributions/sambanova/build.yaml | 20 +---- distributions/sambanova/run.yaml | 84 +------------------ .../remote/inference/sambanova/config.py | 2 +- .../remote/inference/sambanova/sambanova.py | 23 +++-- llama_stack/templates/sambanova/build.yaml | 10 +++ llama_stack/templates/sambanova/run.yaml | 60 +++++++++++-- 7 files changed, 84 insertions(+), 117 deletions(-) diff --git a/README.md b/README.md index 17acd0096..53235389f 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ Here is a list of the various API providers and available distributions to devel | **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | |:------------------------------------------------------------------------------------------:|:----------------------:|:------------------:|:------------------:|:------------------:|:------------------:|:------------------:| | Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| SambaNova | Hosted | | :heavy_check_mark: | | | | | Cerebras | Hosted | | :heavy_check_mark: | | | | | Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | | | AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | | @@ -57,6 +58,7 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider |:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:| | Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) | | Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | +| SambaNova | [llamastack/distribution-sambanova](https://hub.docker.com/repository/docker/llamastack/distribution-sambanova/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/sambanova.html) | | Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/cerebras.html) | | Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) | | TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) | diff --git a/distributions/sambanova/build.yaml b/distributions/sambanova/build.yaml index d6da478d1..dbf013d2d 100644 --- a/distributions/sambanova/build.yaml +++ b/distributions/sambanova/build.yaml @@ -1,19 +1 @@ -version: '2' -name: sambanova -distribution_spec: - description: Use SambaNova.AI for running LLM inference - docker_image: null - providers: - inference: - - remote::sambanova - memory: - - inline::faiss - - remote::chromadb - - remote::pgvector - safety: - - inline::llama-guard - agents: - - inline::meta-reference - telemetry: - - inline::meta-reference -image_type: conda +../../llama_stack/templates/sambanova/build.yaml diff --git a/distributions/sambanova/run.yaml b/distributions/sambanova/run.yaml index 03c8ea44f..385282c67 100644 --- a/distributions/sambanova/run.yaml +++ b/distributions/sambanova/run.yaml @@ -1,83 +1 @@ -version: '2' -image_name: sambanova -docker_image: null -conda_env: sambanova -apis: -- agents -- inference -- memory -- safety -- telemetry -providers: - inference: - - provider_id: sambanova - provider_type: remote::sambanova - config: - url: https://api.sambanova.ai/v1/ - api_key: ${env.SAMBANOVA_API_KEY} - memory: - - provider_id: faiss - provider_type: inline::faiss - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/faiss_store.db - safety: - - provider_id: llama-guard - provider_type: inline::llama-guard - config: {} - agents: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - persistence_store: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/agents_store.db - telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: {} -metadata_store: - namespace: null - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db -models: -- metadata: {} - model_id: meta-llama/Llama-3.1-8B-Instruct - provider_id: null - provider_model_id: Meta-Llama-3.1-8B-Instruct -- metadata: {} - model_id: meta-llama/Llama-3.1-70B-Instruct - provider_id: null - provider_model_id: Meta-Llama-3.1-70B-Instruct -- metadata: {} - model_id: meta-llama/Llama-3.1-405B-Instruct - provider_id: null - provider_model_id: Meta-Llama-3.1-405B-Instruct -- metadata: {} - model_id: meta-llama/Llama-3.2-1B-Instruct - provider_id: null - provider_model_id: Meta-Llama-3.2-1B-Instruct -- metadata: {} - model_id: meta-llama/Llama-3.2-3B-Instruct - provider_id: null - provider_model_id: Meta-Llama-3.2-3B-Instruct -- metadata: {} - model_id: meta-llama/Llama-3.2-11B-Vision-Instruct - provider_id: null - provider_model_id: Llama-3.2-11B-Vision-Instruct -- metadata: {} - model_id: meta-llama/Llama-3.2-90B-Vision-Instruct - provider_id: null - provider_model_id: Llama-3.2-90B-Vision-Instruct -shields: -- params: null - shield_id: meta-llama/Llama-Guard-3-8B - provider_id: null - provider_shield_id: null -memory_banks: [] -datasets: [] -scoring_fns: [] -eval_tasks: [] +../../llama_stack/templates/sambanova/run.yaml diff --git a/llama_stack/providers/remote/inference/sambanova/config.py b/llama_stack/providers/remote/inference/sambanova/config.py index e7454404b..1798841df 100644 --- a/llama_stack/providers/remote/inference/sambanova/config.py +++ b/llama_stack/providers/remote/inference/sambanova/config.py @@ -22,7 +22,7 @@ class SambaNovaImplConfig(BaseModel): ) @classmethod - def sample_run_config(cls) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: return { "url": "https://api.sambanova.ai/v1", "api_key": "${env.SAMBANOVA_API_KEY}", diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 9c203a8d0..da446567a 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -7,7 +7,12 @@ import json from typing import AsyncGenerator -from llama_models.datatypes import CoreModelId, SamplingStrategy +from llama_models.datatypes import ( + CoreModelId, + GreedySamplingStrategy, + TopKSamplingStrategy, + TopPSamplingStrategy, +) from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI @@ -60,6 +65,10 @@ MODEL_ALIASES = [ "Llama-3.2-90B-Vision-Instruct", CoreModelId.llama3_2_90b_vision_instruct.value, ), + build_model_alias( + "Meta-Llama-Guard-3-8B", + CoreModelId.llama_guard_3_8b.value, + ), ] @@ -197,12 +206,12 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): else: params["max_completion_tokens"] = sampling_params.max_tokens - if sampling_params.strategy == SamplingStrategy.top_p: - params["top_p"] = sampling_params.top_p - elif sampling_params.strategy == "top_k": - params["extra_body"]["top_k"] = sampling_params.top_k - elif sampling_params.strategy == "greedy": - params["temperature"] = sampling_params.temperature + if isinstance(sampling_params.strategy, TopPSamplingStrategy): + params["top_p"] = sampling_params.strategy.top_p + if isinstance(sampling_params.strategy, TopKSamplingStrategy): + params["extra_body"]["top_k"] = sampling_params.strategy.top_k + if isinstance(sampling_params.strategy, GreedySamplingStrategy): + params["temperature"] = 0.0 return params diff --git a/llama_stack/templates/sambanova/build.yaml b/llama_stack/templates/sambanova/build.yaml index ca5ffe618..0966bfdd9 100644 --- a/llama_stack/templates/sambanova/build.yaml +++ b/llama_stack/templates/sambanova/build.yaml @@ -14,9 +14,19 @@ distribution_spec: - inline::meta-reference telemetry: - inline::meta-reference + eval: + - inline::meta-reference + datasetio: + - remote::huggingface + - inline::localfs + scoring: + - inline::basic + - inline::llm-as-judge + - inline::braintrust tool_runtime: - remote::brave-search - remote::tavily-search - inline::code-interpreter - inline::rag-runtime + - remote::model-context-protocol image_type: conda diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index 31f47e0c1..c63b5d217 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -2,8 +2,11 @@ version: '2' image_name: sambanova apis: - agents +- datasetio +- eval - inference - safety +- scoring - telemetry - tool_runtime - vector_io @@ -22,12 +25,6 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/faiss_store.db - - provider_id: chromadb - provider_type: remote::chromadb - config: {} - - provider_id: pgvector - provider_type: remote::pgvector - config: {} safety: - provider_id: llama-guard provider_type: inline::llama-guard @@ -47,6 +44,28 @@ providers: service_name: ${env.OTEL_SERVICE_NAME:llama-stack} sinks: ${env.TELEMETRY_SINKS:console,sqlite} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/sambanova/trace_store.db} + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: {} + - provider_id: localfs + provider_type: inline::localfs + config: {} + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:} tool_runtime: - provider_id: brave-search provider_type: remote::brave-search @@ -64,42 +83,69 @@ providers: - provider_id: rag-runtime provider_type: inline::rag-runtime config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db models: - metadata: {} model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: sambanova provider_model_id: Meta-Llama-3.1-8B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.1-70B-Instruct + model_type: llm + provider_id: sambanova provider_model_id: Meta-Llama-3.1-70B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 + provider_id: sambanova provider_model_id: Meta-Llama-3.1-405B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-1B-Instruct + provider_id: sambanova provider_model_id: Meta-Llama-3.2-1B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-3B-Instruct + provider_id: sambanova provider_model_id: Meta-Llama-3.2-3B-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: sambanova provider_model_id: Llama-3.2-11B-Vision-Instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: sambanova provider_model_id: Llama-3.2-90B-Vision-Instruct model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: sambanova + provider_model_id: Llama-3.2-90B-Vision-Instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-Guard-3-8B + provider_id: sambanova + provider_model_id: Llama-Guard-3-8B + model_type: llm shields: - shield_id: meta-llama/Llama-Guard-3-8B vector_dbs: [] datasets: [] scoring_fns: [] eval_tasks: [] -tool_groups: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter