From 18b3dbcacc5de19847b24e140f5c59740507bb9d Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 18 Sep 2024 10:00:29 -0700 Subject: [PATCH] wip --- llama_stack/apis/models/clients.py | 72 ++++++++++++++ .../impls/meta_reference/models/__init__.py | 25 +++++ .../impls/meta_reference/models/config.py | 18 ++++ .../impls/meta_reference/models/models.py | 99 +++++++++++++++++++ llama_stack/providers/registry/models.py | 25 +++++ 5 files changed, 239 insertions(+) create mode 100644 llama_stack/apis/models/clients.py create mode 100644 llama_stack/providers/impls/meta_reference/models/__init__.py create mode 100644 llama_stack/providers/impls/meta_reference/models/config.py create mode 100644 llama_stack/providers/impls/meta_reference/models/models.py create mode 100644 llama_stack/providers/registry/models.py diff --git a/llama_stack/apis/models/clients.py b/llama_stack/apis/models/clients.py new file mode 100644 index 000000000..85009eb3d --- /dev/null +++ b/llama_stack/apis/models/clients.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json +from pathlib import Path + +from typing import Any, Dict, List, Optional + +import fire +import httpx + +from llama_toolchain.core.datatypes import RemoteProviderConfig +from termcolor import cprint + +from .api import * # noqa: F403 + + +class ModelsClient(Models): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def list_models(self) -> List[ModelSpec]: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/models/list", + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return ModelsListResponse(**response.json()) + + async def get_model(self, model_id: str) -> List[ModelSpec]: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/models/get", + json={ + "model_id": model_id, + }, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return ModelsGetResponse(**response.json()) + + +async def run_main(host: str, port: int, stream: bool): + client = ModelsClient(f"http://{host}:{port}") + + response = await client.list_models() + cprint(f"list_models response={response}", "green") + + response = await client.get_model("Meta-Llama3.1-8B-Instruct") + cprint(f"get_model response={response}", "blue") + + response = await client.get_model("Llama-Guard-3-8B") + cprint(f"get_model response={response}", "red") + + +def main(host: str, port: int, stream: bool = True): + asyncio.run(run_main(host, port, stream)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/providers/impls/meta_reference/models/__init__.py b/llama_stack/providers/impls/meta_reference/models/__init__.py new file mode 100644 index 000000000..585da0933 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/models/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict + +from llama_stack.distribution.datatypes import Api, ProviderSpec + +from .config import MetaReferenceImplConfig # noqa + + +async def get_provider_impl( + config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec] +): + from .models import MetaReferenceModelsImpl + + assert isinstance( + config, MetaReferenceImplConfig + ), f"Unexpected config type: {type(config)}" + + impl = MetaReferenceModelsImpl(config, deps[Api.inference], deps[Api.safety]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/impls/meta_reference/models/config.py b/llama_stack/providers/impls/meta_reference/models/config.py new file mode 100644 index 000000000..f9a80de15 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/models/config.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_models.datatypes import ModelFamily + +from llama_models.schema_utils import json_schema_type +from llama_models.sku_list import all_registered_models, resolve_model + +from pydantic import BaseModel, Field, field_validator + + +@json_schema_type +class MetaReferenceImplConfig(BaseModel): ... diff --git a/llama_stack/providers/impls/meta_reference/models/models.py b/llama_stack/providers/impls/meta_reference/models/models.py new file mode 100644 index 000000000..a7843f0fc --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/models/models.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import asyncio + +from typing import AsyncIterator, Union + +from llama_models.llama3.api.datatypes import StopReason +from llama_models.sku_list import resolve_model + +from llama_stack.apis.models import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_models.datatypes import CoreModelId, Model +from llama_models.sku_list import resolve_model +from llama_stack.apis.inference import Inference +from llama_stack.apis.safety import Safety + +from llama_stack.providers.impls.meta_reference.inference.inference import ( + MetaReferenceInferenceImpl, +) +from llama_stack.providers.impls.meta_reference.safety.safety import ( + MetaReferenceSafetyImpl, +) + +from .config import MetaReferenceImplConfig + + +class MetaReferenceModelsImpl(Models): + def __init__( + self, + config: MetaReferenceImplConfig, + inference_api: Inference, + safety_api: Safety, + ) -> None: + self.config = config + self.inference_api = inference_api + self.safety_api = safety_api + + self.models_list = [] + # TODO, make the inference route provider and use router provider to do the lookup dynamically + if isinstance( + self.inference_api, + MetaReferenceInferenceImpl, + ): + model = resolve_model(self.inference_api.config.model) + self.models_list.append( + ModelSpec( + llama_model_metadata=model, + providers_spec={ + "inference": [{"provider_type": "meta-reference"}], + }, + ) + ) + + if isinstance( + self.safety_api, + MetaReferenceSafetyImpl, + ): + shield_cfg = self.safety_api.config.llama_guard_shield + if shield_cfg is not None: + model = resolve_model(shield_cfg.model) + self.models_list.append( + ModelSpec( + llama_model_metadata=model, + providers_spec={ + "safety": [{"provider_type": "meta-reference"}], + }, + ) + ) + shield_cfg = self.safety_api.config.prompt_guard_shield + if shield_cfg is not None: + model = resolve_model(shield_cfg.model) + self.models_list.append( + ModelSpec( + llama_model_metadata=model, + providers_spec={ + "safety": [{"provider_type": "meta-reference"}], + }, + ) + ) + + async def initialize(self) -> None: + pass + + async def list_models(self) -> ModelsListResponse: + return ModelsListResponse(models_list=self.models_list) + + async def get_model(self, model_id: str) -> ModelsGetResponse: + for model in self.models_list: + if model.llama_model_metadata.core_model_id.value == model_id: + return ModelsGetResponse(core_model_spec=model) + return ModelsGetResponse() + + async def register_model( + self, model_id: str, api: str, provider_spec: Dict[str, str] + ) -> ModelsRegisterResponse: + return ModelsRegisterResponse() diff --git a/llama_stack/providers/registry/models.py b/llama_stack/providers/registry/models.py new file mode 100644 index 000000000..4b4738b33 --- /dev/null +++ b/llama_stack/providers/registry/models.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.distribution.datatypes import * # noqa: F403 + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.models, + provider_id="meta-reference", + pip_packages=[], + module="llama_stack.providers.impls.meta_reference.models", + config_class="llama_stack.providers.impls.meta_reference.models.MetaReferenceImplConfig", + api_dependencies=[ + Api.inference, + Api.safety, + ], + ) + ]