skeleton models api

This commit is contained in:
Xi Yan 2024-09-19 16:26:24 -07:00
parent 59af1c8fec
commit 68131afc86
9 changed files with 233 additions and 10 deletions

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from pathlib import Path
from typing import Any, Dict, List, Optional
import fire
import httpx
from llama_stack.distribution.datatypes import RemoteProviderConfig
from termcolor import cprint
from .models import * # noqa: F403
class ModelsClient(Models):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def list_models(self) -> List[ModelSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/models/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return ModelsListResponse(**response.json())
async def get_model(self, core_model_id: str) -> List[ModelSpec]:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/models/get",
json={
"core_model_id": core_model_id,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return ModelsGetResponse(**response.json())
async def run_main(host: str, port: int, stream: bool):
client = ModelsClient(f"http://{host}:{port}")
response = await client.list_models()
cprint(f"list_models response={response}", "green")
response = await client.get_model("Meta-Llama3.1-8B-Instruct")
cprint(f"get_model response={response}", "blue")
response = await client.get_model("Llama-Guard-3-8B")
cprint(f"get_model response={response}", "red")
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -1,14 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol
from typing import Dict, List, Optional, Protocol
from llama_models.schema_utils import webmethod # noqa: F401
from llama_models.llama3.api.datatypes import * # noqa: F403
from pydantic import BaseModel # noqa: F401
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
class Models(Protocol): ...
@json_schema_type
class ModelSpec(BaseModel):
llama_model_metadata: Model = Field(
description="All metadatas associated with llama model (defined in llama_models.models.sku_list). "
)
providers_spec: Dict[str, Any] = Field(
default_factory=dict,
description="Map of API to the concrete provider specs. E.g. {}".format(
{
"inference": {
"provider_type": "remote::8080",
"url": "localhost::5555",
"api_token": "hf_xxx",
},
}
),
)
@json_schema_type
class ModelsListResponse(BaseModel):
models_list: List[ModelSpec]
@json_schema_type
class ModelsGetResponse(BaseModel):
core_model_spec: Optional[ModelSpec] = None
@json_schema_type
class ModelsRegisterResponse(BaseModel):
core_model_spec: Optional[ModelSpec] = None
class Models(Protocol):
@webmethod(route="/models/list", method="GET")
async def list_models(self) -> ModelsListResponse: ...
@webmethod(route="/models/get", method="POST")
async def get_model(self, core_model_id: str) -> ModelsGetResponse: ...

View file

@ -20,6 +20,7 @@ class Api(Enum):
agents = "agents"
memory = "memory"
telemetry = "telemetry"
models = "models"
@json_schema_type

View file

@ -11,6 +11,7 @@ from typing import Dict, List
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.telemetry import Telemetry
@ -38,6 +39,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
Api.models: Models,
}
for api, protocol in protocols.items():

View file

@ -35,9 +35,6 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -45,6 +42,9 @@ from llama_stack.providers.utils.telemetry.tracing import (
SpanStatus,
start_trace,
)
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
@ -333,7 +333,9 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
app = FastAPI()
print(config)
impls, specs = asyncio.run(resolve_impls(config.provider_map))
print(impls)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import BuiltinImplConfig # noqa
async def get_provider_impl(config: BuiltinImplConfig, deps: Dict[Api, ProviderSpec]):
from .models import BuiltinModelsImpl
assert isinstance(
config, BuiltinImplConfig
), f"Unexpected config type: {type(config)}"
impl = BuiltinModelsImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models, resolve_model
from pydantic import BaseModel, Field, field_validator
@json_schema_type
class BuiltinImplConfig(BaseModel): ...

View file

@ -0,0 +1,46 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import AsyncIterator, Union
from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from llama_stack.apis.models import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.datatypes import CoreModelId, Model
from llama_models.sku_list import resolve_model
from .config import BuiltinImplConfig
DUMMY_MODELS_SPEC = ModelSpec(
llama_model_metadata=resolve_model("Meta-Llama3.1-8B"),
providers_spec={"inference": {"provider_type": "meta-reference"}},
)
class BuiltinModelsImpl(Models):
def __init__(
self,
config: BuiltinImplConfig,
) -> None:
self.config = config
self.models_list = [DUMMY_MODELS_SPEC]
async def initialize(self) -> None:
pass
async def list_models(self) -> ModelsListResponse:
return ModelsListResponse(models_list=self.models_list)
async def get_model(self, core_model_id: str) -> ModelsGetResponse:
return ModelsGetResponse(core_model_spec=DUMMY_MODELS_SPEC)
async def register_model(
self, model_id: str, api: str, provider_spec: Dict[str, str]
) -> ModelsRegisterResponse:
return ModelsRegisterResponse()

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.models,
provider_id="builtin",
pip_packages=[],
module="llama_stack.providers.impls.builtin.models",
config_class="llama_stack.providers.impls.builtin.models.BuiltinImplConfig",
api_dependencies=[],
)
]