inference registry updates

This commit is contained in:
Ashwin Bharambe 2024-10-05 22:25:48 -07:00 committed by Ashwin Bharambe
parent 4215cc9331
commit 59302a86df
12 changed files with 570 additions and 535 deletions

View file

@ -17,14 +17,19 @@ class DistributionInspectConfig(BaseModel):
pass
def get_provider_impl(*args, **kwargs):
return DistributionInspectImpl()
async def get_provider_impl(*args, **kwargs):
impl = DistributionInspectImpl()
await impl.initialize()
return impl
class DistributionInspectImpl(Inspect):
def __init__(self):
pass
async def initialize(self) -> None:
pass
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
ret = {}
all_providers = get_provider_registry()

View file

@ -20,6 +20,7 @@ class ProviderWithSpec(Provider):
spec: ProviderSpec
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
"""
Does two things:
@ -134,7 +135,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
print("")
impls = {}
inner_impls_by_provider_id = {f"inner-{x}": {} for x in router_apis}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies}

View file

@ -14,14 +14,13 @@ from llama_stack.apis.safety import * # noqa: F403
class MemoryRouter(Memory):
"""Routes to an provider based on the memory bank type"""
"""Routes to an provider based on the memory bank identifier"""
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
self.bank_id_to_type = {}
async def initialize(self) -> None:
pass
@ -29,32 +28,14 @@ class MemoryRouter(Memory):
async def shutdown(self) -> None:
pass
def get_provider_from_bank_id(self, bank_id: str) -> Any:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
async def list_memory_banks(self) -> List[MemoryBankDef]:
return self.routing_table.list_memory_banks()
provider = self.routing_table.get_provider_impl(bank_type)
if not provider:
raise ValueError(f"Could not find provider for {bank_type}")
return provider
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
return self.routing_table.get_memory_bank(identifier)
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_type = config.type
bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank(
name, config, url
)
self.bank_id_to_type[bank.bank_id] = bank_type
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
provider = self.get_provider_from_bank_id(bank_id)
return await provider.get_memory_bank(bank_id)
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
await self.routing_table.register_memory_bank(bank)
async def insert_documents(
self,
@ -62,7 +43,7 @@ class MemoryRouter(Memory):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
return await self.get_provider_from_bank_id(bank_id).insert_documents(
return await self.routing_table.get_provider_impl(bank_id).insert_documents(
bank_id, documents, ttl_seconds
)
@ -72,7 +53,7 @@ class MemoryRouter(Memory):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
return await self.get_provider_from_bank_id(bank_id).query_documents(
return await self.routing_table.get_provider_impl(bank_id).query_documents(
bank_id, query, params
)
@ -92,6 +73,15 @@ class InferenceRouter(Inference):
async def shutdown(self) -> None:
pass
async def list_models(self) -> List[ModelDef]:
return self.routing_table.list_models()
async def get_model(self, identifier: str) -> Optional[ModelDef]:
return self.routing_table.get_model(identifier)
async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model)
async def chat_completion(
self,
model: str,
@ -159,6 +149,15 @@ class SafetyRouter(Safety):
async def shutdown(self) -> None:
pass
async def list_shields(self) -> List[ShieldDef]:
return self.routing_table.list_shields()
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
return self.routing_table.get_shield(shield_type)
async def register_shield(self, shield: ShieldDef) -> None:
await self.routing_table.register_shield(shield)
async def run_shield(
self,
shield_type: str,

View file

@ -15,6 +15,8 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
# TODO: this routing table maintains state in memory purely. We need to
# add persistence to it when we add dynamic registration of objects.
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
@ -54,7 +56,7 @@ class CommonRoutingTableImpl(RoutingTable):
return obj
return None
def register_object(self, obj: RoutableObject) -> None:
async def register_object_common(self, obj: RoutableObject) -> None:
if obj.identifier in self.routing_key_to_object:
raise ValueError(f"Object `{obj.identifier}` already registered")
@ -79,7 +81,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return self.get_object_by_identifier(identifier)
async def register_model(self, model: ModelDef) -> None:
await self.register_object(model)
await self.register_object_common(model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
@ -93,7 +95,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return self.get_object_by_identifier(shield_type)
async def register_shield(self, shield: ShieldDef) -> None:
await self.register_object(shield)
await self.register_object_common(shield)
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
@ -107,4 +109,4 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
return self.get_object_by_identifier(identifier)
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
await self.register_object(bank)
await self.register_object_common(bank)