Further generalize Xi's changes (#88)

* Further generalize Xi's changes

- introduce a slightly more general notion of an AutoRouted provider
- the AutoRouted provider is associated with a RoutingTable provider
- e.g. inference -> models
- Introduced safety -> shields and memory -> memory_banks
  correspondences

* typo

* Basic build and run succeeded
This commit is contained in:
Ashwin Bharambe 2024-09-22 16:31:18 -07:00 committed by GitHub
parent b8914bb56f
commit c1ab66f1e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 597 additions and 418 deletions

View file

@ -3,9 +3,5 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from .memory_banks import * # noqa: F401 F403
@json_schema_type
class BuiltinImplConfig(BaseModel): ...

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import List, Optional
import fire
import httpx
from termcolor import cprint
from .memory_banks import * # noqa: F403
class MemoryBanksClient(MemoryBanks):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def list_memory_banks(self) -> List[MemoryBankSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return [MemoryBankSpec(**x) for x in response.json()]
async def get_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/memory_banks/get",
json={
"bank_type": bank_type,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
j = response.json()
if j is None:
return None
return MemoryBankSpec(**j)
async def run_main(host: str, port: int, stream: bool):
client = MemoryBanksClient(f"http://{host}:{port}")
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.memory import MemoryBankType
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type
class MemoryBankSpec(BaseModel):
bank_type: MemoryBankType
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_id, and corresponding config. ",
)
class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBankSpec]: ...
@webmethod(route="/memory_banks/get", method="GET")
async def get_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]: ...

View file

@ -5,15 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import json
from pathlib import Path
from typing import Any, Dict, List, Optional from typing import List, Optional
import fire import fire
import httpx import httpx
from llama_stack.distribution.datatypes import RemoteProviderConfig
from termcolor import cprint from termcolor import cprint
from .models import * # noqa: F403 from .models import * # noqa: F403
@ -29,18 +25,18 @@ class ModelsClient(Models):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_models(self) -> ModelsListResponse: async def list_models(self) -> List[ModelServingSpec]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/models/list", f"{self.base_url}/models/list",
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
return ModelsListResponse(**response.json()) return [ModelServingSpec(**x) for x in response.json()]
async def get_model(self, core_model_id: str) -> ModelsGetResponse: async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.get(
f"{self.base_url}/models/get", f"{self.base_url}/models/get",
json={ json={
"core_model_id": core_model_id, "core_model_id": core_model_id,
@ -48,7 +44,10 @@ class ModelsClient(Models):
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
return ModelsGetResponse(**response.json()) j = response.json()
if j is None:
return None
return ModelServingSpec(**j)
async def run_main(host: str, port: int, stream: bool): async def run_main(host: str, port: int, stream: bool):

View file

@ -4,14 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol from typing import List, Optional, Protocol
from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.datatypes import Model
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from llama_stack.distribution.datatypes import GenericProviderConfig
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type @json_schema_type
class ModelServingSpec(BaseModel): class ModelServingSpec(BaseModel):
@ -21,25 +22,11 @@ class ModelServingSpec(BaseModel):
provider_config: GenericProviderConfig = Field( provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_id, and corresponding config. ", description="Provider config for the model, including provider_id, and corresponding config. ",
) )
api: str = Field(
description="The API that this model is serving (e.g. inference / safety).",
default="inference",
)
@json_schema_type
class ModelsListResponse(BaseModel):
models_list: List[ModelServingSpec]
@json_schema_type
class ModelsGetResponse(BaseModel):
core_model_spec: Optional[ModelServingSpec] = None
class Models(Protocol): class Models(Protocol):
@webmethod(route="/models/list", method="GET") @webmethod(route="/models/list", method="GET")
async def list_models(self) -> ModelsListResponse: ... async def list_models(self) -> List[ModelServingSpec]: ...
@webmethod(route="/models/get", method="POST") @webmethod(route="/models/get", method="GET")
async def get_model(self, core_model_id: str) -> ModelsGetResponse: ... async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .shields import * # noqa: F401 F403

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import List, Optional
import fire
import httpx
from termcolor import cprint
from .shields import * # noqa: F403
class ShieldsClient(Shields):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def list_shields(self) -> List[ShieldSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/shields/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return [ShieldSpec(**x) for x in response.json()]
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/shields/get",
json={
"shield_type": shield_type,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
j = response.json()
if j is None:
return None
return ShieldSpec(**j)
async def run_main(host: str, port: int, stream: bool):
client = ShieldsClient(f"http://{host}:{port}")
response = await client.list_shields()
cprint(f"list_shields response={response}", "green")
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import GenericProviderConfig
@json_schema_type
class ShieldSpec(BaseModel):
shield_type: str
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_id, and corresponding config. ",
)
class Shields(Protocol):
@webmethod(route="/shields/list", method="GET")
async def list_shields(self) -> List[ShieldSpec]: ...
@webmethod(route="/shields/get", method="GET")
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: ...

View file

@ -6,7 +6,7 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
@ -19,8 +19,12 @@ class Api(Enum):
safety = "safety" safety = "safety"
agents = "agents" agents = "agents"
memory = "memory" memory = "memory"
telemetry = "telemetry" telemetry = "telemetry"
models = "models" models = "models"
shields = "shields"
memory_banks = "memory_banks"
@json_schema_type @json_schema_type
@ -44,27 +48,36 @@ class ProviderSpec(BaseModel):
) )
class RoutingTable(Protocol):
def get_routing_keys(self) -> List[str]: ...
def get_provider_impl(self, routing_key: str) -> Any: ...
class GenericProviderConfig(BaseModel): class GenericProviderConfig(BaseModel):
provider_id: str provider_id: str
config: Dict[str, Any] config: Dict[str, Any]
@json_schema_type class RoutableProviderConfig(GenericProviderConfig):
class ProviderRoutingEntry(GenericProviderConfig):
routing_key: str routing_key: str
class RoutingTableConfig(BaseModel):
entries: List[RoutableProviderConfig] = Field(...)
keys: Optional[List[str]] = Field(
default=None,
)
# Example: /inference, /safety
@json_schema_type @json_schema_type
class RouterProviderSpec(ProviderSpec): class AutoRoutedProviderSpec(ProviderSpec):
provider_id: str = "router" provider_id: str = "router"
config_class: str = "" config_class: str = ""
docker_image: Optional[str] = None docker_image: Optional[str] = None
routing_table_api: Api
routing_table: List[ProviderRoutingEntry] = Field(
default_factory=list,
description="Routing table entries corresponding to the API",
)
module: str = Field( module: str = Field(
..., ...,
description=""" description="""
@ -79,18 +92,17 @@ class RouterProviderSpec(ProviderSpec):
@property @property
def pip_packages(self) -> List[str]: def pip_packages(self) -> List[str]:
raise AssertionError("Should not be called on RouterProviderSpec") raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields
@json_schema_type @json_schema_type
class BuiltinProviderSpec(ProviderSpec): class RoutingTableProviderSpec(ProviderSpec):
provider_id: str = "builtin" provider_id: str = "routing_table"
config_class: str = "" config_class: str = ""
docker_image: Optional[str] = None docker_image: Optional[str] = None
api_dependencies: List[Api] = []
provider_data_validator: Optional[str] = Field( inner_specs: List[ProviderSpec]
default=None,
)
module: str = Field( module: str = Field(
..., ...,
description=""" description="""
@ -99,10 +111,7 @@ class BuiltinProviderSpec(ProviderSpec):
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation - `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""", """,
) )
pip_packages: List[str] = Field( pip_packages: List[str] = Field(default_factory=list)
default_factory=list,
description="The pip dependencies needed for this implementation",
)
@json_schema_type @json_schema_type
@ -130,10 +139,6 @@ Fully-qualified name of the module to import. The module is expected to have:
provider_data_validator: Optional[str] = Field( provider_data_validator: Optional[str] = Field(
default=None, default=None,
) )
supported_model_ids: List[str] = Field(
default_factory=list,
description="The list of model ids that this adapter supports",
)
@json_schema_type @json_schema_type
@ -243,9 +248,6 @@ in the runtime configuration to help route to the correct provider.""",
) )
ProviderMapEntry = GenericProviderConfig
@json_schema_type @json_schema_type
class StackRunConfig(BaseModel): class StackRunConfig(BaseModel):
built_at: datetime built_at: datetime
@ -269,25 +271,17 @@ this could be just a hash
description=""" description="""
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""", The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
) )
provider_map: Dict[str, ProviderMapEntry] = Field(
api_providers: Dict[str, GenericProviderConfig] = Field(
description=""" description="""
Provider configurations for each of the APIs provided by this package. Provider configurations for each of the APIs provided by this package.
""",
Given an API, you can specify a single provider or a "routing table". Each entry in the routing
table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific.
As examples:
- the "inference" API interprets the routing_key as a "model"
- the "memory" API interprets the routing_key as the type of a "memory bank"
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
) )
provider_routing_table: Dict[str, List[ProviderRoutingEntry]] = Field( routing_tables: Dict[str, RoutingTableConfig] = Field(
default_factory=dict, default_factory=dict,
description=""" description="""
API: List[ProviderRoutingEntry] map. Each ProviderRoutingEntry is a (routing_key, provider_config) tuple.
E.g. The following is a ProviderRoutingEntry for inference API: E.g. The following is a ProviderRoutingEntry for models:
- routing_key: Meta-Llama3.1-8B-Instruct - routing_key: Meta-Llama3.1-8B-Instruct
provider_id: meta-reference provider_id: meta-reference
config: config:

View file

@ -8,11 +8,15 @@ import importlib
import inspect import inspect
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.telemetry import Telemetry
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
@ -30,6 +34,28 @@ def stack_apis() -> List[Api]:
return [v for v in Api] return [v for v in Api]
class AutoRoutedApiInfo(BaseModel):
routing_table_api: Api
router_api: Api
def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
return [
AutoRoutedApiInfo(
routing_table_api=Api.models,
router_api=Api.inference,
),
AutoRoutedApiInfo(
routing_table_api=Api.shields,
router_api=Api.safety,
),
AutoRoutedApiInfo(
routing_table_api=Api.memory_banks,
router_api=Api.memory,
),
]
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {} apis = {}
@ -40,6 +66,8 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
Api.memory: Memory, Api.memory: Memory,
Api.telemetry: Telemetry, Api.telemetry: Telemetry,
Api.models: Models, Api.models: Models,
Api.shields: Shields,
Api.memory_banks: MemoryBanks,
} }
for api, protocol in protocols.items(): for api, protocol in protocols.items():
@ -68,7 +96,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]: def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {} ret = {}
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
for api in stack_apis(): for api in stack_apis():
if api in routing_table_apis:
continue
name = api.name.lower() name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}") module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = { ret[api] = {

View file

@ -4,25 +4,47 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Tuple from typing import Any, List, Tuple
from llama_stack.distribution.datatypes import Api, ProviderRoutingEntry from llama_stack.distribution.datatypes import * # noqa: F403
async def get_router_impl( async def get_routing_table_impl(
api: str, provider_routing_table: Dict[str, List[ProviderRoutingEntry]] api: Api,
): inner_impls: List[Tuple[str, Any]],
from .routers import InferenceRouter, MemoryRouter routing_table_config: RoutingTableConfig,
from .routing_table import RoutingTable _deps,
) -> Any:
from .routing_tables import (
MemoryBanksRoutingTable,
ModelsRoutingTable,
ShieldsRoutingTable,
)
api2routers = { api_to_tables = {
"memory": MemoryRouter, "memory_banks": MemoryBanksRoutingTable,
"inference": InferenceRouter, "models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
} }
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
# initialize routing table with concrete provider impls impl = api_to_tables[api.value](inner_impls, routing_table_config)
routing_table = RoutingTable(provider_routing_table) await impl.initialize()
return impl
impl = api2routers[api](routing_table)
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
from .routers import InferenceRouter, MemoryRouter, SafetyRouter
api_to_routers = {
"memory": MemoryRouter,
"inference": InferenceRouter,
"safety": SafetyRouter,
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_routers[api.value](routing_table)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -4,17 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Tuple from typing import Any, AsyncGenerator, Dict, List
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import RoutingTable
from .routing_table import RoutingTable
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from types import MethodType
from termcolor import cprint
class MemoryRouter(Memory): class MemoryRouter(Memory):
@ -24,22 +20,24 @@ class MemoryRouter(Memory):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
self.api = Api.memory.value
self.routing_table = routing_table self.routing_table = routing_table
self.bank_id_to_type = {} self.bank_id_to_type = {}
async def initialize(self) -> None: async def initialize(self) -> None:
await self.routing_table.initialize(self.api) pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
await self.routing_table.shutdown(self.api) pass
def get_provider_from_bank_id(self, bank_id: str) -> Any: def get_provider_from_bank_id(self, bank_id: str) -> Any:
bank_type = self.bank_id_to_type.get(bank_id) bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type: if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}") raise ValueError(f"Could not find bank type for {bank_id}")
return self.routing_table.get_provider_impl(self.api, bank_type) provider = self.routing_table.get_provider_impl(bank_type)
if not provider:
raise ValueError(f"Could not find provider for {bank_type}")
return provider
async def create_memory_bank( async def create_memory_bank(
self, self,
@ -48,14 +46,15 @@ class MemoryRouter(Memory):
url: Optional[URL] = None, url: Optional[URL] = None,
) -> MemoryBank: ) -> MemoryBank:
bank_type = config.type bank_type = config.type
bank = await self.routing_table.get_provider_impl( provider = await self.routing_table.get_provider_impl(
self.api, bank_type bank_type
).create_memory_bank(name, config, url) ).create_memory_bank(name, config, url)
self.bank_id_to_type[bank.bank_id] = bank_type self.bank_id_to_type[bank.bank_id] = bank_type
return bank return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
return await self.get_provider_from_bank_id(bank_id).get_memory_bank(bank_id) provider = self.get_provider_from_bank_id(bank_id)
return await provider.get_memory_bank(bank_id)
async def insert_documents( async def insert_documents(
self, self,
@ -85,34 +84,31 @@ class InferenceRouter(Inference):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
self.api = Api.inference.value
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
await self.routing_table.initialize(self.api) pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
await self.routing_table.shutdown(self.api) pass
async def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = [], tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
# TODO: we need to fix streaming response to align provider implementations with Protocol. # TODO: we need to fix streaming response to align provider implementations with Protocol.
async for chunk in self.routing_table.get_provider_impl( async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
self.api, model
).chat_completion(
model=model, model=model,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools, tools=tools or [],
tool_choice=tool_choice, tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,
@ -128,7 +124,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
return await self.routing_table.get_provider_impl(self.api, model).completion( return await self.routing_table.get_provider_impl(model).completion(
model=model, model=model,
content=content, content=content,
sampling_params=sampling_params, sampling_params=sampling_params,
@ -141,7 +137,33 @@ class InferenceRouter(Inference):
model: str, model: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
return await self.routing_table.get_provider_impl(self.api, model).embeddings( return await self.routing_table.get_provider_impl(model).embeddings(
model=model, model=model,
contents=contents, contents=contents,
) )
class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def run_shield(
self,
shield_type: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
return await self.routing_table.get_provider_impl(shield_type).run_shield(
shield_type=shield_type,
messages=messages,
params=params,
)

View file

@ -1,60 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from llama_stack.distribution.datatypes import (
Api,
GenericProviderConfig,
ProviderRoutingEntry,
)
from llama_stack.distribution.distribution import api_providers
from llama_stack.distribution.utils.dynamic import instantiate_provider
from termcolor import cprint
class RoutingTable:
def __init__(self, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]):
self.provider_routing_table = provider_routing_table
# map {api: {routing_key: impl}}, e.g. {'inference': {'8b': <MetaReferenceImpl>, '70b': <OllamaImpl>}}
self.api2routes = {}
async def initialize(self, api_str: str) -> None:
"""Initialize the routing table with concrete provider impls"""
if api_str not in self.provider_routing_table:
raise ValueError(f"API {api_str} not found in routing table")
providers = api_providers()[Api(api_str)]
routing_list = self.provider_routing_table[api_str]
self.api2routes[api_str] = {}
for rt_entry in routing_list:
rt_key = rt_entry.routing_key
provider_id = rt_entry.provider_id
impl = await instantiate_provider(
providers[provider_id],
deps=[],
provider_config=GenericProviderConfig(
provider_id=provider_id, config=rt_entry.config
),
)
cprint(f"impl = {impl}", "red")
self.api2routes[api_str][rt_key] = impl
cprint(f"> Initialized implementations for {api_str} in routing table", "blue")
async def shutdown(self, api_str: str) -> None:
"""Shutdown the routing table"""
if api_str not in self.api2routes:
return
for impl in self.api2routes[api_str].values():
await impl.shutdown()
def get_provider_impl(self, api: str, routing_key: str) -> Any:
"""Get the provider impl for a given api and routing key"""
return self.api2routes[api][routing_key]

View file

@ -0,0 +1,118 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Optional, Tuple
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
inner_impls: List[Tuple[str, Any]],
routing_table_config: RoutingTableConfig,
) -> None:
self.providers = {k: v for k, v in inner_impls}
self.routing_keys = list(self.providers.keys())
self.routing_table_config = routing_table_config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
for p in self.providers.values():
await p.shutdown()
async def get_provider_impl(self, routing_key: str) -> Optional[Any]:
return self.providers.get(routing_key)
async def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def get_provider_config(
self, routing_key: str
) -> Optional[GenericProviderConfig]:
for entry in self.routing_table_config.entries:
if entry.routing_key == routing_key:
return entry
return None
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelServingSpec]:
specs = []
for entry in self.routing_table_config.entries:
model_id = entry.routing_key
specs.append(
ModelServingSpec(
llama_model=resolve_model(model_id),
provider_config=entry,
)
)
return specs
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
for entry in self.routing_table_config.entries:
if entry.routing_key == core_model_id:
return ModelServingSpec(
llama_model=resolve_model(core_model_id),
provider_config=entry,
)
return None
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldSpec]:
specs = []
for entry in self.routing_table_config.entries:
specs.append(
ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
)
)
return specs
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
for entry in self.routing_table_config.entries:
if entry.routing_key == shield_type:
return ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
)
return None
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankSpec]:
specs = []
for entry in self.routing_table_config.entries:
specs.append(
MemoryBankSpec(
bank_type=entry.routing_key,
provider_config=entry,
)
)
return specs
async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
for entry in self.routing_table_config.entries:
if entry.routing_key == bank_type:
return MemoryBankSpec(
bank_type=entry.routing_key,
provider_config=entry,
)
return None

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import importlib
import inspect import inspect
import json import json
import signal import signal
@ -36,6 +35,9 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
end_trace, end_trace,
@ -43,18 +45,15 @@ from llama_stack.providers.utils.telemetry.tracing import (
SpanStatus, SpanStatus,
start_trace, start_trace,
) )
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers from llama_stack.distribution.distribution import (
from llama_stack.distribution.request_headers import set_request_provider_data api_endpoints,
from llama_stack.distribution.utils.dynamic import ( api_providers,
instantiate_builtin_provider, builtin_automatically_routed_apis,
instantiate_provider,
instantiate_router,
) )
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import instantiate_provider
def is_async_iterator_type(typ): def is_async_iterator_type(typ):
@ -292,9 +291,7 @@ def snake_to_camel(snake_str):
return "".join(word.capitalize() for word in snake_str.split("_")) return "".join(word.capitalize() for word in snake_str.split("_"))
async def resolve_impls_with_routing( async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
stack_run_config: StackRunConfig,
) -> Dict[Api, Any]:
""" """
Does two things: Does two things:
- flatmaps, sorts and resolves the providers in dependency order - flatmaps, sorts and resolves the providers in dependency order
@ -302,48 +299,80 @@ async def resolve_impls_with_routing(
""" """
all_providers = api_providers() all_providers = api_providers()
specs = {} specs = {}
configs = {}
for api_str in stack_run_config.apis_to_serve: for api_str, config in run_config.api_providers.items():
api = Api(api_str) api = Api(api_str)
# TODO: check that these APIs are not in the routing table part of the config
providers = all_providers[api] providers = all_providers[api]
# check for routing table, we need to pass routing table to the router implementation if config.provider_id not in providers:
if api_str in stack_run_config.provider_routing_table: raise ValueError(
specs[api] = RouterProviderSpec( f"Unknown provider `{config.provider_id}` is not available for API `{api}`"
api=api,
module=f"llama_stack.distribution.routers",
api_dependencies=[],
routing_table=stack_run_config.provider_routing_table[api_str],
) )
else: specs[api] = providers[config.provider_id]
if api_str in stack_run_config.provider_map: configs[api] = config
provider_map_entry = stack_run_config.provider_map[api_str]
provider_id = provider_map_entry.provider_id
else:
# not defined in config, will be a builtin provider, assign builtin provider id
provider_id = "builtin"
if provider_id not in providers: apis_to_serve = run_config.apis_to_serve or set(
list(specs.keys()) + list(run_config.routing_tables.keys())
)
print("apis_to_serve", apis_to_serve)
for info in builtin_automatically_routed_apis():
source_api = info.routing_table_api
assert (
source_api not in specs
), f"Routing table API {source_api} specified in wrong place?"
assert (
info.router_api not in specs
), f"Auto-routed API {info.router_api} specified in wrong place?"
if info.router_api.value not in apis_to_serve:
continue
if source_api.value not in run_config.routing_tables:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
routing_table = run_config.routing_tables[source_api.value]
providers = all_providers[info.router_api]
inner_specs = []
for rt_entry in routing_table.entries:
if rt_entry.provider_id not in providers:
raise ValueError( raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`" f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
) )
specs[api] = providers[provider_id] inner_specs.append(providers[rt_entry.provider_id])
specs[source_api] = RoutingTableProviderSpec(
api=source_api,
module="llama_stack.distribution.routers",
api_dependencies=[],
inner_specs=inner_specs,
)
configs[source_api] = routing_table
specs[info.router_api] = AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=source_api,
api_dependencies=[source_api],
)
configs[info.router_api] = {}
sorted_specs = topological_sort(specs.values()) sorted_specs = topological_sort(specs.values())
print(f"Resolved {len(sorted_specs)} providers in topological order")
for spec in sorted_specs:
print(f" {spec.api}: {spec.provider_id}")
print("")
impls = {} impls = {}
for spec in sorted_specs: for spec in sorted_specs:
api = spec.api api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies} deps = {api: impls[api] for api in spec.api_dependencies}
if api.value in stack_run_config.provider_map: impl = await instantiate_provider(spec, deps, configs[api])
provider_config = stack_run_config.provider_map[api.value]
impl = await instantiate_provider(spec, deps, provider_config)
elif api.value in stack_run_config.provider_routing_table:
impl = await instantiate_router(
spec, api.value, stack_run_config.provider_routing_table
)
else:
impl = await instantiate_builtin_provider(spec, stack_run_config)
impls[api] = impl impls[api] = impl
return impls, specs return impls, specs
@ -355,16 +384,23 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
app = FastAPI() app = FastAPI()
# impls, specs = asyncio.run(resolve_impls(config.provider_map))
impls, specs = asyncio.run(resolve_impls_with_routing(config)) impls, specs = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])
all_endpoints = api_endpoints() all_endpoints = api_endpoints()
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys()) if config.apis_to_serve:
apis_to_serve = set(config.apis_to_serve)
for inf in builtin_automatically_routed_apis():
if inf.router_api.value in apis_to_serve:
apis_to_serve.add(inf.routing_table_api)
else:
apis_to_serve = set(impls.keys())
for api_str in apis_to_serve: for api_str in apis_to_serve:
api = Api(api_str) api = Api(api_str)
endpoints = all_endpoints[api] endpoints = all_endpoints[api]
impl = impls[api] impl = impls[api]
@ -391,7 +427,11 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
create_dynamic_typed_route( create_dynamic_typed_route(
impl_method, impl_method,
endpoint.method, endpoint.method,
provider_spec.provider_data_validator, (
provider_spec.provider_data_validator
if not isinstance(provider_spec, RoutingTableProviderSpec)
else None
),
) )
) )

View file

@ -16,36 +16,11 @@ def instantiate_class_type(fully_qualified_name):
return getattr(module, class_name) return getattr(module, class_name)
async def instantiate_router(
provider_spec: RouterProviderSpec,
api: str,
provider_routing_table: Dict[str, Any],
):
module = importlib.import_module(provider_spec.module)
fn = getattr(module, "get_router_impl")
impl = await fn(api, provider_routing_table)
impl.__provider_spec__ = provider_spec
return impl
async def instantiate_builtin_provider(
provider_spec: BuiltinProviderSpec,
run_config: StackRunConfig,
):
print("!!! instantiate_builtin_provider")
module = importlib.import_module(provider_spec.module)
fn = getattr(module, "get_builtin_impl")
impl = await fn(run_config)
impl.__provider_spec__ = provider_spec
return impl
# returns a class implementing the protocol corresponding to the Api # returns a class implementing the protocol corresponding to the Api
async def instantiate_provider( async def instantiate_provider(
provider_spec: ProviderSpec, provider_spec: ProviderSpec,
deps: Dict[str, Any], deps: Dict[str, Any],
provider_config: ProviderMapEntry, provider_config: Union[GenericProviderConfig, RoutingTable],
): ):
module = importlib.import_module(provider_spec.module) module = importlib.import_module(provider_spec.module)
@ -60,6 +35,29 @@ async def instantiate_provider(
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config) config = config_type(**provider_config.config)
args = [config, deps] args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
assert isinstance(provider_config, RoutingTableConfig)
routing_table = provider_config
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in routing_table.entries:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_id],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [provider_spec.api, inner_impls, routing_table, deps]
else: else:
method = "get_provider_impl" method = "get_provider_impl"

View file

@ -1,21 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec, StackRunConfig
from .config import BuiltinImplConfig # noqa
async def get_builtin_impl(config: StackRunConfig):
from .models import BuiltinModelsImpl
assert isinstance(config, StackRunConfig), f"Unexpected config type: {type(config)}"
impl = BuiltinModelsImpl(config)
await impl.initialize()
return impl

View file

@ -1,113 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import AsyncIterator, Union
from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from llama_stack.distribution.distribution import Api, api_providers
from llama_stack.apis.models import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.datatypes import CoreModelId, Model
from llama_models.sku_list import resolve_model
from llama_stack.distribution.datatypes import * # noqa: F403
from termcolor import cprint
class BuiltinModelsImpl(Models):
def __init__(
self,
config: StackRunConfig,
) -> None:
self.run_config = config
self.models = {}
# check against inference & safety api
apis_with_models = [Api.inference, Api.safety]
all_providers = api_providers()
for api in apis_with_models:
# check against provider_map (simple case single model)
if api.value in config.provider_map:
providers_for_api = all_providers[api]
provider_spec = config.provider_map[api.value]
core_model_id = provider_spec.config
# get supported model ids from the provider
supported_model_ids = self.get_supported_model_ids(
api.value, provider_spec, providers_for_api
)
for model_id in supported_model_ids:
self.models[model_id] = ModelServingSpec(
llama_model=resolve_model(model_id),
provider_config=provider_spec,
api=api.value,
)
# check against provider_routing_table (router with multiple models)
# with routing table, we use the routing_key as the supported models
if api.value in config.provider_routing_table:
routing_table = config.provider_routing_table[api.value]
for rt_entry in routing_table:
model_id = rt_entry.routing_key
self.models[model_id] = ModelServingSpec(
llama_model=resolve_model(model_id),
provider_config=GenericProviderConfig(
provider_id=rt_entry.provider_id,
config=rt_entry.config,
),
api=api.value,
)
print("BuiltinModelsImpl models", self.models)
def get_supported_model_ids(
self,
api: str,
provider_spec: GenericProviderConfig,
providers_for_api: Dict[str, ProviderSpec],
) -> List[str]:
serving_models_list = []
if api == Api.inference.value:
provider_id = provider_spec.provider_id
if provider_id == "meta-reference":
serving_models_list.append(provider_spec.config["model"])
if provider_id in {
remote_provider_id("ollama"),
remote_provider_id("fireworks"),
remote_provider_id("together"),
}:
adapter_supported_models = providers_for_api[
provider_id
].adapter.supported_model_ids
serving_models_list.extend(adapter_supported_models)
elif api == Api.safety.value:
if provider_spec.config and "llama_guard_shield" in provider_spec.config:
llama_guard_shield = provider_spec.config["llama_guard_shield"]
serving_models_list.append(llama_guard_shield["model"])
if provider_spec.config and "prompt_guard_shield" in provider_spec.config:
prompt_guard_shield = provider_spec.config["prompt_guard_shield"]
serving_models_list.append(prompt_guard_shield["model"])
else:
raise NotImplementedError(f"Unsupported api {api} for builtin models")
return serving_models_list
async def initialize(self) -> None:
pass
async def list_models(self) -> ModelsListResponse:
return ModelsListResponse(models_list=list(self.models.values()))
async def get_model(self, core_model_id: str) -> ModelsGetResponse:
if core_model_id in self.models:
return ModelsGetResponse(core_model_spec=self.models[core_model_id])
print(f"Cannot find {core_model_id} in model registry")
return ModelsGetResponse()

View file

@ -6,17 +6,14 @@
from typing import Optional from typing import Optional
from llama_models.datatypes import ModelFamily from llama_models.datatypes import * # noqa: F403
from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models, resolve_model from llama_models.sku_list import all_registered_models, resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from llama_stack.apis.inference import QuantizationConfig
@json_schema_type
class MetaReferenceImplConfig(BaseModel): class MetaReferenceImplConfig(BaseModel):
model: str = Field( model: str = Field(
default="Meta-Llama3.1-8B-Instruct", default="Meta-Llama3.1-8B-Instruct",
@ -34,6 +31,7 @@ class MetaReferenceImplConfig(BaseModel):
m.descriptor() m.descriptor()
for m in all_registered_models() for m in all_registered_models()
if m.model_family == ModelFamily.llama3_1 if m.model_family == ModelFamily.llama3_1
or m.core_model_id == CoreModelId.llama_guard_3_8b
] ]
if model not in permitted_models: if model not in permitted_models:
model_list = "\n\t".join(permitted_models) model_list = "\n\t".join(permitted_models)

View file

@ -32,10 +32,6 @@ def available_providers() -> List[ProviderSpec]:
adapter_id="ollama", adapter_id="ollama",
pip_packages=["ollama"], pip_packages=["ollama"],
module="llama_stack.providers.adapters.inference.ollama", module="llama_stack.providers.adapters.inference.ollama",
supported_model_ids=[
"Meta-Llama3.1-8B-Instruct",
"Meta-Llama3.1-70B-Instruct",
],
), ),
), ),
remote_provider_spec( remote_provider_spec(
@ -56,11 +52,6 @@ def available_providers() -> List[ProviderSpec]:
], ],
module="llama_stack.providers.adapters.inference.fireworks", module="llama_stack.providers.adapters.inference.fireworks",
config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig", config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig",
supported_model_ids=[
"Meta-Llama3.1-8B-Instruct",
"Meta-Llama3.1-70B-Instruct",
"Meta-Llama3.1-405B-Instruct",
],
), ),
), ),
remote_provider_spec( remote_provider_spec(
@ -73,11 +64,6 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.adapters.inference.together", module="llama_stack.providers.adapters.inference.together",
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig", config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor", header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
supported_model_ids=[
"Meta-Llama3.1-8B-Instruct",
"Meta-Llama3.1-70B-Instruct",
"Meta-Llama3.1-405B-Instruct",
],
), ),
), ),
] ]

View file

@ -1,22 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
BuiltinProviderSpec(
api=Api.models,
provider_id="builtin",
pip_packages=[],
module="llama_stack.providers.impls.builtin.models",
config_class="llama_stack.providers.impls.builtin.models.BuiltinImplConfig",
api_dependencies=[],
)
]