mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
ensure embedding models have dimensions
This commit is contained in:
parent
2a478fb1d5
commit
27a6918834
2 changed files with 24 additions and 4 deletions
|
@ -36,6 +36,11 @@ class ProviderModelEntry(BaseModel):
|
||||||
model_type: ModelType = ModelType.llm
|
model_type: ModelType = ModelType.llm
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
if self.model_type == ModelType.embedding and "embedding_dimension" not in self.metadata:
|
||||||
|
raise ValueError("Embedding models must specify 'embedding_dimension' in metadata")
|
||||||
|
|
||||||
|
|
||||||
def get_huggingface_repo(model_descriptor: str) -> str | None:
|
def get_huggingface_repo(model_descriptor: str) -> str | None:
|
||||||
for model in all_registered_models():
|
for model in all_registered_models():
|
||||||
|
|
|
@ -17,7 +17,7 @@ class TestProviderDataValidator(BaseModel):
|
||||||
MODEL_ENTRIES_WITHOUT_ALIASES = [
|
MODEL_ENTRIES_WITHOUT_ALIASES = [
|
||||||
ProviderModelEntry(model_type=ModelType.llm, provider_model_id="test-llm-model", aliases=[]),
|
ProviderModelEntry(model_type=ModelType.llm, provider_model_id="test-llm-model", aliases=[]),
|
||||||
ProviderModelEntry(model_type=ModelType.embedding, provider_model_id="test-text-embedding-model", aliases=[], metadata={"embedding_dimension": 1536, "context_length": 8192}),
|
ProviderModelEntry(model_type=ModelType.embedding, provider_model_id="test-text-embedding-model", aliases=[], metadata={"embedding_dimension": 1536, "context_length": 8192}),
|
||||||
]
|
]
|
||||||
|
|
||||||
class TestLiteLLMAdapterWithModelEntries(LiteLLMOpenAIMixin):
|
class TestLiteLLMAdapterWithModelEntries(LiteLLMOpenAIMixin):
|
||||||
def __init__(self, config: TestConfig):
|
def __init__(self, config: TestConfig):
|
||||||
|
@ -38,7 +38,6 @@ def adapter_with_model_entries():
|
||||||
|
|
||||||
return adapter
|
return adapter
|
||||||
|
|
||||||
|
|
||||||
async def test_model_types_are_correct(adapter_with_model_entries):
|
async def test_model_types_are_correct(adapter_with_model_entries):
|
||||||
"""Test that model types are correct"""
|
"""Test that model types are correct"""
|
||||||
model_entries = adapter_with_model_entries.model_entries
|
model_entries = adapter_with_model_entries.model_entries
|
||||||
|
@ -55,3 +54,19 @@ async def test_model_types_are_correct(adapter_with_model_entries):
|
||||||
embedding_models = [m for m in models if m.model_type == ModelType.embedding]
|
embedding_models = [m for m in models if m.model_type == ModelType.embedding]
|
||||||
assert len(embedding_models) == len(embedding_model_entries)
|
assert len(embedding_models) == len(embedding_model_entries)
|
||||||
|
|
||||||
|
def test_embedding_metadata_is_required():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
entry1 = ProviderModelEntry(
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
provider_model_id="test-text-embedding-model",
|
||||||
|
aliases=[],
|
||||||
|
metadata={}
|
||||||
|
)
|
||||||
|
|
||||||
|
entry2 = ProviderModelEntry(
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
provider_model_id="test-text-embedding-model",
|
||||||
|
aliases=[],
|
||||||
|
metadata={"embedding_dimension": 1536}
|
||||||
|
)
|
||||||
|
assert entry2.metadata["embedding_dimension"] == 1536
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue