mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Merge 309f06829c
into a1301911e4
This commit is contained in:
commit
2c52ab8944
2 changed files with 78 additions and 1 deletions
|
@ -35,6 +35,11 @@ class ProviderModelEntry(BaseModel):
|
|||
llama_model: str | None = None
|
||||
model_type: ModelType = ModelType.llm
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
if self.model_type == ModelType.embedding and "embedding_dimension" not in self.metadata:
|
||||
raise ValueError("Embedding models must specify 'embedding_dimension' in metadata")
|
||||
|
||||
|
||||
def get_huggingface_repo(model_descriptor: str) -> str | None:
|
||||
|
@ -103,7 +108,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
Model(
|
||||
identifier=id,
|
||||
provider_resource_id=entry.provider_model_id,
|
||||
model_type=ModelType.llm,
|
||||
model_type=entry.model_type or ModelType.llm,
|
||||
metadata=entry.metadata,
|
||||
provider_id=self.__provider_id__,
|
||||
)
|
||||
|
|
72
tests/unit/providers/inference/test_model_registry.py
Normal file
72
tests/unit/providers/inference/test_model_registry.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry, ModelType
|
||||
|
||||
class TestConfig(BaseModel):
|
||||
api_key: str | None = Field(default=None)
|
||||
|
||||
class TestProviderDataValidator(BaseModel):
|
||||
test_api_key: str | None = Field(default=None)
|
||||
|
||||
MODEL_ENTRIES_WITHOUT_ALIASES = [
|
||||
ProviderModelEntry(model_type=ModelType.llm, provider_model_id="test-llm-model", aliases=[]),
|
||||
ProviderModelEntry(model_type=ModelType.embedding, provider_model_id="test-text-embedding-model", aliases=[], metadata={"embedding_dimension": 1536, "context_length": 8192}),
|
||||
]
|
||||
|
||||
class TestLiteLLMAdapterWithModelEntries(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: TestConfig):
|
||||
super().__init__(
|
||||
model_entries=MODEL_ENTRIES_WITHOUT_ALIASES,
|
||||
litellm_provider_name="test",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="test_api_key",
|
||||
openai_compat_api_base=None,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def adapter_with_model_entries():
|
||||
"""Fixture to create adapter with API key in config"""
|
||||
config = TestConfig()
|
||||
adapter = TestLiteLLMAdapterWithModelEntries(config)
|
||||
adapter.__provider_id__ = "test-provider"
|
||||
|
||||
return adapter
|
||||
|
||||
async def test_model_types_are_correct(adapter_with_model_entries):
|
||||
"""Test that model types are correct"""
|
||||
model_entries = adapter_with_model_entries.model_entries
|
||||
llm_model_entries = [m for m in model_entries if m.model_type == ModelType.llm]
|
||||
assert len(llm_model_entries) == 1
|
||||
|
||||
embedding_model_entries = [m for m in model_entries if m.model_type == ModelType.embedding]
|
||||
assert len(embedding_model_entries) == 1
|
||||
|
||||
models = await adapter_with_model_entries.list_models()
|
||||
llm_models = [m for m in models if m.model_type == ModelType.llm]
|
||||
assert len(llm_models) == len(llm_model_entries)
|
||||
|
||||
embedding_models = [m for m in models if m.model_type == ModelType.embedding]
|
||||
assert len(embedding_models) == len(embedding_model_entries)
|
||||
|
||||
def test_embedding_metadata_is_required():
|
||||
with pytest.raises(ValueError):
|
||||
entry1 = ProviderModelEntry(
|
||||
model_type=ModelType.embedding,
|
||||
provider_model_id="test-text-embedding-model",
|
||||
aliases=[],
|
||||
metadata={}
|
||||
)
|
||||
|
||||
entry2 = ProviderModelEntry(
|
||||
model_type=ModelType.embedding,
|
||||
provider_model_id="test-text-embedding-model",
|
||||
aliases=[],
|
||||
metadata={"embedding_dimension": 1536}
|
||||
)
|
||||
assert entry2.metadata["embedding_dimension"] == 1536
|
Loading…
Add table
Add a link
Reference in a new issue