mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
144 lines
4.8 KiB
Python
144 lines
4.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Any, List, Optional, Tuple
|
|
|
|
from llama_models.sku_list import resolve_model
|
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
|
|
from llama_stack.apis.models import * # noqa: F403
|
|
from llama_stack.apis.shields import * # noqa: F403
|
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
|
|
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
|
|
|
|
|
class CommonRoutingTableImpl(RoutingTable):
|
|
def __init__(
|
|
self,
|
|
inner_impls: List[Tuple[RoutingKey, Any]],
|
|
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
|
) -> None:
|
|
self.unique_providers = []
|
|
self.providers = {}
|
|
self.routing_keys = []
|
|
|
|
for key, impl in inner_impls:
|
|
keys = key if isinstance(key, list) else [key]
|
|
self.unique_providers.append((keys, impl))
|
|
|
|
for k in keys:
|
|
if k in self.providers:
|
|
raise ValueError(f"Duplicate routing key {k}")
|
|
self.providers[k] = impl
|
|
self.routing_keys.append(k)
|
|
|
|
self.routing_table_config = routing_table_config
|
|
|
|
async def initialize(self) -> None:
|
|
for keys, p in self.unique_providers:
|
|
spec = p.__provider_spec__
|
|
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
|
|
continue
|
|
|
|
await p.validate_routing_keys(keys)
|
|
|
|
async def shutdown(self) -> None:
|
|
for _, p in self.unique_providers:
|
|
await p.shutdown()
|
|
|
|
def get_provider_impl(self, routing_key: str) -> Any:
|
|
if routing_key not in self.providers:
|
|
raise ValueError(f"Could not find provider for {routing_key}")
|
|
return self.providers[routing_key]
|
|
|
|
def get_routing_keys(self) -> List[str]:
|
|
return self.routing_keys
|
|
|
|
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
|
|
for entry in self.routing_table_config:
|
|
if entry.routing_key == routing_key:
|
|
return entry
|
|
return None
|
|
|
|
|
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|
|
|
async def list_models(self) -> List[ModelServingSpec]:
|
|
specs = []
|
|
for entry in self.routing_table_config:
|
|
model_id = entry.routing_key
|
|
specs.append(
|
|
ModelServingSpec(
|
|
llama_model=resolve_model(model_id),
|
|
provider_config=entry,
|
|
)
|
|
)
|
|
return specs
|
|
|
|
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
|
for entry in self.routing_table_config:
|
|
if entry.routing_key == core_model_id:
|
|
return ModelServingSpec(
|
|
llama_model=resolve_model(core_model_id),
|
|
provider_config=entry,
|
|
)
|
|
return None
|
|
|
|
|
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|
|
|
async def list_shields(self) -> List[ShieldSpec]:
|
|
specs = []
|
|
for entry in self.routing_table_config:
|
|
if isinstance(entry.routing_key, list):
|
|
for k in entry.routing_key:
|
|
specs.append(
|
|
ShieldSpec(
|
|
shield_type=k,
|
|
provider_config=entry,
|
|
)
|
|
)
|
|
else:
|
|
specs.append(
|
|
ShieldSpec(
|
|
shield_type=entry.routing_key,
|
|
provider_config=entry,
|
|
)
|
|
)
|
|
return specs
|
|
|
|
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
|
for entry in self.routing_table_config:
|
|
if entry.routing_key == shield_type:
|
|
return ShieldSpec(
|
|
shield_type=entry.routing_key,
|
|
provider_config=entry,
|
|
)
|
|
return None
|
|
|
|
|
|
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|
|
|
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
|
|
specs = []
|
|
for entry in self.routing_table_config:
|
|
specs.append(
|
|
MemoryBankSpec(
|
|
bank_type=entry.routing_key,
|
|
provider_config=entry,
|
|
)
|
|
)
|
|
return specs
|
|
|
|
async def get_serving_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
|
|
for entry in self.routing_table_config:
|
|
if entry.routing_key == bank_type:
|
|
return MemoryBankSpec(
|
|
bank_type=entry.routing_key,
|
|
provider_config=entry,
|
|
)
|
|
return None
|