llama-stack-mirror/llama_stack/distribution/routers/routing_tables.py
Ashwin Bharambe e45a417543 more fixes, plug shutdown handlers
still, FastAPIs sigint handler is not calling ours
2024-10-08 17:23:02 -07:00

113 lines
4.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
# TODO: this routing table maintains state in memory purely. We need to
# add persistence to it when we add dynamic registration of objects.
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
registry: List[RoutableObject],
impls_by_provider_id: Dict[str, RoutedProtocol],
) -> None:
for obj in registry:
if obj.provider_id not in impls_by_provider_id:
raise ValueError(
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
)
self.impls_by_provider_id = impls_by_provider_id
self.registry = registry
self.routing_key_to_object = {}
for obj in self.registry:
self.routing_key_to_object[obj.identifier] = obj
async def initialize(self) -> None:
for obj in self.registry:
p = self.impls_by_provider_id[obj.provider_id]
await self.register_object(obj, p)
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str) -> Any:
if routing_key not in self.routing_key_to_object:
raise ValueError(f"Could not find provider for {routing_key}")
obj = self.routing_key_to_object[routing_key]
return self.impls_by_provider_id[obj.provider_id]
def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]:
for obj in self.registry:
if obj.identifier == identifier:
return obj
return None
async def register_object_common(self, obj: RoutableObject) -> None:
if obj.identifier in self.routing_key_to_object:
raise ValueError(f"Object `{obj.identifier}` already registered")
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
p = self.impls_by_provider_id[obj.provider_id]
await p.register_object(obj)
self.routing_key_to_object[obj.identifier] = obj
self.registry.append(obj)
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def register_object(self, obj: ModelDef, p: Inference) -> None:
await p.register_model(obj)
async def list_models(self) -> List[ModelDef]:
return self.registry
async def get_model(self, identifier: str) -> Optional[ModelDef]:
return self.get_object_by_identifier(identifier)
async def register_model(self, model: ModelDef) -> None:
await self.register_object_common(model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def register_object(self, obj: ShieldDef, p: Safety) -> None:
await p.register_shield(obj)
async def list_shields(self) -> List[ShieldDef]:
return self.registry
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
return self.get_object_by_identifier(shield_type)
async def register_shield(self, shield: ShieldDef) -> None:
await self.register_object_common(shield)
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def register_object(self, obj: MemoryBankDef, p: Memory) -> None:
await p.register_memory_bank(obj)
async def list_memory_banks(self) -> List[MemoryBankDef]:
return self.registry
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
return self.get_object_by_identifier(identifier)
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
await self.register_object_common(bank)