Further generalize Xi's changes (#88)

* Further generalize Xi's changes

- introduce a slightly more general notion of an AutoRouted provider
- the AutoRouted provider is associated with a RoutingTable provider
- e.g. inference -> models
- Introduced safety -> shields and memory -> memory_banks
  correspondences

* typo

* Basic build and run succeeded
This commit is contained in:
Ashwin Bharambe 2024-09-22 16:31:18 -07:00 committed by GitHub
parent b8914bb56f
commit c1ab66f1e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 597 additions and 418 deletions

View file

@ -3,9 +3,5 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class BuiltinImplConfig(BaseModel): ...
from .memory_banks import * # noqa: F401 F403

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import List, Optional
import fire
import httpx
from termcolor import cprint
from .memory_banks import * # noqa: F403
class MemoryBanksClient(MemoryBanks):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def list_memory_banks(self) -> List[MemoryBankSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return [MemoryBankSpec(**x) for x in response.json()]
async def get_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/get",
json={
"bank_type": bank_type,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
j = response.json()
if j is None:
return None
return MemoryBankSpec(**j)
async def run_main(host: str, port: int, stream: bool):
client = MemoryBanksClient(f"http://{host}:{port}")
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.memory import MemoryBankType
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type
class MemoryBankSpec(BaseModel):
bank_type: MemoryBankType
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_id, and corresponding config. ",
)
class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBankSpec]: ...
@webmethod(route="/memory_banks/get", method="GET")
async def get_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]: ...

View file

@ -5,15 +5,11 @@
# the root directory of this source tree.
import asyncio
import json
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import List, Optional
import fire
import httpx
from llama_stack.distribution.datatypes import RemoteProviderConfig
from termcolor import cprint
from .models import * # noqa: F403
@ -29,18 +25,18 @@ class ModelsClient(Models):
async def shutdown(self) -> None:
pass
async def list_models(self) -> ModelsListResponse:
async def list_models(self) -> List[ModelServingSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/models/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return ModelsListResponse(**response.json())
return [ModelServingSpec(**x) for x in response.json()]
async def get_model(self, core_model_id: str) -> ModelsGetResponse:
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
async with httpx.AsyncClient() as client:
response = await client.post(
response = await client.get(
f"{self.base_url}/models/get",
json={
"core_model_id": core_model_id,
@ -48,7 +44,10 @@ class ModelsClient(Models):
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return ModelsGetResponse(**response.json())
j = response.json()
if j is None:
return None
return ModelServingSpec(**j)
async def run_main(host: str, port: int, stream: bool):

View file

@ -4,14 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol
from typing import List, Optional, Protocol
from llama_models.llama3.api.datatypes import Model
from llama_models.schema_utils import json_schema_type, webmethod
from llama_stack.distribution.datatypes import GenericProviderConfig
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type
class ModelServingSpec(BaseModel):
@ -21,25 +22,11 @@ class ModelServingSpec(BaseModel):
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_id, and corresponding config. ",
)
api: str = Field(
description="The API that this model is serving (e.g. inference / safety).",
default="inference",
)
@json_schema_type
class ModelsListResponse(BaseModel):
models_list: List[ModelServingSpec]
@json_schema_type
class ModelsGetResponse(BaseModel):
core_model_spec: Optional[ModelServingSpec] = None
class Models(Protocol):
@webmethod(route="/models/list", method="GET")
async def list_models(self) -> ModelsListResponse: ...
async def list_models(self) -> List[ModelServingSpec]: ...
@webmethod(route="/models/get", method="POST")
async def get_model(self, core_model_id: str) -> ModelsGetResponse: ...
@webmethod(route="/models/get", method="GET")
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .shields import * # noqa: F401 F403

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import List, Optional
import fire
import httpx
from termcolor import cprint
from .shields import * # noqa: F403
class ShieldsClient(Shields):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def list_shields(self) -> List[ShieldSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/shields/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return [ShieldSpec(**x) for x in response.json()]
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/shields/get",
json={
"shield_type": shield_type,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
j = response.json()
if j is None:
return None
return ShieldSpec(**j)
async def run_main(host: str, port: int, stream: bool):
client = ShieldsClient(f"http://{host}:{port}")
response = await client.list_shields()
cprint(f"list_shields response={response}", "green")
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type
class ShieldSpec(BaseModel):
shield_type: str
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_id, and corresponding config. ",
)
class Shields(Protocol):
@webmethod(route="/shields/list", method="GET")
async def list_shields(self) -> List[ShieldSpec]: ...
@webmethod(route="/shields/get", method="GET")
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: ...

View file

@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type
@ -19,8 +19,12 @@ class Api(Enum):
safety = "safety"
agents = "agents"
memory = "memory"
telemetry = "telemetry"
models = "models"
shields = "shields"
memory_banks = "memory_banks"
@json_schema_type
@ -44,27 +48,36 @@ class ProviderSpec(BaseModel):
)
class RoutingTable(Protocol):
def get_routing_keys(self) -> List[str]: ...
def get_provider_impl(self, routing_key: str) -> Any: ...
class GenericProviderConfig(BaseModel):
provider_id: str
config: Dict[str, Any]
@json_schema_type
class ProviderRoutingEntry(GenericProviderConfig):
class RoutableProviderConfig(GenericProviderConfig):
routing_key: str
class RoutingTableConfig(BaseModel):
entries: List[RoutableProviderConfig] = Field(...)
keys: Optional[List[str]] = Field(
default=None,
)
# Example: /inference, /safety
@json_schema_type
class RouterProviderSpec(ProviderSpec):
class AutoRoutedProviderSpec(ProviderSpec):
provider_id: str = "router"
config_class: str = ""
docker_image: Optional[str] = None
routing_table: List[ProviderRoutingEntry] = Field(
default_factory=list,
description="Routing table entries corresponding to the API",
)
routing_table_api: Api
module: str = Field(
...,
description="""
@ -79,18 +92,17 @@ class RouterProviderSpec(ProviderSpec):
@property
def pip_packages(self) -> List[str]:
raise AssertionError("Should not be called on RouterProviderSpec")
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields
@json_schema_type
class BuiltinProviderSpec(ProviderSpec):
provider_id: str = "builtin"
class RoutingTableProviderSpec(ProviderSpec):
provider_id: str = "routing_table"
config_class: str = ""
docker_image: Optional[str] = None
api_dependencies: List[Api] = []
provider_data_validator: Optional[str] = Field(
default=None,
)
inner_specs: List[ProviderSpec]
module: str = Field(
...,
description="""
@ -99,10 +111,7 @@ class BuiltinProviderSpec(ProviderSpec):
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
pip_packages: List[str] = Field(default_factory=list)
@json_schema_type
@ -130,10 +139,6 @@ Fully-qualified name of the module to import. The module is expected to have:
provider_data_validator: Optional[str] = Field(
default=None,
)
supported_model_ids: List[str] = Field(
default_factory=list,
description="The list of model ids that this adapter supports",
)
@json_schema_type
@ -243,9 +248,6 @@ in the runtime configuration to help route to the correct provider.""",
)
ProviderMapEntry = GenericProviderConfig
@json_schema_type
class StackRunConfig(BaseModel):
built_at: datetime
@ -269,25 +271,17 @@ this could be just a hash
description="""
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
)
provider_map: Dict[str, ProviderMapEntry] = Field(
api_providers: Dict[str, GenericProviderConfig] = Field(
description="""
Provider configurations for each of the APIs provided by this package.
Given an API, you can specify a single provider or a "routing table". Each entry in the routing
table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific.
As examples:
- the "inference" API interprets the routing_key as a "model"
- the "memory" API interprets the routing_key as the type of a "memory bank"
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
""",
)
provider_routing_table: Dict[str, List[ProviderRoutingEntry]] = Field(
routing_tables: Dict[str, RoutingTableConfig] = Field(
default_factory=dict,
description="""
API: List[ProviderRoutingEntry] map. Each ProviderRoutingEntry is a (routing_key, provider_config) tuple.
E.g. The following is a ProviderRoutingEntry for inference API:
E.g. The following is a ProviderRoutingEntry for models:
- routing_key: Meta-Llama3.1-8B-Instruct
provider_id: meta-reference
config:

View file

@ -8,11 +8,15 @@ import importlib
import inspect
from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
@ -30,6 +34,28 @@ def stack_apis() -> List[Api]:
return [v for v in Api]
class AutoRoutedApiInfo(BaseModel):
routing_table_api: Api
router_api: Api
def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
return [
AutoRoutedApiInfo(
routing_table_api=Api.models,
router_api=Api.inference,
),
AutoRoutedApiInfo(
routing_table_api=Api.shields,
router_api=Api.safety,
),
AutoRoutedApiInfo(
routing_table_api=Api.memory_banks,
router_api=Api.memory,
),
]
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
@ -40,6 +66,8 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
Api.memory: Memory,
Api.telemetry: Telemetry,
Api.models: Models,
Api.shields: Shields,
Api.memory_banks: MemoryBanks,
}
for api, protocol in protocols.items():
@ -68,7 +96,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {}
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
for api in stack_apis():
if api in routing_table_apis:
continue
name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {

View file

@ -4,25 +4,47 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Tuple
from typing import Any, List, Tuple
from llama_stack.distribution.datatypes import Api, ProviderRoutingEntry
from llama_stack.distribution.datatypes import * # noqa: F403
async def get_router_impl(
api: str, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]
):
from .routers import InferenceRouter, MemoryRouter
from .routing_table import RoutingTable
async def get_routing_table_impl(
api: Api,
inner_impls: List[Tuple[str, Any]],
routing_table_config: RoutingTableConfig,
_deps,
) -> Any:
from .routing_tables import (
MemoryBanksRoutingTable,
ModelsRoutingTable,
ShieldsRoutingTable,
)
api2routers = {
"memory": MemoryRouter,
"inference": InferenceRouter,
api_to_tables = {
"memory_banks": MemoryBanksRoutingTable,
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
}
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
# initialize routing table with concrete provider impls
routing_table = RoutingTable(provider_routing_table)
impl = api2routers[api](routing_table)
impl = api_to_tables[api.value](inner_impls, routing_table_config)
await impl.initialize()
return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
from .routers import InferenceRouter, MemoryRouter, SafetyRouter
api_to_routers = {
"memory": MemoryRouter,
"inference": InferenceRouter,
"safety": SafetyRouter,
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_routers[api.value](routing_table)
await impl.initialize()
return impl

View file

@ -4,17 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Tuple
from typing import Any, AsyncGenerator, Dict, List
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.datatypes import RoutingTable
from .routing_table import RoutingTable
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from types import MethodType
from termcolor import cprint
from llama_stack.apis.safety import * # noqa: F403
class MemoryRouter(Memory):
@ -24,22 +20,24 @@ class MemoryRouter(Memory):
self,
routing_table: RoutingTable,
) -> None:
self.api = Api.memory.value
self.routing_table = routing_table
self.bank_id_to_type = {}
async def initialize(self) -> None:
await self.routing_table.initialize(self.api)
pass
async def shutdown(self) -> None:
await self.routing_table.shutdown(self.api)
pass
def get_provider_from_bank_id(self, bank_id: str) -> Any:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
return self.routing_table.get_provider_impl(self.api, bank_type)
provider = self.routing_table.get_provider_impl(bank_type)
if not provider:
raise ValueError(f"Could not find provider for {bank_type}")
return provider
async def create_memory_bank(
self,
@ -48,14 +46,15 @@ class MemoryRouter(Memory):
url: Optional[URL] = None,
) -> MemoryBank:
bank_type = config.type
bank = await self.routing_table.get_provider_impl(
self.api, bank_type
provider = await self.routing_table.get_provider_impl(
bank_type
).create_memory_bank(name, config, url)
self.bank_id_to_type[bank.bank_id] = bank_type
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
return await self.get_provider_from_bank_id(bank_id).get_memory_bank(bank_id)
provider = self.get_provider_from_bank_id(bank_id)
return await provider.get_memory_bank(bank_id)
async def insert_documents(
self,
@ -85,34 +84,31 @@ class InferenceRouter(Inference):
self,
routing_table: RoutingTable,
) -> None:
self.api = Api.inference.value
self.routing_table = routing_table
async def initialize(self) -> None:
await self.routing_table.initialize(self.api)
pass
async def shutdown(self) -> None:
await self.routing_table.shutdown(self.api)
pass
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = [],
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
# TODO: we need to fix streaming response to align provider implementations with Protocol.
async for chunk in self.routing_table.get_provider_impl(
self.api, model
).chat_completion(
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
@ -128,7 +124,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
return await self.routing_table.get_provider_impl(self.api, model).completion(
return await self.routing_table.get_provider_impl(model).completion(
model=model,
content=content,
sampling_params=sampling_params,
@ -141,7 +137,33 @@ class InferenceRouter(Inference):
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
return await self.routing_table.get_provider_impl(self.api, model).embeddings(
return await self.routing_table.get_provider_impl(model).embeddings(
model=model,
contents=contents,
)
class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def run_shield(
self,
shield_type: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
return await self.routing_table.get_provider_impl(shield_type).run_shield(
shield_type=shield_type,
messages=messages,
params=params,
)

View file

@ -1,60 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from llama_stack.distribution.datatypes import (
Api,
GenericProviderConfig,
ProviderRoutingEntry,
)
from llama_stack.distribution.distribution import api_providers
from llama_stack.distribution.utils.dynamic import instantiate_provider
from termcolor import cprint
class RoutingTable:
def __init__(self, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]):
self.provider_routing_table = provider_routing_table
# map {api: {routing_key: impl}}, e.g. {'inference': {'8b': <MetaReferenceImpl>, '70b': <OllamaImpl>}}
self.api2routes = {}
async def initialize(self, api_str: str) -> None:
"""Initialize the routing table with concrete provider impls"""
if api_str not in self.provider_routing_table:
raise ValueError(f"API {api_str} not found in routing table")
providers = api_providers()[Api(api_str)]
routing_list = self.provider_routing_table[api_str]
self.api2routes[api_str] = {}
for rt_entry in routing_list:
rt_key = rt_entry.routing_key
provider_id = rt_entry.provider_id
impl = await instantiate_provider(
providers[provider_id],
deps=[],
provider_config=GenericProviderConfig(
provider_id=provider_id, config=rt_entry.config
),
)
cprint(f"impl = {impl}", "red")
self.api2routes[api_str][rt_key] = impl
cprint(f"> Initialized implementations for {api_str} in routing table", "blue")
async def shutdown(self, api_str: str) -> None:
"""Shutdown the routing table"""
if api_str not in self.api2routes:
return
for impl in self.api2routes[api_str].values():
await impl.shutdown()
def get_provider_impl(self, api: str, routing_key: str) -> Any:
"""Get the provider impl for a given api and routing key"""
return self.api2routes[api][routing_key]

View file

@ -0,0 +1,118 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Optional, Tuple
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
inner_impls: List[Tuple[str, Any]],
routing_table_config: RoutingTableConfig,
) -> None:
self.providers = {k: v for k, v in inner_impls}
self.routing_keys = list(self.providers.keys())
self.routing_table_config = routing_table_config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
for p in self.providers.values():
await p.shutdown()
async def get_provider_impl(self, routing_key: str) -> Optional[Any]:
return self.providers.get(routing_key)
async def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def get_provider_config(
self, routing_key: str
) -> Optional[GenericProviderConfig]:
for entry in self.routing_table_config.entries:
if entry.routing_key == routing_key:
return entry
return None
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelServingSpec]:
specs = []
for entry in self.routing_table_config.entries:
model_id = entry.routing_key
specs.append(
ModelServingSpec(
llama_model=resolve_model(model_id),
provider_config=entry,
)
)
return specs
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
for entry in self.routing_table_config.entries:
if entry.routing_key == core_model_id:
return ModelServingSpec(
llama_model=resolve_model(core_model_id),
provider_config=entry,
)
return None
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldSpec]:
specs = []
for entry in self.routing_table_config.entries:
specs.append(
ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
)
)
return specs
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
for entry in self.routing_table_config.entries:
if entry.routing_key == shield_type:
return ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
)
return None
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankSpec]:
specs = []
for entry in self.routing_table_config.entries:
specs.append(
MemoryBankSpec(
bank_type=entry.routing_key,
provider_config=entry,
)
)
return specs
async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
for entry in self.routing_table_config.entries:
if entry.routing_key == bank_type:
return MemoryBankSpec(
bank_type=entry.routing_key,
provider_config=entry,
)
return None

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import importlib
import inspect
import json
import signal
@ -36,6 +35,9 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -43,18 +45,15 @@ from llama_stack.providers.utils.telemetry.tracing import (
SpanStatus,
start_trace,
)
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import (
instantiate_builtin_provider,
instantiate_provider,
instantiate_router,
from llama_stack.distribution.distribution import (
api_endpoints,
api_providers,
builtin_automatically_routed_apis,
)
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import instantiate_provider
def is_async_iterator_type(typ):
@ -292,9 +291,7 @@ def snake_to_camel(snake_str):
return "".join(word.capitalize() for word in snake_str.split("_"))
async def resolve_impls_with_routing(
stack_run_config: StackRunConfig,
) -> Dict[Api, Any]:
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
@ -302,48 +299,80 @@ async def resolve_impls_with_routing(
"""
all_providers = api_providers()
specs = {}
configs = {}
for api_str in stack_run_config.apis_to_serve:
for api_str, config in run_config.api_providers.items():
api = Api(api_str)
# TODO: check that these APIs are not in the routing table part of the config
providers = all_providers[api]
# check for routing table, we need to pass routing table to the router implementation
if api_str in stack_run_config.provider_routing_table:
specs[api] = RouterProviderSpec(
api=api,
module=f"llama_stack.distribution.routers",
api_dependencies=[],
routing_table=stack_run_config.provider_routing_table[api_str],
)
else:
if api_str in stack_run_config.provider_map:
provider_map_entry = stack_run_config.provider_map[api_str]
provider_id = provider_map_entry.provider_id
else:
# not defined in config, will be a builtin provider, assign builtin provider id
provider_id = "builtin"
if provider_id not in providers:
if config.provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
f"Unknown provider `{config.provider_id}` is not available for API `{api}`"
)
specs[api] = providers[provider_id]
specs[api] = providers[config.provider_id]
configs[api] = config
apis_to_serve = run_config.apis_to_serve or set(
list(specs.keys()) + list(run_config.routing_tables.keys())
)
print("apis_to_serve", apis_to_serve)
for info in builtin_automatically_routed_apis():
source_api = info.routing_table_api
assert (
source_api not in specs
), f"Routing table API {source_api} specified in wrong place?"
assert (
info.router_api not in specs
), f"Auto-routed API {info.router_api} specified in wrong place?"
if info.router_api.value not in apis_to_serve:
continue
if source_api.value not in run_config.routing_tables:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
routing_table = run_config.routing_tables[source_api.value]
providers = all_providers[info.router_api]
inner_specs = []
for rt_entry in routing_table.entries:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_id])
specs[source_api] = RoutingTableProviderSpec(
api=source_api,
module="llama_stack.distribution.routers",
api_dependencies=[],
inner_specs=inner_specs,
)
configs[source_api] = routing_table
specs[info.router_api] = AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=source_api,
api_dependencies=[source_api],
)
configs[info.router_api] = {}
sorted_specs = topological_sort(specs.values())
print(f"Resolved {len(sorted_specs)} providers in topological order")
for spec in sorted_specs:
print(f" {spec.api}: {spec.provider_id}")
print("")
impls = {}
for spec in sorted_specs:
api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies}
if api.value in stack_run_config.provider_map:
provider_config = stack_run_config.provider_map[api.value]
impl = await instantiate_provider(spec, deps, provider_config)
elif api.value in stack_run_config.provider_routing_table:
impl = await instantiate_router(
spec, api.value, stack_run_config.provider_routing_table
)
else:
impl = await instantiate_builtin_provider(spec, stack_run_config)
impl = await instantiate_provider(spec, deps, configs[api])
impls[api] = impl
return impls, specs
@ -355,16 +384,23 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
app = FastAPI()
# impls, specs = asyncio.run(resolve_impls(config.provider_map))
impls, specs = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
all_endpoints = api_endpoints()
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
if config.apis_to_serve:
apis_to_serve = set(config.apis_to_serve)
for inf in builtin_automatically_routed_apis():
if inf.router_api.value in apis_to_serve:
apis_to_serve.add(inf.routing_table_api)
else:
apis_to_serve = set(impls.keys())
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api]
impl = impls[api]
@ -391,7 +427,11 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
create_dynamic_typed_route(
impl_method,
endpoint.method,
provider_spec.provider_data_validator,
(
provider_spec.provider_data_validator
if not isinstance(provider_spec, RoutingTableProviderSpec)
else None
),
)
)

View file

@ -16,36 +16,11 @@ def instantiate_class_type(fully_qualified_name):
return getattr(module, class_name)
async def instantiate_router(
provider_spec: RouterProviderSpec,
api: str,
provider_routing_table: Dict[str, Any],
):
module = importlib.import_module(provider_spec.module)
fn = getattr(module, "get_router_impl")
impl = await fn(api, provider_routing_table)
impl.__provider_spec__ = provider_spec
return impl
async def instantiate_builtin_provider(
provider_spec: BuiltinProviderSpec,
run_config: StackRunConfig,
):
print("!!! instantiate_builtin_provider")
module = importlib.import_module(provider_spec.module)
fn = getattr(module, "get_builtin_impl")
impl = await fn(run_config)
impl.__provider_spec__ = provider_spec
return impl
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
deps: Dict[str, Any],
provider_config: ProviderMapEntry,
provider_config: Union[GenericProviderConfig, RoutingTable],
):
module = importlib.import_module(provider_spec.module)
@ -60,6 +35,29 @@ async def instantiate_provider(
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
assert isinstance(provider_config, RoutingTableConfig)
routing_table = provider_config
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in routing_table.entries:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_id],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [provider_spec.api, inner_impls, routing_table, deps]
else:
method = "get_provider_impl"

View file

@ -1,21 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec, StackRunConfig
from .config import BuiltinImplConfig # noqa
async def get_builtin_impl(config: StackRunConfig):
from .models import BuiltinModelsImpl
assert isinstance(config, StackRunConfig), f"Unexpected config type: {type(config)}"
impl = BuiltinModelsImpl(config)
await impl.initialize()
return impl

View file

@ -1,113 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import AsyncIterator, Union
from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from llama_stack.distribution.distribution import Api, api_providers
from llama_stack.apis.models import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.datatypes import CoreModelId, Model
from llama_models.sku_list import resolve_model
from llama_stack.distribution.datatypes import * # noqa: F403
from termcolor import cprint
class BuiltinModelsImpl(Models):
def __init__(
self,
config: StackRunConfig,
) -> None:
self.run_config = config
self.models = {}
# check against inference & safety api
apis_with_models = [Api.inference, Api.safety]
all_providers = api_providers()
for api in apis_with_models:
# check against provider_map (simple case single model)
if api.value in config.provider_map:
providers_for_api = all_providers[api]
provider_spec = config.provider_map[api.value]
core_model_id = provider_spec.config
# get supported model ids from the provider
supported_model_ids = self.get_supported_model_ids(
api.value, provider_spec, providers_for_api
)
for model_id in supported_model_ids:
self.models[model_id] = ModelServingSpec(
llama_model=resolve_model(model_id),
provider_config=provider_spec,
api=api.value,
)
# check against provider_routing_table (router with multiple models)
# with routing table, we use the routing_key as the supported models
if api.value in config.provider_routing_table:
routing_table = config.provider_routing_table[api.value]
for rt_entry in routing_table:
model_id = rt_entry.routing_key
self.models[model_id] = ModelServingSpec(
llama_model=resolve_model(model_id),
provider_config=GenericProviderConfig(
provider_id=rt_entry.provider_id,
config=rt_entry.config,
),
api=api.value,
)
print("BuiltinModelsImpl models", self.models)
def get_supported_model_ids(
self,
api: str,
provider_spec: GenericProviderConfig,
providers_for_api: Dict[str, ProviderSpec],
) -> List[str]:
serving_models_list = []
if api == Api.inference.value:
provider_id = provider_spec.provider_id
if provider_id == "meta-reference":
serving_models_list.append(provider_spec.config["model"])
if provider_id in {
remote_provider_id("ollama"),
remote_provider_id("fireworks"),
remote_provider_id("together"),
}:
adapter_supported_models = providers_for_api[
provider_id
].adapter.supported_model_ids
serving_models_list.extend(adapter_supported_models)
elif api == Api.safety.value:
if provider_spec.config and "llama_guard_shield" in provider_spec.config:
llama_guard_shield = provider_spec.config["llama_guard_shield"]
serving_models_list.append(llama_guard_shield["model"])
if provider_spec.config and "prompt_guard_shield" in provider_spec.config:
prompt_guard_shield = provider_spec.config["prompt_guard_shield"]
serving_models_list.append(prompt_guard_shield["model"])
else:
raise NotImplementedError(f"Unsupported api {api} for builtin models")
return serving_models_list
async def initialize(self) -> None:
pass
async def list_models(self) -> ModelsListResponse:
return ModelsListResponse(models_list=list(self.models.values()))
async def get_model(self, core_model_id: str) -> ModelsGetResponse:
if core_model_id in self.models:
return ModelsGetResponse(core_model_spec=self.models[core_model_id])
print(f"Cannot find {core_model_id} in model registry")
return ModelsGetResponse()

View file

@ -6,17 +6,14 @@
from typing import Optional
from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import all_registered_models, resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator
from llama_stack.apis.inference import QuantizationConfig
@json_schema_type
class MetaReferenceImplConfig(BaseModel):
model: str = Field(
default="Meta-Llama3.1-8B-Instruct",
@ -34,6 +31,7 @@ class MetaReferenceImplConfig(BaseModel):
m.descriptor()
for m in all_registered_models()
if m.model_family == ModelFamily.llama3_1
or m.core_model_id == CoreModelId.llama_guard_3_8b
]
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)

View file

@ -32,10 +32,6 @@ def available_providers() -> List[ProviderSpec]:
adapter_id="ollama",
pip_packages=["ollama"],
module="llama_stack.providers.adapters.inference.ollama",
supported_model_ids=[
"Meta-Llama3.1-8B-Instruct",
"Meta-Llama3.1-70B-Instruct",
],
),
),
remote_provider_spec(
@ -56,11 +52,6 @@ def available_providers() -> List[ProviderSpec]:
],
module="llama_stack.providers.adapters.inference.fireworks",
config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig",
supported_model_ids=[
"Meta-Llama3.1-8B-Instruct",
"Meta-Llama3.1-70B-Instruct",
"Meta-Llama3.1-405B-Instruct",
],
),
),
remote_provider_spec(
@ -73,11 +64,6 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.adapters.inference.together",
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
supported_model_ids=[
"Meta-Llama3.1-8B-Instruct",
"Meta-Llama3.1-70B-Instruct",
"Meta-Llama3.1-405B-Instruct",
],
),
),
]

View file

@ -1,22 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
BuiltinProviderSpec(
api=Api.models,
provider_id="builtin",
pip_packages=[],
module="llama_stack.providers.impls.builtin.models",
config_class="llama_stack.providers.impls.builtin.models.BuiltinImplConfig",
api_dependencies=[],
)
]