mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore: OpenAIMixin implements ModelsProtocolPrivate
This commit is contained in:
parent
ceca3c056f
commit
ad24a2c463
8 changed files with 243 additions and 11 deletions
|
@ -362,6 +362,124 @@ class TestOpenAIMixinAllowedModels:
|
|||
assert not await mixin.check_model_availability("another-mock-model-id")
|
||||
|
||||
|
||||
class TestOpenAIMixinModelRegistration:
|
||||
"""Test cases for model registration functionality"""
|
||||
|
||||
async def test_register_model_success(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test successful model registration when model is available"""
|
||||
model = Model(
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="some-mock-model-id",
|
||||
identifier="test-model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
result = await mixin.register_model(model)
|
||||
|
||||
assert result == model
|
||||
assert result.provider_id == "test-provider"
|
||||
assert result.provider_resource_id == "some-mock-model-id"
|
||||
assert result.identifier == "test-model"
|
||||
assert result.model_type == ModelType.llm
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
|
||||
async def test_register_model_not_available(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test model registration failure when model is not available from provider"""
|
||||
model = Model(
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="non-existent-model",
|
||||
identifier="test-model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
with pytest.raises(
|
||||
ValueError, match="Model non-existent-model is not available from provider test-provider"
|
||||
):
|
||||
await mixin.register_model(model)
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
|
||||
async def test_register_model_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test model registration with allowed_models filtering"""
|
||||
mixin.allowed_models = {"some-mock-model-id"}
|
||||
|
||||
# Test with allowed model
|
||||
allowed_model = Model(
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="some-mock-model-id",
|
||||
identifier="allowed-model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
# Test with disallowed model
|
||||
disallowed_model = Model(
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="final-mock-model-id",
|
||||
identifier="disallowed-model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
result = await mixin.register_model(allowed_model)
|
||||
assert result == allowed_model
|
||||
with pytest.raises(
|
||||
ValueError, match="Model final-mock-model-id is not available from provider test-provider"
|
||||
):
|
||||
await mixin.register_model(disallowed_model)
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
|
||||
async def test_register_embedding_model(self, mixin_with_embeddings, mock_client_context):
|
||||
"""Test registration of embedding models with metadata"""
|
||||
mock_embedding_model = MagicMock(id="text-embedding-3-small")
|
||||
mock_models = [mock_embedding_model]
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
async def mock_models_list():
|
||||
for model in mock_models:
|
||||
yield model
|
||||
|
||||
mock_client.models.list.return_value = mock_models_list()
|
||||
|
||||
embedding_model = Model(
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="text-embedding-3-small",
|
||||
identifier="embedding-test",
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
with mock_client_context(mixin_with_embeddings, mock_client):
|
||||
result = await mixin_with_embeddings.register_model(embedding_model)
|
||||
assert result == embedding_model
|
||||
assert result.model_type == ModelType.embedding
|
||||
|
||||
async def test_unregister_model(self, mixin):
|
||||
"""Test model unregistration (should be no-op)"""
|
||||
# unregister_model should not raise any exceptions and return None
|
||||
result = await mixin.unregister_model("any-model-id")
|
||||
assert result is None
|
||||
|
||||
async def test_should_refresh_models(self, mixin):
|
||||
"""Test should_refresh_models method (should always return False)"""
|
||||
result = await mixin.should_refresh_models()
|
||||
assert result is False
|
||||
|
||||
async def test_register_model_error_propagation(self, mixin, mock_client_with_exception, mock_client_context):
|
||||
"""Test that errors from provider API are properly propagated during registration"""
|
||||
model = Model(
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="some-model",
|
||||
identifier="test-model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_exception):
|
||||
# The exception from the API should be propagated
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
await mixin.register_model(model)
|
||||
|
||||
|
||||
class ProviderDataValidator(BaseModel):
|
||||
"""Validator for provider data in tests"""
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue