mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
chore: OpenAIMixin implements ModelsProtocolPrivate (#3662)
# What does this PR do? add ModelsProtocolPrivate methods to OpenAIMixin this will allow providers using OpenAIMixin to use a common interface ## Test Plan ci w/ new tests
This commit is contained in:
parent
14a94e9894
commit
0a41c4ead0
8 changed files with 243 additions and 11 deletions
|
@ -25,9 +25,6 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
TopKSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
|
@ -44,7 +41,6 @@ from .config import CerebrasImplConfig
|
|||
|
||||
class CerebrasInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
):
|
||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||
|
|
|
@ -44,7 +44,7 @@ from .config import FireworksImplConfig
|
|||
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
class FireworksInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData):
|
||||
embedding_model_metadata = {
|
||||
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
|
||||
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
|
||||
|
|
|
@ -29,7 +29,6 @@ from llama_stack.apis.models import Model
|
|||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
|
@ -65,7 +64,6 @@ def build_hf_repo_model_entries():
|
|||
class _HfAdapter(
|
||||
OpenAIMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
url: str
|
||||
api_key: SecretStr
|
||||
|
|
|
@ -47,7 +47,7 @@ from .config import TogetherImplConfig
|
|||
logger = get_logger(name=__name__, category="inference::together")
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
class TogetherInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData):
|
||||
embedding_model_metadata = {
|
||||
"togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768},
|
||||
"BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512},
|
||||
|
|
|
@ -26,14 +26,14 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
||||
class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
|
||||
"""
|
||||
Mixin class that provides OpenAI-specific functionality for inference providers.
|
||||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||
|
@ -73,6 +73,9 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
|||
# Optional field name in provider data to look for API key, which takes precedence
|
||||
provider_data_api_key_field: str | None = None
|
||||
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
|
@ -356,6 +359,24 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
|||
usage=usage,
|
||||
)
|
||||
|
||||
###
|
||||
# ModelsProtocolPrivate implementation - provide model management functionality
|
||||
#
|
||||
# async def register_model(self, model: Model) -> Model: ...
|
||||
# async def unregister_model(self, model_id: str) -> None: ...
|
||||
#
|
||||
# async def list_models(self) -> list[Model] | None: ...
|
||||
# async def should_refresh_models(self) -> bool: ...
|
||||
##
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if not await self.check_model_availability(model.provider_model_id):
|
||||
raise ValueError(f"Model {model.provider_model_id} is not available from provider {self.__provider_id__}")
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
return None
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
"""
|
||||
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
|
||||
|
@ -400,5 +421,7 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
|||
"""
|
||||
if not self._model_cache:
|
||||
await self.list_models()
|
||||
|
||||
return model in self._model_cache
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
|
57
tests/integration/recordings/responses/39576bcd7ed6.json
Normal file
57
tests/integration/recordings/responses/39576bcd7ed6.json
Normal file
|
@ -0,0 +1,57 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama-guard3:1b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Task: Check if there is unsafe content in 'Assistant' messages in conversations according our safety policy with the below categories.\n\n<BEGIN UNSAFE CONTENT CATEGORIES>\nS1: Violent Crimes.\nS2: Non-Violent Crimes.\nS3: Sex Crimes.\nS4: Child Exploitation.\nS5: Defamation.\nS6: Specialized Advice.\nS7: Privacy.\nS8: Intellectual Property.\nS9: Indiscriminate Weapons.\nS10: Hate.\nS11: Self-Harm.\nS12: Sexual Content.\nS13: Elections.\nS14: Code Interpreter Abuse.\n<END UNSAFE CONTENT CATEGORIES>\n\n<BEGIN CONVERSATION>\n\nUser: Give me a sentence that contains the word: hello\n\nAssistant: \n\n<END CONVERSATION>\n\nProvide your safety assessment for ONLY THE LAST Assistant message in the above conversation:\n - First line must read 'safe' or 'unsafe'.\n - If unsafe, a second line must include a comma-separated list of violated categories."
|
||||
}
|
||||
],
|
||||
"stream": false,
|
||||
"temperature": 0.0
|
||||
},
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama-guard3:1b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
|
||||
"__data__": {
|
||||
"id": "chatcmpl-317",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "safe",
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"annotations": null,
|
||||
"audio": null,
|
||||
"function_call": null,
|
||||
"tool_calls": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1759351124,
|
||||
"model": "llama-guard3:1b",
|
||||
"object": "chat.completion",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": {
|
||||
"completion_tokens": 2,
|
||||
"prompt_tokens": 397,
|
||||
"total_tokens": 399,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
40
tests/integration/recordings/responses/53d2488c9ea9.json
Normal file
40
tests/integration/recordings/responses/53d2488c9ea9.json
Normal file
|
@ -0,0 +1,40 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/api/generate",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"options": {
|
||||
"temperature": 0.0001,
|
||||
"top_p": 0.9
|
||||
},
|
||||
"stream": true
|
||||
},
|
||||
"endpoint": "/api/generate",
|
||||
"model": "llama3.2:3b-instruct-fp16"
|
||||
},
|
||||
"response": {
|
||||
"body": [
|
||||
{
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-10-01T20:38:48.732564955Z",
|
||||
"done": true,
|
||||
"done_reason": "load",
|
||||
"total_duration": null,
|
||||
"load_duration": null,
|
||||
"prompt_eval_count": null,
|
||||
"prompt_eval_duration": null,
|
||||
"eval_count": null,
|
||||
"eval_duration": null,
|
||||
"response": "",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"is_streaming": true
|
||||
}
|
||||
}
|
|
@ -362,6 +362,124 @@ class TestOpenAIMixinAllowedModels:
|
|||
assert not await mixin.check_model_availability("another-mock-model-id")
|
||||
|
||||
|
||||
class TestOpenAIMixinModelRegistration:
|
||||
"""Test cases for model registration functionality"""
|
||||
|
||||
async def test_register_model_success(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test successful model registration when model is available"""
|
||||
model = Model(
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="some-mock-model-id",
|
||||
identifier="test-model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
result = await mixin.register_model(model)
|
||||
|
||||
assert result == model
|
||||
assert result.provider_id == "test-provider"
|
||||
assert result.provider_resource_id == "some-mock-model-id"
|
||||
assert result.identifier == "test-model"
|
||||
assert result.model_type == ModelType.llm
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
|
||||
async def test_register_model_not_available(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test model registration failure when model is not available from provider"""
|
||||
model = Model(
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="non-existent-model",
|
||||
identifier="test-model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
with pytest.raises(
|
||||
ValueError, match="Model non-existent-model is not available from provider test-provider"
|
||||
):
|
||||
await mixin.register_model(model)
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
|
||||
async def test_register_model_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test model registration with allowed_models filtering"""
|
||||
mixin.allowed_models = {"some-mock-model-id"}
|
||||
|
||||
# Test with allowed model
|
||||
allowed_model = Model(
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="some-mock-model-id",
|
||||
identifier="allowed-model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
# Test with disallowed model
|
||||
disallowed_model = Model(
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="final-mock-model-id",
|
||||
identifier="disallowed-model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
result = await mixin.register_model(allowed_model)
|
||||
assert result == allowed_model
|
||||
with pytest.raises(
|
||||
ValueError, match="Model final-mock-model-id is not available from provider test-provider"
|
||||
):
|
||||
await mixin.register_model(disallowed_model)
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
|
||||
async def test_register_embedding_model(self, mixin_with_embeddings, mock_client_context):
|
||||
"""Test registration of embedding models with metadata"""
|
||||
mock_embedding_model = MagicMock(id="text-embedding-3-small")
|
||||
mock_models = [mock_embedding_model]
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
async def mock_models_list():
|
||||
for model in mock_models:
|
||||
yield model
|
||||
|
||||
mock_client.models.list.return_value = mock_models_list()
|
||||
|
||||
embedding_model = Model(
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="text-embedding-3-small",
|
||||
identifier="embedding-test",
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
with mock_client_context(mixin_with_embeddings, mock_client):
|
||||
result = await mixin_with_embeddings.register_model(embedding_model)
|
||||
assert result == embedding_model
|
||||
assert result.model_type == ModelType.embedding
|
||||
|
||||
async def test_unregister_model(self, mixin):
|
||||
"""Test model unregistration (should be no-op)"""
|
||||
# unregister_model should not raise any exceptions and return None
|
||||
result = await mixin.unregister_model("any-model-id")
|
||||
assert result is None
|
||||
|
||||
async def test_should_refresh_models(self, mixin):
|
||||
"""Test should_refresh_models method (should always return False)"""
|
||||
result = await mixin.should_refresh_models()
|
||||
assert result is False
|
||||
|
||||
async def test_register_model_error_propagation(self, mixin, mock_client_with_exception, mock_client_context):
|
||||
"""Test that errors from provider API are properly propagated during registration"""
|
||||
model = Model(
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="some-model",
|
||||
identifier="test-model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_exception):
|
||||
# The exception from the API should be propagated
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
await mixin.register_model(model)
|
||||
|
||||
|
||||
class ProviderDataValidator(BaseModel):
|
||||
"""Validator for provider data in tests"""
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue