Further generalize Xi's changes

- introduce a slightly more general notion of an AutoRouted provider
- the AutoRouted provider is associated with a RoutingTable provider
- e.g. inference -> models
- Introduced safety -> shields and memory -> memory_banks
  correspondences
This commit is contained in:
Ashwin Bharambe 2024-09-22 12:06:43 -07:00
parent b8914bb56f
commit e1966b90d9
19 changed files with 559 additions and 388 deletions

View file

@ -5,15 +5,11 @@
# the root directory of this source tree.
import asyncio
import json
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import List, Optional
import fire
import httpx
from llama_stack.distribution.datatypes import RemoteProviderConfig
from termcolor import cprint
from .models import * # noqa: F403
@ -29,18 +25,18 @@ class ModelsClient(Models):
async def shutdown(self) -> None:
pass
async def list_models(self) -> ModelsListResponse:
async def list_models(self) -> List[ModelServingSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/models/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return ModelsListResponse(**response.json())
return [ModelServingSpec(**x) for x in response.json()]
async def get_model(self, core_model_id: str) -> ModelsGetResponse:
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
async with httpx.AsyncClient() as client:
response = await client.post(
response = await client.get(
f"{self.base_url}/models/get",
json={
"core_model_id": core_model_id,
@ -48,7 +44,10 @@ class ModelsClient(Models):
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return ModelsGetResponse(**response.json())
j = response.json()
if j is None:
return None
return ModelServingSpec(**j)
async def run_main(host: str, port: int, stream: bool):