diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index ee1d5f0ba..bebd90bc1 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -4,11 +4,63 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Protocol +from typing import Dict, List, Optional, Protocol -from llama_models.schema_utils import webmethod # noqa: F401 +from llama_models.llama3.api.datatypes import * # noqa: F403 -from pydantic import BaseModel # noqa: F401 +from llama_models.schema_utils import json_schema_type, webmethod +from pydantic import BaseModel, Field -class Models(Protocol): ... +@json_schema_type +class ModelSpec(BaseModel): + llama_model_metadata: Model = Field( + description="All metadatas associated with llama model (defined in llama_models.models.sku_list). " + ) + providers_spec: Dict[str, List[Any]] = Field( + default_factory=dict, + description="Map of API to the concrete provider specs. E.g. {}".format( + { + "inference": [ + { + "provider_type": "remote::8080", + "url": "localhost::5555", + "api_token": "hf_xxx", + }, + { + "provider_type": "meta-reference", + "model": "Meta-Llama3.1-8B-Instruct", + "max_seq_len": 4096, + }, + ] + } + ), + ) + + +@json_schema_type +class ModelsListResponse(BaseModel): + models_list: List[ModelSpec] + + +@json_schema_type +class ModelsGetResponse(BaseModel): + core_model_spec: Optional[ModelSpec] = None + + +@json_schema_type +class ModelsRegisterResponse(BaseModel): + core_model_spec: Optional[ModelSpec] = None + + +class Models(Protocol): + @webmethod(route="/models/list", method="GET") + async def list_models(self) -> ModelsListResponse: ... + + @webmethod(route="/models/get", method="POST") + async def get_model(self, model_id: str) -> ModelsGetResponse: ... + + @webmethod(route="/models/register") + async def register_model( + self, model_id: str, api: str, provider_spec: Dict[str, str] + ) -> ModelsRegisterResponse: ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index e57617016..457ab0d3a 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -20,6 +20,7 @@ class Api(Enum): agents = "agents" memory = "memory" telemetry = "telemetry" + models = "models" @json_schema_type diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 0825121dc..3dd406ccc 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -11,6 +11,7 @@ from typing import Dict, List from llama_stack.apis.agents import Agents from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory +from llama_stack.apis.models import Models from llama_stack.apis.safety import Safety from llama_stack.apis.telemetry import Telemetry @@ -38,6 +39,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: Api.agents: Agents, Api.memory: Memory, Api.telemetry: Telemetry, + Api.models: Models, } for api, protocol in protocols.items():