mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
inference registry updates
This commit is contained in:
parent
4215cc9331
commit
59302a86df
12 changed files with 570 additions and 535 deletions
|
|
@ -12,7 +12,6 @@ from llama_models.sku_list import resolve_model
|
|||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
|
|
@ -25,24 +24,39 @@ from .model_parallel import LlamaModelParallelGenerator
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
||||
class MetaReferenceInferenceImpl(Inference):
|
||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
self.model = model
|
||||
self.registered_model_defs = []
|
||||
# verify that the checkpoint actually is for this model lol
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
assert (
|
||||
len(routing_keys) == 1
|
||||
), f"Only one routing key is supported {routing_keys}"
|
||||
assert routing_keys[0] == self.config.model
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
existing = await self.get_model(model.identifier)
|
||||
if existing is not None:
|
||||
return
|
||||
|
||||
if model.identifier != self.model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
|
||||
)
|
||||
self.registered_model_defs = [model]
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
return self.registered_model_defs
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||
for model in self.registered_model_defs:
|
||||
if model.identifier == identifier:
|
||||
return model
|
||||
return None
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.generator.stop()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import numpy as np
|
|||
from numpy.typing import NDArray
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
|
@ -62,7 +61,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class FaissMemoryImpl(Memory, RoutableProvider):
|
||||
class FaissMemoryImpl(Memory):
|
||||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
|
|
@ -83,7 +82,6 @@ class FaissMemoryImpl(Memory, RoutableProvider):
|
|||
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
||||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
|
||||
index = self.cache.get(identifier)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue