mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
chore: OpenAIMixin implements ModelsProtocolPrivate (#3662)
# What does this PR do? add ModelsProtocolPrivate methods to OpenAIMixin this will allow providers using OpenAIMixin to use a common interface ## Test Plan ci w/ new tests
This commit is contained in:
parent
14a94e9894
commit
0a41c4ead0
8 changed files with 243 additions and 11 deletions
|
@ -25,9 +25,6 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
TopKSamplingStrategy,
|
TopKSamplingStrategy,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
|
||||||
ModelRegistryHelper,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
|
@ -44,7 +41,6 @@ from .config import CerebrasImplConfig
|
||||||
|
|
||||||
class CerebrasInferenceAdapter(
|
class CerebrasInferenceAdapter(
|
||||||
OpenAIMixin,
|
OpenAIMixin,
|
||||||
ModelRegistryHelper,
|
|
||||||
Inference,
|
Inference,
|
||||||
):
|
):
|
||||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||||
|
|
|
@ -44,7 +44,7 @@ from .config import FireworksImplConfig
|
||||||
logger = get_logger(name=__name__, category="inference::fireworks")
|
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||||
|
|
||||||
|
|
||||||
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
class FireworksInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData):
|
||||||
embedding_model_metadata = {
|
embedding_model_metadata = {
|
||||||
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
|
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
|
||||||
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
|
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
|
||||||
|
|
|
@ -29,7 +29,6 @@ from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.models.models import ModelType
|
from llama_stack.apis.models.models import ModelType
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.sku_list import all_registered_models
|
from llama_stack.models.llama.sku_list import all_registered_models
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
build_hf_repo_model_entry,
|
build_hf_repo_model_entry,
|
||||||
|
@ -65,7 +64,6 @@ def build_hf_repo_model_entries():
|
||||||
class _HfAdapter(
|
class _HfAdapter(
|
||||||
OpenAIMixin,
|
OpenAIMixin,
|
||||||
Inference,
|
Inference,
|
||||||
ModelsProtocolPrivate,
|
|
||||||
):
|
):
|
||||||
url: str
|
url: str
|
||||||
api_key: SecretStr
|
api_key: SecretStr
|
||||||
|
|
|
@ -47,7 +47,7 @@ from .config import TogetherImplConfig
|
||||||
logger = get_logger(name=__name__, category="inference::together")
|
logger = get_logger(name=__name__, category="inference::together")
|
||||||
|
|
||||||
|
|
||||||
class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
class TogetherInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData):
|
||||||
embedding_model_metadata = {
|
embedding_model_metadata = {
|
||||||
"togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768},
|
"togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768},
|
||||||
"BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512},
|
"BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512},
|
||||||
|
|
|
@ -26,14 +26,14 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
|
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="providers::utils")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
|
||||||
"""
|
"""
|
||||||
Mixin class that provides OpenAI-specific functionality for inference providers.
|
Mixin class that provides OpenAI-specific functionality for inference providers.
|
||||||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||||
|
@ -73,6 +73,9 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
||||||
# Optional field name in provider data to look for API key, which takes precedence
|
# Optional field name in provider data to look for API key, which takes precedence
|
||||||
provider_data_api_key_field: str | None = None
|
provider_data_api_key_field: str | None = None
|
||||||
|
|
||||||
|
# automatically set by the resolver when instantiating the provider
|
||||||
|
__provider_id__: str
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -356,6 +359,24 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
###
|
||||||
|
# ModelsProtocolPrivate implementation - provide model management functionality
|
||||||
|
#
|
||||||
|
# async def register_model(self, model: Model) -> Model: ...
|
||||||
|
# async def unregister_model(self, model_id: str) -> None: ...
|
||||||
|
#
|
||||||
|
# async def list_models(self) -> list[Model] | None: ...
|
||||||
|
# async def should_refresh_models(self) -> bool: ...
|
||||||
|
##
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
if not await self.check_model_availability(model.provider_model_id):
|
||||||
|
raise ValueError(f"Model {model.provider_model_id} is not available from provider {self.__provider_id__}")
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
async def list_models(self) -> list[Model] | None:
|
async def list_models(self) -> list[Model] | None:
|
||||||
"""
|
"""
|
||||||
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
|
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
|
||||||
|
@ -400,5 +421,7 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
||||||
"""
|
"""
|
||||||
if not self._model_cache:
|
if not self._model_cache:
|
||||||
await self.list_models()
|
await self.list_models()
|
||||||
|
|
||||||
return model in self._model_cache
|
return model in self._model_cache
|
||||||
|
|
||||||
|
async def should_refresh_models(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
57
tests/integration/recordings/responses/39576bcd7ed6.json
Normal file
57
tests/integration/recordings/responses/39576bcd7ed6.json
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
{
|
||||||
|
"request": {
|
||||||
|
"method": "POST",
|
||||||
|
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
|
||||||
|
"headers": {},
|
||||||
|
"body": {
|
||||||
|
"model": "llama-guard3:1b",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Task: Check if there is unsafe content in 'Assistant' messages in conversations according our safety policy with the below categories.\n\n<BEGIN UNSAFE CONTENT CATEGORIES>\nS1: Violent Crimes.\nS2: Non-Violent Crimes.\nS3: Sex Crimes.\nS4: Child Exploitation.\nS5: Defamation.\nS6: Specialized Advice.\nS7: Privacy.\nS8: Intellectual Property.\nS9: Indiscriminate Weapons.\nS10: Hate.\nS11: Self-Harm.\nS12: Sexual Content.\nS13: Elections.\nS14: Code Interpreter Abuse.\n<END UNSAFE CONTENT CATEGORIES>\n\n<BEGIN CONVERSATION>\n\nUser: Give me a sentence that contains the word: hello\n\nAssistant: \n\n<END CONVERSATION>\n\nProvide your safety assessment for ONLY THE LAST Assistant message in the above conversation:\n - First line must read 'safe' or 'unsafe'.\n - If unsafe, a second line must include a comma-separated list of violated categories."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": false,
|
||||||
|
"temperature": 0.0
|
||||||
|
},
|
||||||
|
"endpoint": "/v1/chat/completions",
|
||||||
|
"model": "llama-guard3:1b"
|
||||||
|
},
|
||||||
|
"response": {
|
||||||
|
"body": {
|
||||||
|
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "chatcmpl-317",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "safe",
|
||||||
|
"refusal": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"annotations": null,
|
||||||
|
"audio": null,
|
||||||
|
"function_call": null,
|
||||||
|
"tool_calls": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1759351124,
|
||||||
|
"model": "llama-guard3:1b",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"service_tier": null,
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 2,
|
||||||
|
"prompt_tokens": 397,
|
||||||
|
"total_tokens": 399,
|
||||||
|
"completion_tokens_details": null,
|
||||||
|
"prompt_tokens_details": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"is_streaming": false
|
||||||
|
}
|
||||||
|
}
|
40
tests/integration/recordings/responses/53d2488c9ea9.json
Normal file
40
tests/integration/recordings/responses/53d2488c9ea9.json
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
{
|
||||||
|
"request": {
|
||||||
|
"method": "POST",
|
||||||
|
"url": "http://localhost:11434/api/generate",
|
||||||
|
"headers": {},
|
||||||
|
"body": {
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"options": {
|
||||||
|
"temperature": 0.0001,
|
||||||
|
"top_p": 0.9
|
||||||
|
},
|
||||||
|
"stream": true
|
||||||
|
},
|
||||||
|
"endpoint": "/api/generate",
|
||||||
|
"model": "llama3.2:3b-instruct-fp16"
|
||||||
|
},
|
||||||
|
"response": {
|
||||||
|
"body": [
|
||||||
|
{
|
||||||
|
"__type__": "ollama._types.GenerateResponse",
|
||||||
|
"__data__": {
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"created_at": "2025-10-01T20:38:48.732564955Z",
|
||||||
|
"done": true,
|
||||||
|
"done_reason": "load",
|
||||||
|
"total_duration": null,
|
||||||
|
"load_duration": null,
|
||||||
|
"prompt_eval_count": null,
|
||||||
|
"prompt_eval_duration": null,
|
||||||
|
"eval_count": null,
|
||||||
|
"eval_duration": null,
|
||||||
|
"response": "",
|
||||||
|
"thinking": null,
|
||||||
|
"context": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"is_streaming": true
|
||||||
|
}
|
||||||
|
}
|
|
@ -362,6 +362,124 @@ class TestOpenAIMixinAllowedModels:
|
||||||
assert not await mixin.check_model_availability("another-mock-model-id")
|
assert not await mixin.check_model_availability("another-mock-model-id")
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIMixinModelRegistration:
|
||||||
|
"""Test cases for model registration functionality"""
|
||||||
|
|
||||||
|
async def test_register_model_success(self, mixin, mock_client_with_models, mock_client_context):
|
||||||
|
"""Test successful model registration when model is available"""
|
||||||
|
model = Model(
|
||||||
|
provider_id="test-provider",
|
||||||
|
provider_resource_id="some-mock-model-id",
|
||||||
|
identifier="test-model",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock_client_context(mixin, mock_client_with_models):
|
||||||
|
result = await mixin.register_model(model)
|
||||||
|
|
||||||
|
assert result == model
|
||||||
|
assert result.provider_id == "test-provider"
|
||||||
|
assert result.provider_resource_id == "some-mock-model-id"
|
||||||
|
assert result.identifier == "test-model"
|
||||||
|
assert result.model_type == ModelType.llm
|
||||||
|
mock_client_with_models.models.list.assert_called_once()
|
||||||
|
|
||||||
|
async def test_register_model_not_available(self, mixin, mock_client_with_models, mock_client_context):
|
||||||
|
"""Test model registration failure when model is not available from provider"""
|
||||||
|
model = Model(
|
||||||
|
provider_id="test-provider",
|
||||||
|
provider_resource_id="non-existent-model",
|
||||||
|
identifier="test-model",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock_client_context(mixin, mock_client_with_models):
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Model non-existent-model is not available from provider test-provider"
|
||||||
|
):
|
||||||
|
await mixin.register_model(model)
|
||||||
|
mock_client_with_models.models.list.assert_called_once()
|
||||||
|
|
||||||
|
async def test_register_model_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
|
||||||
|
"""Test model registration with allowed_models filtering"""
|
||||||
|
mixin.allowed_models = {"some-mock-model-id"}
|
||||||
|
|
||||||
|
# Test with allowed model
|
||||||
|
allowed_model = Model(
|
||||||
|
provider_id="test-provider",
|
||||||
|
provider_resource_id="some-mock-model-id",
|
||||||
|
identifier="allowed-model",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with disallowed model
|
||||||
|
disallowed_model = Model(
|
||||||
|
provider_id="test-provider",
|
||||||
|
provider_resource_id="final-mock-model-id",
|
||||||
|
identifier="disallowed-model",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock_client_context(mixin, mock_client_with_models):
|
||||||
|
result = await mixin.register_model(allowed_model)
|
||||||
|
assert result == allowed_model
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Model final-mock-model-id is not available from provider test-provider"
|
||||||
|
):
|
||||||
|
await mixin.register_model(disallowed_model)
|
||||||
|
mock_client_with_models.models.list.assert_called_once()
|
||||||
|
|
||||||
|
async def test_register_embedding_model(self, mixin_with_embeddings, mock_client_context):
|
||||||
|
"""Test registration of embedding models with metadata"""
|
||||||
|
mock_embedding_model = MagicMock(id="text-embedding-3-small")
|
||||||
|
mock_models = [mock_embedding_model]
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
|
||||||
|
async def mock_models_list():
|
||||||
|
for model in mock_models:
|
||||||
|
yield model
|
||||||
|
|
||||||
|
mock_client.models.list.return_value = mock_models_list()
|
||||||
|
|
||||||
|
embedding_model = Model(
|
||||||
|
provider_id="test-provider",
|
||||||
|
provider_resource_id="text-embedding-3-small",
|
||||||
|
identifier="embedding-test",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock_client_context(mixin_with_embeddings, mock_client):
|
||||||
|
result = await mixin_with_embeddings.register_model(embedding_model)
|
||||||
|
assert result == embedding_model
|
||||||
|
assert result.model_type == ModelType.embedding
|
||||||
|
|
||||||
|
async def test_unregister_model(self, mixin):
|
||||||
|
"""Test model unregistration (should be no-op)"""
|
||||||
|
# unregister_model should not raise any exceptions and return None
|
||||||
|
result = await mixin.unregister_model("any-model-id")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_should_refresh_models(self, mixin):
|
||||||
|
"""Test should_refresh_models method (should always return False)"""
|
||||||
|
result = await mixin.should_refresh_models()
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
async def test_register_model_error_propagation(self, mixin, mock_client_with_exception, mock_client_context):
|
||||||
|
"""Test that errors from provider API are properly propagated during registration"""
|
||||||
|
model = Model(
|
||||||
|
provider_id="test-provider",
|
||||||
|
provider_resource_id="some-model",
|
||||||
|
identifier="test-model",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock_client_context(mixin, mock_client_with_exception):
|
||||||
|
# The exception from the API should be propagated
|
||||||
|
with pytest.raises(Exception, match="API Error"):
|
||||||
|
await mixin.register_model(model)
|
||||||
|
|
||||||
|
|
||||||
class ProviderDataValidator(BaseModel):
|
class ProviderDataValidator(BaseModel):
|
||||||
"""Validator for provider data in tests"""
|
"""Validator for provider data in tests"""
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue