ensure embedding models have dimensions

This commit is contained in:
melonkernel 2025-09-04 14:01:51 +03:00
parent 2a478fb1d5
commit 27a6918834
2 changed files with 24 additions and 4 deletions

View file

@ -36,6 +36,11 @@ class ProviderModelEntry(BaseModel):
model_type: ModelType = ModelType.llm model_type: ModelType = ModelType.llm
metadata: dict[str, Any] = Field(default_factory=dict) metadata: dict[str, Any] = Field(default_factory=dict)
def __init__(self, **data):
super().__init__(**data)
if self.model_type == ModelType.embedding and "embedding_dimension" not in self.metadata:
raise ValueError("Embedding models must specify 'embedding_dimension' in metadata")
def get_huggingface_repo(model_descriptor: str) -> str | None: def get_huggingface_repo(model_descriptor: str) -> str | None:
for model in all_registered_models(): for model in all_registered_models():

View file

@ -38,7 +38,6 @@ def adapter_with_model_entries():
return adapter return adapter
async def test_model_types_are_correct(adapter_with_model_entries): async def test_model_types_are_correct(adapter_with_model_entries):
"""Test that model types are correct""" """Test that model types are correct"""
model_entries = adapter_with_model_entries.model_entries model_entries = adapter_with_model_entries.model_entries
@ -55,3 +54,19 @@ async def test_model_types_are_correct(adapter_with_model_entries):
embedding_models = [m for m in models if m.model_type == ModelType.embedding] embedding_models = [m for m in models if m.model_type == ModelType.embedding]
assert len(embedding_models) == len(embedding_model_entries) assert len(embedding_models) == len(embedding_model_entries)
def test_embedding_metadata_is_required():
with pytest.raises(ValueError):
entry1 = ProviderModelEntry(
model_type=ModelType.embedding,
provider_model_id="test-text-embedding-model",
aliases=[],
metadata={}
)
entry2 = ProviderModelEntry(
model_type=ModelType.embedding,
provider_model_id="test-text-embedding-model",
aliases=[],
metadata={"embedding_dimension": 1536}
)
assert entry2.metadata["embedding_dimension"] == 1536