mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
fix: Adding Embedding model to watsonx inference (#2118)
# What does this PR do? Issue Link : https://github.com/meta-llama/llama-stack/issues/2117 ## Test Plan Once added, User will be able to use Sentence Transformer model `all-MiniLM-L6-v2`
This commit is contained in:
parent
136e6b3cf7
commit
c985ea6326
5 changed files with 36 additions and 6 deletions
|
@ -18,7 +18,7 @@ The `llamastack/distribution-watsonx` distribution consists of the following pro
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::watsonx` |
|
| inference | `remote::watsonx`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
|
|
@ -833,6 +833,8 @@
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"tree_sitter",
|
"tree_sitter",
|
||||||
"uvicorn"
|
"uvicorn",
|
||||||
|
"sentence-transformers --no-deps",
|
||||||
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ distribution_spec:
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::watsonx
|
- remote::watsonx
|
||||||
|
- inline::sentence-transformers
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
safety:
|
safety:
|
||||||
|
|
|
@ -18,6 +18,9 @@ providers:
|
||||||
url: ${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}
|
url: ${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}
|
||||||
api_key: ${env.WATSONX_API_KEY:}
|
api_key: ${env.WATSONX_API_KEY:}
|
||||||
project_id: ${env.WATSONX_PROJECT_ID:}
|
project_id: ${env.WATSONX_PROJECT_ID:}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
@ -191,6 +194,11 @@ models:
|
||||||
provider_id: watsonx
|
provider_id: watsonx
|
||||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
model_type: embedding
|
||||||
shields: []
|
shields: []
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
@ -6,7 +6,11 @@
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
|
from llama_stack.apis.models.models import ModelType
|
||||||
|
from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput
|
||||||
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
|
SentenceTransformersInferenceConfig,
|
||||||
|
)
|
||||||
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||||
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||||
|
@ -14,7 +18,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ["remote::watsonx"],
|
"inference": ["remote::watsonx", "inline::sentence-transformers"],
|
||||||
"vector_io": ["inline::faiss"],
|
"vector_io": ["inline::faiss"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
|
@ -36,6 +40,12 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
config=WatsonXConfig.sample_run_config(),
|
config=WatsonXConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
embedding_provider = Provider(
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
|
||||||
available_models = {
|
available_models = {
|
||||||
"watsonx": MODEL_ENTRIES,
|
"watsonx": MODEL_ENTRIES,
|
||||||
}
|
}
|
||||||
|
@ -50,6 +60,15 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
embedding_model = ModelInput(
|
||||||
|
model_id="all-MiniLM-L6-v2",
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 384,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
default_models = get_model_registry(available_models)
|
default_models = get_model_registry(available_models)
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name="watsonx",
|
name="watsonx",
|
||||||
|
@ -62,9 +81,9 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider],
|
"inference": [inference_provider, embedding_provider],
|
||||||
},
|
},
|
||||||
default_models=default_models,
|
default_models=default_models + [embedding_model],
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue