mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
Introduce model_store, shield_store, memory_bank_store
This commit is contained in:
parent
e45a417543
commit
91e0063593
19 changed files with 172 additions and 297 deletions
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List
|
||||
from typing import Dict
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
|
|
@ -15,7 +15,6 @@ class ModelRegistryHelper:
|
|||
|
||||
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
||||
self.stack_to_provider_models_map = stack_to_provider_models_map
|
||||
self.registered_models = []
|
||||
|
||||
def map_to_provider_model(self, identifier: str) -> str:
|
||||
model = resolve_model(identifier)
|
||||
|
|
@ -30,22 +29,7 @@ class ModelRegistryHelper:
|
|||
return self.stack_to_provider_models_map[identifier]
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
existing = await self.get_model(model.identifier)
|
||||
if existing is not None:
|
||||
return
|
||||
|
||||
if model.identifier not in self.stack_to_provider_models_map:
|
||||
raise ValueError(
|
||||
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
||||
)
|
||||
|
||||
self.registered_models.append(model)
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
return self.registered_models
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||
for model in self.registered_models:
|
||||
if model.identifier == identifier:
|
||||
return model
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue