inference registry updates

This commit is contained in:
Ashwin Bharambe 2024-10-05 22:25:48 -07:00 committed by Ashwin Bharambe
parent 4215cc9331
commit 59302a86df
12 changed files with 570 additions and 535 deletions

View file

@ -12,7 +12,6 @@ from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
@ -25,24 +24,39 @@ from .model_parallel import LlamaModelParallelGenerator
SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, RoutableProvider):
class MetaReferenceInferenceImpl(Inference):
def __init__(self, config: MetaReferenceImplConfig) -> None:
self.config = config
model = resolve_model(config.model)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model
self.registered_model_defs = []
# verify that the checkpoint actually is for this model lol
async def initialize(self) -> None:
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
assert (
len(routing_keys) == 1
), f"Only one routing key is supported {routing_keys}"
assert routing_keys[0] == self.config.model
async def register_model(self, model: ModelDef) -> None:
existing = await self.get_model(model.identifier)
if existing is not None:
return
if model.identifier != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
)
self.registered_model_defs = [model]
async def list_models(self) -> List[ModelDef]:
return self.registered_model_defs
async def get_model(self, identifier: str) -> Optional[ModelDef]:
for model in self.registered_model_defs:
if model.identifier == identifier:
return model
return None
async def shutdown(self) -> None:
self.generator.stop()

View file

@ -13,7 +13,6 @@ import numpy as np
from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.utils.memory.vector_store import (
@ -62,7 +61,7 @@ class FaissIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory, RoutableProvider):
class FaissMemoryImpl(Memory):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.cache = {}
@ -83,7 +82,6 @@ class FaissMemoryImpl(Memory, RoutableProvider):
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
)
self.cache[memory_bank.identifier] = index
return bank
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
index = self.cache.get(identifier)