chore: OpenAIMixin implements ModelsProtocolPrivate (#3662)

# What does this PR do?

add ModelsProtocolPrivate methods to OpenAIMixin

this will allow providers using OpenAIMixin to use a common interface


## Test Plan

ci w/ new tests
This commit is contained in:
Matthew Farrellee 2025-10-03 00:32:02 -04:00 committed by GitHub
parent 14a94e9894
commit 0a41c4ead0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 243 additions and 11 deletions

View file

@ -0,0 +1,57 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama-guard3:1b",
"messages": [
{
"role": "user",
"content": "Task: Check if there is unsafe content in 'Assistant' messages in conversations according our safety policy with the below categories.\n\n<BEGIN UNSAFE CONTENT CATEGORIES>\nS1: Violent Crimes.\nS2: Non-Violent Crimes.\nS3: Sex Crimes.\nS4: Child Exploitation.\nS5: Defamation.\nS6: Specialized Advice.\nS7: Privacy.\nS8: Intellectual Property.\nS9: Indiscriminate Weapons.\nS10: Hate.\nS11: Self-Harm.\nS12: Sexual Content.\nS13: Elections.\nS14: Code Interpreter Abuse.\n<END UNSAFE CONTENT CATEGORIES>\n\n<BEGIN CONVERSATION>\n\nUser: Give me a sentence that contains the word: hello\n\nAssistant: \n\n<END CONVERSATION>\n\nProvide your safety assessment for ONLY THE LAST Assistant message in the above conversation:\n - First line must read 'safe' or 'unsafe'.\n - If unsafe, a second line must include a comma-separated list of violated categories."
}
],
"stream": false,
"temperature": 0.0
},
"endpoint": "/v1/chat/completions",
"model": "llama-guard3:1b"
},
"response": {
"body": {
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
"__data__": {
"id": "chatcmpl-317",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "safe",
"refusal": null,
"role": "assistant",
"annotations": null,
"audio": null,
"function_call": null,
"tool_calls": null
}
}
],
"created": 1759351124,
"model": "llama-guard3:1b",
"object": "chat.completion",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": {
"completion_tokens": 2,
"prompt_tokens": 397,
"total_tokens": 399,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

@ -0,0 +1,40 @@
{
"request": {
"method": "POST",
"url": "http://localhost:11434/api/generate",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"options": {
"temperature": 0.0001,
"top_p": 0.9
},
"stream": true
},
"endpoint": "/api/generate",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": [
{
"__type__": "ollama._types.GenerateResponse",
"__data__": {
"model": "llama3.2:3b-instruct-fp16",
"created_at": "2025-10-01T20:38:48.732564955Z",
"done": true,
"done_reason": "load",
"total_duration": null,
"load_duration": null,
"prompt_eval_count": null,
"prompt_eval_duration": null,
"eval_count": null,
"eval_duration": null,
"response": "",
"thinking": null,
"context": null
}
}
],
"is_streaming": true
}
}

View file

@ -362,6 +362,124 @@ class TestOpenAIMixinAllowedModels:
assert not await mixin.check_model_availability("another-mock-model-id")
class TestOpenAIMixinModelRegistration:
"""Test cases for model registration functionality"""
async def test_register_model_success(self, mixin, mock_client_with_models, mock_client_context):
"""Test successful model registration when model is available"""
model = Model(
provider_id="test-provider",
provider_resource_id="some-mock-model-id",
identifier="test-model",
model_type=ModelType.llm,
)
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.register_model(model)
assert result == model
assert result.provider_id == "test-provider"
assert result.provider_resource_id == "some-mock-model-id"
assert result.identifier == "test-model"
assert result.model_type == ModelType.llm
mock_client_with_models.models.list.assert_called_once()
async def test_register_model_not_available(self, mixin, mock_client_with_models, mock_client_context):
"""Test model registration failure when model is not available from provider"""
model = Model(
provider_id="test-provider",
provider_resource_id="non-existent-model",
identifier="test-model",
model_type=ModelType.llm,
)
with mock_client_context(mixin, mock_client_with_models):
with pytest.raises(
ValueError, match="Model non-existent-model is not available from provider test-provider"
):
await mixin.register_model(model)
mock_client_with_models.models.list.assert_called_once()
async def test_register_model_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
"""Test model registration with allowed_models filtering"""
mixin.allowed_models = {"some-mock-model-id"}
# Test with allowed model
allowed_model = Model(
provider_id="test-provider",
provider_resource_id="some-mock-model-id",
identifier="allowed-model",
model_type=ModelType.llm,
)
# Test with disallowed model
disallowed_model = Model(
provider_id="test-provider",
provider_resource_id="final-mock-model-id",
identifier="disallowed-model",
model_type=ModelType.llm,
)
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.register_model(allowed_model)
assert result == allowed_model
with pytest.raises(
ValueError, match="Model final-mock-model-id is not available from provider test-provider"
):
await mixin.register_model(disallowed_model)
mock_client_with_models.models.list.assert_called_once()
async def test_register_embedding_model(self, mixin_with_embeddings, mock_client_context):
"""Test registration of embedding models with metadata"""
mock_embedding_model = MagicMock(id="text-embedding-3-small")
mock_models = [mock_embedding_model]
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
embedding_model = Model(
provider_id="test-provider",
provider_resource_id="text-embedding-3-small",
identifier="embedding-test",
model_type=ModelType.embedding,
)
with mock_client_context(mixin_with_embeddings, mock_client):
result = await mixin_with_embeddings.register_model(embedding_model)
assert result == embedding_model
assert result.model_type == ModelType.embedding
async def test_unregister_model(self, mixin):
"""Test model unregistration (should be no-op)"""
# unregister_model should not raise any exceptions and return None
result = await mixin.unregister_model("any-model-id")
assert result is None
async def test_should_refresh_models(self, mixin):
"""Test should_refresh_models method (should always return False)"""
result = await mixin.should_refresh_models()
assert result is False
async def test_register_model_error_propagation(self, mixin, mock_client_with_exception, mock_client_context):
"""Test that errors from provider API are properly propagated during registration"""
model = Model(
provider_id="test-provider",
provider_resource_id="some-model",
identifier="test-model",
model_type=ModelType.llm,
)
with mock_client_context(mixin, mock_client_with_exception):
# The exception from the API should be propagated
with pytest.raises(Exception, match="API Error"):
await mixin.register_model(model)
class ProviderDataValidator(BaseModel):
"""Validator for provider data in tests"""