inference registry updates

This commit is contained in:
Ashwin Bharambe 2024-10-05 22:25:48 -07:00 committed by Ashwin Bharambe
parent 4215cc9331
commit 59302a86df
12 changed files with 570 additions and 535 deletions

View file

@ -17,14 +17,19 @@ class DistributionInspectConfig(BaseModel):
pass pass
def get_provider_impl(*args, **kwargs): async def get_provider_impl(*args, **kwargs):
return DistributionInspectImpl() impl = DistributionInspectImpl()
await impl.initialize()
return impl
class DistributionInspectImpl(Inspect): class DistributionInspectImpl(Inspect):
def __init__(self): def __init__(self):
pass pass
async def initialize(self) -> None:
pass
async def list_providers(self) -> Dict[str, List[ProviderInfo]]: async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
ret = {} ret = {}
all_providers = get_provider_registry() all_providers = get_provider_registry()

View file

@ -20,6 +20,7 @@ class ProviderWithSpec(Provider):
spec: ProviderSpec spec: ProviderSpec
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]: async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
""" """
Does two things: Does two things:
@ -134,7 +135,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
print("") print("")
impls = {} impls = {}
inner_impls_by_provider_id = {f"inner-{x}": {} for x in router_apis} inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies} deps = {a: impls[a] for a in provider.spec.api_dependencies}

View file

@ -14,14 +14,13 @@ from llama_stack.apis.safety import * # noqa: F403
class MemoryRouter(Memory): class MemoryRouter(Memory):
"""Routes to an provider based on the memory bank type""" """Routes to an provider based on the memory bank identifier"""
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
self.routing_table = routing_table self.routing_table = routing_table
self.bank_id_to_type = {}
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -29,32 +28,14 @@ class MemoryRouter(Memory):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def get_provider_from_bank_id(self, bank_id: str) -> Any: async def list_memory_banks(self) -> List[MemoryBankDef]:
bank_type = self.bank_id_to_type.get(bank_id) return self.routing_table.list_memory_banks()
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.routing_table.get_provider_impl(bank_type) async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
if not provider: return self.routing_table.get_memory_bank(identifier)
raise ValueError(f"Could not find provider for {bank_type}")
return provider
async def create_memory_bank( async def register_memory_bank(self, bank: MemoryBankDef) -> None:
self, await self.routing_table.register_memory_bank(bank)
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_type = config.type
bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank(
name, config, url
)
self.bank_id_to_type[bank.bank_id] = bank_type
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
provider = self.get_provider_from_bank_id(bank_id)
return await provider.get_memory_bank(bank_id)
async def insert_documents( async def insert_documents(
self, self,
@ -62,7 +43,7 @@ class MemoryRouter(Memory):
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
return await self.get_provider_from_bank_id(bank_id).insert_documents( return await self.routing_table.get_provider_impl(bank_id).insert_documents(
bank_id, documents, ttl_seconds bank_id, documents, ttl_seconds
) )
@ -72,7 +53,7 @@ class MemoryRouter(Memory):
query: InterleavedTextMedia, query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryDocumentsResponse:
return await self.get_provider_from_bank_id(bank_id).query_documents( return await self.routing_table.get_provider_impl(bank_id).query_documents(
bank_id, query, params bank_id, query, params
) )
@ -92,6 +73,15 @@ class InferenceRouter(Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_models(self) -> List[ModelDef]:
return self.routing_table.list_models()
async def get_model(self, identifier: str) -> Optional[ModelDef]:
return self.routing_table.get_model(identifier)
async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model)
async def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
@ -159,6 +149,15 @@ class SafetyRouter(Safety):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_shields(self) -> List[ShieldDef]:
return self.routing_table.list_shields()
async def get_shield(self, shield_type: str) -> Optional[ShieldDef]:
return self.routing_table.get_shield(shield_type)
async def register_shield(self, shield: ShieldDef) -> None:
await self.routing_table.register_shield(shield)
async def run_shield( async def run_shield(
self, self,
shield_type: str, shield_type: str,

View file

@ -15,6 +15,8 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
# TODO: this routing table maintains state in memory purely. We need to
# add persistence to it when we add dynamic registration of objects.
class CommonRoutingTableImpl(RoutingTable): class CommonRoutingTableImpl(RoutingTable):
def __init__( def __init__(
self, self,
@ -54,7 +56,7 @@ class CommonRoutingTableImpl(RoutingTable):
return obj return obj
return None return None
def register_object(self, obj: RoutableObject) -> None: async def register_object_common(self, obj: RoutableObject) -> None:
if obj.identifier in self.routing_key_to_object: if obj.identifier in self.routing_key_to_object:
raise ValueError(f"Object `{obj.identifier}` already registered") raise ValueError(f"Object `{obj.identifier}` already registered")
@ -79,7 +81,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return self.get_object_by_identifier(identifier) return self.get_object_by_identifier(identifier)
async def register_model(self, model: ModelDef) -> None: async def register_model(self, model: ModelDef) -> None:
await self.register_object(model) await self.register_object_common(model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
@ -93,7 +95,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return self.get_object_by_identifier(shield_type) return self.get_object_by_identifier(shield_type)
async def register_shield(self, shield: ShieldDef) -> None: async def register_shield(self, shield: ShieldDef) -> None:
await self.register_object(shield) await self.register_object_common(shield)
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
@ -107,4 +109,4 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
return self.get_object_by_identifier(identifier) return self.get_object_by_identifier(identifier)
async def register_memory_bank(self, bank: MemoryBankDef) -> None: async def register_memory_bank(self, bank: MemoryBankDef) -> None:
await self.register_object(bank) await self.register_object_common(bank)

View file

@ -13,7 +13,7 @@ from botocore.config import Config
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
@ -26,7 +26,7 @@ BEDROCK_SUPPORTED_MODELS = {
} }
class BedrockInferenceAdapter(Inference, RoutableProviderForModels): class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
@staticmethod @staticmethod
def _create_bedrock_client(config: BedrockConfig) -> BaseClient: def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
@ -69,7 +69,7 @@ class BedrockInferenceAdapter(Inference, RoutableProviderForModels):
return boto3_session.client("bedrock-runtime", config=boto3_config) return boto3_session.client("bedrock-runtime", config=boto3_config)
def __init__(self, config: BedrockConfig) -> None: def __init__(self, config: BedrockConfig) -> None:
RoutableProviderForModels.__init__( ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
) )
self._config = config self._config = config

View file

@ -13,7 +13,7 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
@ -30,9 +30,9 @@ FIREWORKS_SUPPORTED_MODELS = {
} }
class FireworksInferenceAdapter(Inference, RoutableProviderForModels): class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
RoutableProviderForModels.__init__( ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
) )
self.config = config self.config = config

View file

@ -18,7 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools, augment_messages_for_tools,
) )
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
# TODO: Eventually this will move to the llama cli model list command # TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models # mapping of Model SKUs to ollama models
@ -27,12 +27,13 @@ OLLAMA_SUPPORTED_SKUS = {
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16", "Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16", "Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
"Llama-Guard-3-8B": "xe/llamaguard3:latest",
} }
class OllamaInferenceAdapter(Inference, RoutableProviderForModels): class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, url: str) -> None: def __init__(self, url: str) -> None:
RoutableProviderForModels.__init__( ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
) )
self.url = url self.url = url

View file

@ -18,7 +18,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools, augment_messages_for_tools,
) )
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import TogetherImplConfig from .config import TogetherImplConfig
@ -34,10 +34,10 @@ TOGETHER_SUPPORTED_MODELS = {
class TogetherInferenceAdapter( class TogetherInferenceAdapter(
Inference, NeedsRequestProviderData, RoutableProviderForModels ModelRegistryHelper, Inference, NeedsRequestProviderData
): ):
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
RoutableProviderForModels.__init__( ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
) )
self.config = config self.config = config

View file

@ -12,7 +12,6 @@ from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools, augment_messages_for_tools,
) )
@ -25,24 +24,39 @@ from .model_parallel import LlamaModelParallelGenerator
SEMAPHORE = asyncio.Semaphore(1) SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, RoutableProvider): class MetaReferenceInferenceImpl(Inference):
def __init__(self, config: MetaReferenceImplConfig) -> None: def __init__(self, config: MetaReferenceImplConfig) -> None:
self.config = config self.config = config
model = resolve_model(config.model) model = resolve_model(config.model)
if model is None: if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model self.model = model
self.registered_model_defs = []
# verify that the checkpoint actually is for this model lol # verify that the checkpoint actually is for this model lol
async def initialize(self) -> None: async def initialize(self) -> None:
self.generator = LlamaModelParallelGenerator(self.config) self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start() self.generator.start()
async def validate_routing_keys(self, routing_keys: List[str]) -> None: async def register_model(self, model: ModelDef) -> None:
assert ( existing = await self.get_model(model.identifier)
len(routing_keys) == 1 if existing is not None:
), f"Only one routing key is supported {routing_keys}" return
assert routing_keys[0] == self.config.model
if model.identifier != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
)
self.registered_model_defs = [model]
async def list_models(self) -> List[ModelDef]:
return self.registered_model_defs
async def get_model(self, identifier: str) -> Optional[ModelDef]:
for model in self.registered_model_defs:
if model.identifier == identifier:
return model
return None
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.generator.stop() self.generator.stop()

View file

@ -13,7 +13,6 @@ import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
@ -62,7 +61,7 @@ class FaissIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory, RoutableProvider): class FaissMemoryImpl(Memory):
def __init__(self, config: FaissImplConfig) -> None: def __init__(self, config: FaissImplConfig) -> None:
self.config = config self.config = config
self.cache = {} self.cache = {}
@ -83,7 +82,6 @@ class FaissMemoryImpl(Memory, RoutableProvider):
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
) )
self.cache[memory_bank.identifier] = index self.cache[memory_bank.identifier] = index
return bank
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
index = self.cache.get(identifier) index = self.cache.get(identifier)

View file

@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List
from llama_models.sku_list import resolve_model
from llama_stack.apis.models import * # noqa: F403
class ModelRegistryHelper:
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
self.stack_to_provider_models_map = stack_to_provider_models_map
self.registered_models = []
def map_to_provider_model(self, identifier: str) -> str:
model = resolve_model(identifier)
if not model:
raise ValueError(f"Unknown model: `{identifier}`")
if identifier not in self.stack_to_provider_models_map:
raise ValueError(
f"Model {identifier} not found in map {self.stack_to_provider_models_map}"
)
return self.stack_to_provider_models_map[identifier]
async def register_model(self, model: ModelDef) -> None:
existing = await self.get_model(model.identifier)
if existing is not None:
return
if model.identifier not in self.stack_to_provider_models_map:
raise ValueError(
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
)
self.registered_models.append(model)
async def list_models(self) -> List[ModelDef]:
return self.registered_models
async def get_model(self, identifier: str) -> Optional[ModelDef]:
for model in self.registered_models:
if model.identifier == identifier:
return model
return None

View file

@ -1,36 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List
from llama_models.sku_list import resolve_model
from llama_stack.distribution.datatypes import RoutableProvider
class RoutableProviderForModels(RoutableProvider):
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
self.stack_to_provider_models_map = stack_to_provider_models_map
async def validate_routing_keys(self, routing_keys: List[str]):
for routing_key in routing_keys:
if routing_key not in self.stack_to_provider_models_map:
raise ValueError(
f"Routing key {routing_key} not found in map {self.stack_to_provider_models_map}"
)
def map_to_provider_model(self, routing_key: str) -> str:
model = resolve_model(routing_key)
if not model:
raise ValueError(f"Unknown model: `{routing_key}`")
if routing_key not in self.stack_to_provider_models_map:
raise ValueError(
f"Model {routing_key} not found in map {self.stack_to_provider_models_map}"
)
return self.stack_to_provider_models_map[routing_key]