Further generalize Xi's changes (#88)

* Further generalize Xi's changes

- introduce a slightly more general notion of an AutoRouted provider
- the AutoRouted provider is associated with a RoutingTable provider
- e.g. inference -> models
- Introduced safety -> shields and memory -> memory_banks
  correspondences

* typo

* Basic build and run succeeded
This commit is contained in:
Ashwin Bharambe 2024-09-22 16:31:18 -07:00 committed by GitHub
parent b8914bb56f
commit c1ab66f1e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 597 additions and 418 deletions

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .memory_banks import * # noqa: F401 F403

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import List, Optional
import fire
import httpx
from termcolor import cprint
from .memory_banks import * # noqa: F403
class MemoryBanksClient(MemoryBanks):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def list_memory_banks(self) -> List[MemoryBankSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return [MemoryBankSpec(**x) for x in response.json()]
async def get_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/get",
json={
"bank_type": bank_type,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
j = response.json()
if j is None:
return None
return MemoryBankSpec(**j)
async def run_main(host: str, port: int, stream: bool):
client = MemoryBanksClient(f"http://{host}:{port}")
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.memory import MemoryBankType
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type
class MemoryBankSpec(BaseModel):
bank_type: MemoryBankType
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_id, and corresponding config. ",
)
class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBankSpec]: ...
@webmethod(route="/memory_banks/get", method="GET")
async def get_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]: ...

View file

@ -5,15 +5,11 @@
# the root directory of this source tree.
import asyncio
import json
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import List, Optional
import fire
import httpx
from llama_stack.distribution.datatypes import RemoteProviderConfig
from termcolor import cprint
from .models import * # noqa: F403
@ -29,18 +25,18 @@ class ModelsClient(Models):
async def shutdown(self) -> None:
pass
async def list_models(self) -> ModelsListResponse:
async def list_models(self) -> List[ModelServingSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/models/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return ModelsListResponse(**response.json())
return [ModelServingSpec(**x) for x in response.json()]
async def get_model(self, core_model_id: str) -> ModelsGetResponse:
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
async with httpx.AsyncClient() as client:
response = await client.post(
response = await client.get(
f"{self.base_url}/models/get",
json={
"core_model_id": core_model_id,
@ -48,7 +44,10 @@ class ModelsClient(Models):
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return ModelsGetResponse(**response.json())
j = response.json()
if j is None:
return None
return ModelServingSpec(**j)
async def run_main(host: str, port: int, stream: bool):

View file

@ -4,14 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol
from typing import List, Optional, Protocol
from llama_models.llama3.api.datatypes import Model
from llama_models.schema_utils import json_schema_type, webmethod
from llama_stack.distribution.datatypes import GenericProviderConfig
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type
class ModelServingSpec(BaseModel):
@ -21,25 +22,11 @@ class ModelServingSpec(BaseModel):
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_id, and corresponding config. ",
)
api: str = Field(
description="The API that this model is serving (e.g. inference / safety).",
default="inference",
)
@json_schema_type
class ModelsListResponse(BaseModel):
models_list: List[ModelServingSpec]
@json_schema_type
class ModelsGetResponse(BaseModel):
core_model_spec: Optional[ModelServingSpec] = None
class Models(Protocol):
@webmethod(route="/models/list", method="GET")
async def list_models(self) -> ModelsListResponse: ...
async def list_models(self) -> List[ModelServingSpec]: ...
@webmethod(route="/models/get", method="POST")
async def get_model(self, core_model_id: str) -> ModelsGetResponse: ...
@webmethod(route="/models/get", method="GET")
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .shields import * # noqa: F401 F403

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import List, Optional
import fire
import httpx
from termcolor import cprint
from .shields import * # noqa: F403
class ShieldsClient(Shields):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def list_shields(self) -> List[ShieldSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/shields/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return [ShieldSpec(**x) for x in response.json()]
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/shields/get",
json={
"shield_type": shield_type,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
j = response.json()
if j is None:
return None
return ShieldSpec(**j)
async def run_main(host: str, port: int, stream: bool):
client = ShieldsClient(f"http://{host}:{port}")
response = await client.list_shields()
cprint(f"list_shields response={response}", "green")
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type
class ShieldSpec(BaseModel):
shield_type: str
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_id, and corresponding config. ",
)
class Shields(Protocol):
@webmethod(route="/shields/list", method="GET")
async def list_shields(self) -> List[ShieldSpec]: ...
@webmethod(route="/shields/get", method="GET")
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: ...