From c985ea6326a9b429057d0e12b21ec0691074da35 Mon Sep 17 00:00:00 2001 From: Divya <117009486+divyaruhil@users.noreply.github.com> Date: Mon, 12 May 2025 23:28:22 +0530 Subject: [PATCH] fix: Adding Embedding model to watsonx inference (#2118) # What does this PR do? Issue Link : https://github.com/meta-llama/llama-stack/issues/2117 ## Test Plan Once added, User will be able to use Sentence Transformer model `all-MiniLM-L6-v2` --- .../remote_hosted_distro/watsonx.md | 2 +- llama_stack/templates/dependencies.json | 4 ++- llama_stack/templates/watsonx/build.yaml | 1 + llama_stack/templates/watsonx/run.yaml | 8 ++++++ llama_stack/templates/watsonx/watsonx.py | 27 ++++++++++++++++--- 5 files changed, 36 insertions(+), 6 deletions(-) diff --git a/docs/source/distributions/remote_hosted_distro/watsonx.md b/docs/source/distributions/remote_hosted_distro/watsonx.md index b7c89e9b0..d8d327bb5 100644 --- a/docs/source/distributions/remote_hosted_distro/watsonx.md +++ b/docs/source/distributions/remote_hosted_distro/watsonx.md @@ -18,7 +18,7 @@ The `llamastack/distribution-watsonx` distribution consists of the following pro | agents | `inline::meta-reference` | | datasetio | `remote::huggingface`, `inline::localfs` | | eval | `inline::meta-reference` | -| inference | `remote::watsonx` | +| inference | `remote::watsonx`, `inline::sentence-transformers` | | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index 31f2b93f1..35cbc8878 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -833,6 +833,8 @@ "tqdm", "transformers", "tree_sitter", - "uvicorn" + "uvicorn", + "sentence-transformers --no-deps", + "torch torchvision --index-url https://download.pytorch.org/whl/cpu" ] } diff --git a/llama_stack/templates/watsonx/build.yaml b/llama_stack/templates/watsonx/build.yaml index 23a1ffa74..638b16029 100644 --- a/llama_stack/templates/watsonx/build.yaml +++ b/llama_stack/templates/watsonx/build.yaml @@ -4,6 +4,7 @@ distribution_spec: providers: inference: - remote::watsonx + - inline::sentence-transformers vector_io: - inline::faiss safety: diff --git a/llama_stack/templates/watsonx/run.yaml b/llama_stack/templates/watsonx/run.yaml index 82d3b2c6e..50904b7e9 100644 --- a/llama_stack/templates/watsonx/run.yaml +++ b/llama_stack/templates/watsonx/run.yaml @@ -18,6 +18,9 @@ providers: url: ${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com} api_key: ${env.WATSONX_API_KEY:} project_id: ${env.WATSONX_PROJECT_ID:} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} vector_io: - provider_id: faiss provider_type: inline::faiss @@ -191,6 +194,11 @@ models: provider_id: watsonx provider_model_id: meta-llama/llama-guard-3-11b-vision model_type: llm +- metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: sentence-transformers + model_type: embedding shields: [] vector_dbs: [] datasets: [] diff --git a/llama_stack/templates/watsonx/watsonx.py b/llama_stack/templates/watsonx/watsonx.py index f16593051..802aaf8f1 100644 --- a/llama_stack/templates/watsonx/watsonx.py +++ b/llama_stack/templates/watsonx/watsonx.py @@ -6,7 +6,11 @@ from pathlib import Path -from llama_stack.distribution.datatypes import Provider, ToolGroupInput +from llama_stack.apis.models.models import ModelType +from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput +from llama_stack.providers.inline.inference.sentence_transformers import ( + SentenceTransformersInferenceConfig, +) from llama_stack.providers.remote.inference.watsonx import WatsonXConfig from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry @@ -14,7 +18,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::watsonx"], + "inference": ["remote::watsonx", "inline::sentence-transformers"], "vector_io": ["inline::faiss"], "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], @@ -36,6 +40,12 @@ def get_distribution_template() -> DistributionTemplate: config=WatsonXConfig.sample_run_config(), ) + embedding_provider = Provider( + provider_id="sentence-transformers", + provider_type="inline::sentence-transformers", + config=SentenceTransformersInferenceConfig.sample_run_config(), + ) + available_models = { "watsonx": MODEL_ENTRIES, } @@ -50,6 +60,15 @@ def get_distribution_template() -> DistributionTemplate: ), ] + embedding_model = ModelInput( + model_id="all-MiniLM-L6-v2", + provider_id="sentence-transformers", + model_type=ModelType.embedding, + metadata={ + "embedding_dimension": 384, + }, + ) + default_models = get_model_registry(available_models) return DistributionTemplate( name="watsonx", @@ -62,9 +81,9 @@ def get_distribution_template() -> DistributionTemplate: run_configs={ "run.yaml": RunConfigSettings( provider_overrides={ - "inference": [inference_provider], + "inference": [inference_provider, embedding_provider], }, - default_models=default_models, + default_models=default_models + [embedding_model], default_tool_groups=default_tool_groups, ), },