mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: allow user to register model alias explicitly, tests
# What does this PR do? Context: https://github.com/llamastack/llama-stack/discussions/3483 This PR enables the registering `provider_model_id` as the model identifier without breaking backward compatibility. ## Test Plan todo # What does this PR do? ## Test Plan
This commit is contained in:
parent
ac1414b571
commit
83a229554b
20 changed files with 236 additions and 92 deletions
|
@ -10,9 +10,12 @@ from typing import Any, Literal, Protocol, runtime_checkable
|
|||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class CommonModelFields(BaseModel):
|
||||
metadata: dict[str, Any] = Field(
|
||||
|
@ -68,11 +71,36 @@ class Model(CommonModelFields, Resource):
|
|||
|
||||
|
||||
class ModelInput(CommonModelFields):
|
||||
model_id: str
|
||||
provider_id: str | None = None
|
||||
"""A model input for registering a model.
|
||||
|
||||
:param provider_model_id: The identifier of the model in the provider.
|
||||
:param provider_id: The identifier of the provider.
|
||||
:param model_type: The type of model to register.
|
||||
:param model_id: The identifier of the model to register. If model_id == provider_model_id, provider_id/provider_model_id will be used as the identifier. Otherwise,
|
||||
model_id will be used as the identifier.
|
||||
The behavior of this field will soon change to "always use model_id as the identifier".
|
||||
:param use_provider_model_id_as_id: Set to true to use provider_model_id as the identifier. Use model_id if you want to use a different identifier.
|
||||
"""
|
||||
|
||||
provider_model_id: str | None = None
|
||||
provider_id: str | None = None
|
||||
model_type: ModelType | None = ModelType.llm
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
# TODO: update behavior of this field to always be the identifier
|
||||
model_id: str | None = None
|
||||
use_provider_model_id_as_id: bool = False
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if self.model_id is None and self.provider_model_id is None:
|
||||
raise ValueError("provider_model_id must be provided")
|
||||
|
||||
if self.model_id == self.provider_model_id:
|
||||
logger.warning(
|
||||
f"`model_id` is now optional. The behavior of this field will change if model_id == provider_model_id. Please remove `model_id` and use `provider_model_id` instead.: {self.model_id}"
|
||||
)
|
||||
|
||||
if self.use_provider_model_id_as_id and self.model_id:
|
||||
raise ValueError(f"use_provider_model_id_as_id and model_id cannot be provided together: {self.model_id}")
|
||||
|
||||
|
||||
class ListModelsResponse(BaseModel):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue