Another round of simplification and clarity for models/shields/memory_banks stuff

This commit is contained in:
Ashwin Bharambe 2024-10-09 19:19:26 -07:00
parent 73a0a34e39
commit b55034c0de
27 changed files with 454 additions and 444 deletions

View file

@ -12,6 +12,7 @@ from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
@ -24,7 +25,7 @@ from .model_parallel import LlamaModelParallelGenerator
SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference):
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
def __init__(self, config: MetaReferenceImplConfig) -> None:
self.config = config
model = resolve_model(config.model)
@ -39,14 +40,38 @@ class MetaReferenceInferenceImpl(Inference):
self.generator.start()
async def register_model(self, model: ModelDef) -> None:
if model.identifier != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
raise ValueError("Dynamic model registration is not supported")
async def list_models(self) -> List[ModelDef]:
return [
ModelDef(
identifier=self.model.descriptor(),
llama_model=self.model.descriptor(),
)
]
async def get_model(self, identifier: str) -> Optional[ModelDef]:
if self.model.descriptor() != identifier:
return None
return ModelDef(
identifier=self.model.descriptor(),
llama_model=self.model.descriptor(),
)
async def shutdown(self) -> None:
self.generator.stop()
def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError()
def chat_completion(
self,
model: str,
@ -255,3 +280,10 @@ class MetaReferenceInferenceImpl(Inference):
stop_reason=stop_reason,
)
)
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -15,6 +15,8 @@ from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
@ -61,7 +63,7 @@ class FaissIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory):
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.cache = {}
@ -83,6 +85,16 @@ class FaissMemoryImpl(Memory):
)
self.cache[memory_bank.identifier] = index
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
banks = await self.list_memory_banks()
for bank in banks:
if bank.identifier == identifier:
return bank
return None
async def list_memory_banks(self) -> List[MemoryBankDef]:
return [i.bank for i in self.cache.values()]
async def insert_documents(
self,
bank_id: str,