Revert "Revert "add model type to APIs" (#605)"

This reverts commit 47b2dc8ae3.
This commit is contained in:
Dinesh Yeduguru 2024-12-11 10:18:00 -08:00 committed by GitHub
parent 47b2dc8ae3
commit 310c15bada
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 77 additions and 13 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
@ -20,6 +21,11 @@ class CommonModelFields(BaseModel):
)
class ModelType(Enum):
llm = "llm"
embedding_model = "embedding"
@json_schema_type
class Model(CommonModelFields, Resource):
type: Literal[ResourceType.model.value] = ResourceType.model.value
@ -34,11 +40,14 @@ class Model(CommonModelFields, Resource):
model_config = ConfigDict(protected_namespaces=())
model_type: ModelType = Field(default=ModelType.llm)
class ModelInput(CommonModelFields):
model_id: str
provider_id: Optional[str] = None
provider_model_id: Optional[str] = None
model_type: Optional[ModelType] = ModelType.llm
model_config = ConfigDict(protected_namespaces=())
@ -59,6 +68,7 @@ class Models(Protocol):
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model: ...
@webmethod(route="/models/unregister", method="POST")