mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-18 07:18:53 +00:00
feat: Enable setting a default embedding model in the stack (#3803)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 5s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
UI Tests / ui-tests (22) (push) Successful in 40s
Pre-commit / pre-commit (push) Successful in 1m28s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 5s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
UI Tests / ui-tests (22) (push) Successful in 40s
Pre-commit / pre-commit (push) Successful in 1m28s
# What does this PR do? Enables automatic embedding model detection for vector stores and by using a `default_configured` boolean that can be defined in the `run.yaml`. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan - Unit tests - Integration tests - Simple example below: Spin up the stack: ```bash uv run llama stack build --distro starter --image-type venv --run ``` Then test with OpenAI's client: ```python from openai import OpenAI client = OpenAI(base_url="http://localhost:8321/v1/", api_key="none") vs = client.vector_stores.create() ``` Previously you needed: ```python vs = client.vector_stores.create( extra_body={ "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", "embedding_dimension": 384, } ) ``` The `extra_body` is now unnecessary. --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
d875e427bf
commit
ef4bc70bbe
29 changed files with 553 additions and 403 deletions
|
@ -98,6 +98,30 @@ REGISTRY_REFRESH_TASK = None
|
|||
TEST_RECORDING_CONTEXT = None
|
||||
|
||||
|
||||
async def validate_default_embedding_model(impls: dict[Api, Any]):
|
||||
"""Validate that at most one embedding model is marked as default."""
|
||||
if Api.models not in impls:
|
||||
return
|
||||
|
||||
models_impl = impls[Api.models]
|
||||
response = await models_impl.list_models()
|
||||
models_list = response.data if hasattr(response, "data") else response
|
||||
|
||||
default_embedding_models = []
|
||||
for model in models_list:
|
||||
if model.model_type == "embedding" and model.metadata.get("default_configured") is True:
|
||||
default_embedding_models.append(model.identifier)
|
||||
|
||||
if len(default_embedding_models) > 1:
|
||||
raise ValueError(
|
||||
f"Multiple embedding models marked as default_configured=True: {default_embedding_models}. "
|
||||
"Only one embedding model can be marked as default."
|
||||
)
|
||||
|
||||
if default_embedding_models:
|
||||
logger.info(f"Default embedding model configured: {default_embedding_models[0]}")
|
||||
|
||||
|
||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config, rsrc)
|
||||
|
@ -128,6 +152,8 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
||||
)
|
||||
|
||||
await validate_default_embedding_model(impls)
|
||||
|
||||
|
||||
class EnvVarError(Exception):
|
||||
def __init__(self, var_name: str, path: str = ""):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue