mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 07:12:37 +00:00
fix test
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> updating structure of default Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> fix model id creation Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
b3addc94d1
commit
7ffd20d112
10 changed files with 119 additions and 62 deletions
|
|
@ -11,7 +11,7 @@ from unittest.mock import AsyncMock
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.core.datatypes import StackRunConfig, VectorStoresConfig
|
||||
from llama_stack.core.datatypes import DefaultEmbeddingModel, StackRunConfig, VectorStoresConfig
|
||||
from llama_stack.core.stack import validate_vector_stores_config
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
|
@ -20,7 +20,15 @@ class TestVectorStoresValidation:
|
|||
async def test_validate_missing_model(self):
|
||||
"""Test validation fails when model not found."""
|
||||
run_config = StackRunConfig(
|
||||
image_name="test", providers={}, vector_stores=VectorStoresConfig(embedding_model_id="missing")
|
||||
image_name="test",
|
||||
providers={},
|
||||
vector_stores=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=DefaultEmbeddingModel(
|
||||
provider_id="p",
|
||||
model_id="missing",
|
||||
),
|
||||
),
|
||||
)
|
||||
mock_models = AsyncMock()
|
||||
mock_models.list_models.return_value = []
|
||||
|
|
@ -31,12 +39,20 @@ class TestVectorStoresValidation:
|
|||
async def test_validate_success(self):
|
||||
"""Test validation passes with valid model."""
|
||||
run_config = StackRunConfig(
|
||||
image_name="test", providers={}, vector_stores=VectorStoresConfig(embedding_model_id="valid")
|
||||
image_name="test",
|
||||
providers={},
|
||||
vector_stores=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=DefaultEmbeddingModel(
|
||||
provider_id="p",
|
||||
model_id="valid",
|
||||
),
|
||||
),
|
||||
)
|
||||
mock_models = AsyncMock()
|
||||
mock_models.list_models.return_value = [
|
||||
Model(
|
||||
identifier="valid",
|
||||
identifier="p/valid", # Must match provider_id/model_id format
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768},
|
||||
provider_id="p",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue