Another round of simplification and clarity for models/shields/memory_banks stuff

This commit is contained in:
Ashwin Bharambe 2024-10-09 19:19:26 -07:00
parent 73a0a34e39
commit b55034c0de
27 changed files with 454 additions and 444 deletions

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Dict, List, Optional
from llama_models.sku_list import resolve_model
from llama_stack.apis.models import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
class ModelRegistryHelper:
class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
self.stack_to_provider_models_map = stack_to_provider_models_map
@ -33,3 +33,15 @@ class ModelRegistryHelper:
raise ValueError(
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
)
async def list_models(self) -> List[ModelDef]:
models = []
for llama_model, provider_model in self.stack_to_provider_models_map.items():
models.append(ModelDef(identifier=llama_model, llama_model=llama_model))
return models
async def get_model(self, identifier: str) -> Optional[ModelDef]:
if identifier not in self.stack_to_provider_models_map:
return None
return ModelDef(identifier=identifier, llama_model=identifier)