diff --git a/llama_stack/apis/models/clients.py b/llama_stack/apis/models/clients.py new file mode 100644 index 000000000..4069480a8 --- /dev/null +++ b/llama_stack/apis/models/clients.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json +from pathlib import Path + +from typing import Any, Dict, List, Optional + +import fire +import httpx + +from llama_stack.distribution.datatypes import RemoteProviderConfig +from termcolor import cprint + +from .models import * # noqa: F403 + + +class ModelsClient(Models): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def list_models(self) -> List[ModelSpec]: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/models/list", + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return ModelsListResponse(**response.json()) + + async def get_model(self, core_model_id: str) -> List[ModelSpec]: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/models/get", + json={ + "core_model_id": core_model_id, + }, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return ModelsGetResponse(**response.json()) + + +async def run_main(host: str, port: int, stream: bool): + client = ModelsClient(f"http://{host}:{port}") + + response = await client.list_models() + cprint(f"list_models response={response}", "green") + + response = await client.get_model("Meta-Llama3.1-8B-Instruct") + cprint(f"get_model response={response}", "blue") + + response = await client.get_model("Llama-Guard-3-8B") + cprint(f"get_model response={response}", "red") + + +def main(host: str, port: int, stream: bool = True): + asyncio.run(run_main(host, port, stream)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index ee1d5f0ba..d3aa64292 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -1,14 +1,51 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Protocol +from typing import Dict, List, Optional, Protocol -from llama_models.schema_utils import webmethod # noqa: F401 +from llama_models.llama3.api.datatypes import * # noqa: F403 -from pydantic import BaseModel # noqa: F401 +from llama_models.schema_utils import json_schema_type, webmethod +from pydantic import BaseModel, Field -class Models(Protocol): ... +@json_schema_type +class ModelSpec(BaseModel): + llama_model_metadata: Model = Field( + description="All metadatas associated with llama model (defined in llama_models.models.sku_list). " + ) + providers_spec: Dict[str, Any] = Field( + default_factory=dict, + description="Map of API to the concrete provider specs. E.g. {}".format( + { + "inference": { + "provider_type": "remote::8080", + "url": "localhost::5555", + "api_token": "hf_xxx", + }, + } + ), + ) + + +@json_schema_type +class ModelsListResponse(BaseModel): + models_list: List[ModelSpec] + + +@json_schema_type +class ModelsGetResponse(BaseModel): + core_model_spec: Optional[ModelSpec] = None + + +@json_schema_type +class ModelsRegisterResponse(BaseModel): + core_model_spec: Optional[ModelSpec] = None + + +class Models(Protocol): + @webmethod(route="/models/list", method="GET") + async def list_models(self) -> ModelsListResponse: ... + + @webmethod(route="/models/get", method="POST") + async def get_model(self, core_model_id: str) -> ModelsGetResponse: ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index e57617016..457ab0d3a 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -20,6 +20,7 @@ class Api(Enum): agents = "agents" memory = "memory" telemetry = "telemetry" + models = "models" @json_schema_type diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 0825121dc..3dd406ccc 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -11,6 +11,7 @@ from typing import Dict, List from llama_stack.apis.agents import Agents from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory +from llama_stack.apis.models import Models from llama_stack.apis.safety import Safety from llama_stack.apis.telemetry import Telemetry @@ -38,6 +39,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: Api.agents: Agents, Api.memory: Memory, Api.telemetry: Telemetry, + Api.models: Models, } for api, protocol in protocols.items(): diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 16d24cad5..583a25e1a 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -35,9 +35,6 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute -from pydantic import BaseModel, ValidationError -from termcolor import cprint -from typing_extensions import Annotated from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -45,6 +42,9 @@ from llama_stack.providers.utils.telemetry.tracing import ( SpanStatus, start_trace, ) +from pydantic import BaseModel, ValidationError +from termcolor import cprint +from typing_extensions import Annotated from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.distribution import api_endpoints, api_providers @@ -333,7 +333,9 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): app = FastAPI() + print(config) impls, specs = asyncio.run(resolve_impls(config.provider_map)) + print(impls) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) diff --git a/llama_stack/providers/impls/builtin/models/__init__.py b/llama_stack/providers/impls/builtin/models/__init__.py new file mode 100644 index 000000000..439f2be61 --- /dev/null +++ b/llama_stack/providers/impls/builtin/models/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict + +from llama_stack.distribution.datatypes import Api, ProviderSpec + +from .config import BuiltinImplConfig # noqa + + +async def get_provider_impl(config: BuiltinImplConfig, deps: Dict[Api, ProviderSpec]): + from .models import BuiltinModelsImpl + + assert isinstance( + config, BuiltinImplConfig + ), f"Unexpected config type: {type(config)}" + + impl = BuiltinModelsImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/impls/builtin/models/config.py b/llama_stack/providers/impls/builtin/models/config.py new file mode 100644 index 000000000..b24499d4e --- /dev/null +++ b/llama_stack/providers/impls/builtin/models/config.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_models.datatypes import ModelFamily + +from llama_models.schema_utils import json_schema_type +from llama_models.sku_list import all_registered_models, resolve_model + +from pydantic import BaseModel, Field, field_validator + + +@json_schema_type +class BuiltinImplConfig(BaseModel): ... diff --git a/llama_stack/providers/impls/builtin/models/models.py b/llama_stack/providers/impls/builtin/models/models.py new file mode 100644 index 000000000..f87aab83a --- /dev/null +++ b/llama_stack/providers/impls/builtin/models/models.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import asyncio + +from typing import AsyncIterator, Union + +from llama_models.llama3.api.datatypes import StopReason +from llama_models.sku_list import resolve_model + +from llama_stack.apis.models import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_models.datatypes import CoreModelId, Model +from llama_models.sku_list import resolve_model + +from .config import BuiltinImplConfig + +DUMMY_MODELS_SPEC = ModelSpec( + llama_model_metadata=resolve_model("Meta-Llama3.1-8B"), + providers_spec={"inference": {"provider_type": "meta-reference"}}, +) + + +class BuiltinModelsImpl(Models): + def __init__( + self, + config: BuiltinImplConfig, + ) -> None: + self.config = config + self.models_list = [DUMMY_MODELS_SPEC] + + async def initialize(self) -> None: + pass + + async def list_models(self) -> ModelsListResponse: + return ModelsListResponse(models_list=self.models_list) + + async def get_model(self, core_model_id: str) -> ModelsGetResponse: + return ModelsGetResponse(core_model_spec=DUMMY_MODELS_SPEC) + + async def register_model( + self, model_id: str, api: str, provider_spec: Dict[str, str] + ) -> ModelsRegisterResponse: + return ModelsRegisterResponse() diff --git a/llama_stack/providers/registry/models.py b/llama_stack/providers/registry/models.py new file mode 100644 index 000000000..3aa577cf9 --- /dev/null +++ b/llama_stack/providers/registry/models.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.distribution.datatypes import * # noqa: F403 + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.models, + provider_id="builtin", + pip_packages=[], + module="llama_stack.providers.impls.builtin.models", + config_class="llama_stack.providers.impls.builtin.models.BuiltinImplConfig", + api_dependencies=[], + ) + ]