Further generalize Xi's changes (#88)

* Further generalize Xi's changes

- introduce a slightly more general notion of an AutoRouted provider
- the AutoRouted provider is associated with a RoutingTable provider
- e.g. inference -> models
- Introduced safety -> shields and memory -> memory_banks
  correspondences

* typo

* Basic build and run succeeded
This commit is contained in:
Ashwin Bharambe 2024-09-22 16:31:18 -07:00 committed by GitHub
parent b8914bb56f
commit c1ab66f1e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 597 additions and 418 deletions

View file

@ -5,15 +5,11 @@
# the root directory of this source tree.
import asyncio
import json
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import List, Optional
import fire
import httpx
from llama_stack.distribution.datatypes import RemoteProviderConfig
from termcolor import cprint
from .models import * # noqa: F403
@ -29,18 +25,18 @@ class ModelsClient(Models):
async def shutdown(self) -> None:
pass
async def list_models(self) -> ModelsListResponse:
async def list_models(self) -> List[ModelServingSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/models/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return ModelsListResponse(**response.json())
return [ModelServingSpec(**x) for x in response.json()]
async def get_model(self, core_model_id: str) -> ModelsGetResponse:
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
async with httpx.AsyncClient() as client:
response = await client.post(
response = await client.get(
f"{self.base_url}/models/get",
json={
"core_model_id": core_model_id,
@ -48,7 +44,10 @@ class ModelsClient(Models):
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return ModelsGetResponse(**response.json())
j = response.json()
if j is None:
return None
return ModelServingSpec(**j)
async def run_main(host: str, port: int, stream: bool):

View file

@ -4,14 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol
from typing import List, Optional, Protocol
from llama_models.llama3.api.datatypes import Model
from llama_models.schema_utils import json_schema_type, webmethod
from llama_stack.distribution.datatypes import GenericProviderConfig
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type
class ModelServingSpec(BaseModel):
@ -21,25 +22,11 @@ class ModelServingSpec(BaseModel):
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_id, and corresponding config. ",
)
api: str = Field(
description="The API that this model is serving (e.g. inference / safety).",
default="inference",
)
@json_schema_type
class ModelsListResponse(BaseModel):
models_list: List[ModelServingSpec]
@json_schema_type
class ModelsGetResponse(BaseModel):
core_model_spec: Optional[ModelServingSpec] = None
class Models(Protocol):
@webmethod(route="/models/list", method="GET")
async def list_models(self) -> ModelsListResponse: ...
async def list_models(self) -> List[ModelServingSpec]: ...
@webmethod(route="/models/get", method="POST")
async def get_model(self, core_model_id: str) -> ModelsGetResponse: ...
@webmethod(route="/models/get", method="GET")
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: ...