mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-09 11:20:58 +00:00
Another round of simplification and clarity for models/shields/memory_banks stuff
This commit is contained in:
parent
73a0a34e39
commit
b55034c0de
27 changed files with 454 additions and 444 deletions
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
|
||||
|
||||
class ModelRegistryHelper:
|
||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||
|
||||
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
||||
self.stack_to_provider_models_map = stack_to_provider_models_map
|
||||
|
|
@ -33,3 +33,15 @@ class ModelRegistryHelper:
|
|||
raise ValueError(
|
||||
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
||||
)
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
models = []
|
||||
for llama_model, provider_model in self.stack_to_provider_models_map.items():
|
||||
models.append(ModelDef(identifier=llama_model, llama_model=llama_model))
|
||||
return models
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||
if identifier not in self.stack_to_provider_models_map:
|
||||
return None
|
||||
|
||||
return ModelDef(identifier=identifier, llama_model=identifier)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue