Further generalize Xi's changes (#88)

* Further generalize Xi's changes

- introduce a slightly more general notion of an AutoRouted provider
- the AutoRouted provider is associated with a RoutingTable provider
- e.g. inference -> models
- Introduced safety -> shields and memory -> memory_banks
  correspondences

* typo

* Basic build and run succeeded
This commit is contained in:
Ashwin Bharambe 2024-09-22 16:31:18 -07:00 committed by GitHub
parent b8914bb56f
commit c1ab66f1e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 597 additions and 418 deletions

View file

@ -4,17 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Tuple
from typing import Any, AsyncGenerator, Dict, List
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.datatypes import RoutingTable
from .routing_table import RoutingTable
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from types import MethodType
from termcolor import cprint
from llama_stack.apis.safety import * # noqa: F403
class MemoryRouter(Memory):
@ -24,22 +20,24 @@ class MemoryRouter(Memory):
self,
routing_table: RoutingTable,
) -> None:
self.api = Api.memory.value
self.routing_table = routing_table
self.bank_id_to_type = {}
async def initialize(self) -> None:
await self.routing_table.initialize(self.api)
pass
async def shutdown(self) -> None:
await self.routing_table.shutdown(self.api)
pass
def get_provider_from_bank_id(self, bank_id: str) -> Any:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
return self.routing_table.get_provider_impl(self.api, bank_type)
provider = self.routing_table.get_provider_impl(bank_type)
if not provider:
raise ValueError(f"Could not find provider for {bank_type}")
return provider
async def create_memory_bank(
self,
@ -48,14 +46,15 @@ class MemoryRouter(Memory):
url: Optional[URL] = None,
) -> MemoryBank:
bank_type = config.type
bank = await self.routing_table.get_provider_impl(
self.api, bank_type
provider = await self.routing_table.get_provider_impl(
bank_type
).create_memory_bank(name, config, url)
self.bank_id_to_type[bank.bank_id] = bank_type
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
return await self.get_provider_from_bank_id(bank_id).get_memory_bank(bank_id)
provider = self.get_provider_from_bank_id(bank_id)
return await provider.get_memory_bank(bank_id)
async def insert_documents(
self,
@ -85,34 +84,31 @@ class InferenceRouter(Inference):
self,
routing_table: RoutingTable,
) -> None:
self.api = Api.inference.value
self.routing_table = routing_table
async def initialize(self) -> None:
await self.routing_table.initialize(self.api)
pass
async def shutdown(self) -> None:
await self.routing_table.shutdown(self.api)
pass
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = [],
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
# TODO: we need to fix streaming response to align provider implementations with Protocol.
async for chunk in self.routing_table.get_provider_impl(
self.api, model
).chat_completion(
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
@ -128,7 +124,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
return await self.routing_table.get_provider_impl(self.api, model).completion(
return await self.routing_table.get_provider_impl(model).completion(
model=model,
content=content,
sampling_params=sampling_params,
@ -141,7 +137,33 @@ class InferenceRouter(Inference):
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
return await self.routing_table.get_provider_impl(self.api, model).embeddings(
return await self.routing_table.get_provider_impl(model).embeddings(
model=model,
contents=contents,
)
class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def run_shield(
self,
shield_type: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
return await self.routing_table.get_provider_impl(shield_type).run_shield(
shield_type=shield_type,
messages=messages,
params=params,
)