From 1bb01f934608672736716f65438c7b7f5ee79cd2 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 20:00:48 -0800 Subject: [PATCH] remove model lookup class --- docs/source/getting_started/index.md | 2 +- .../utils/inference/model_registry.py | 22 ++++--------------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/docs/source/getting_started/index.md b/docs/source/getting_started/index.md index d1d61d770..eb95db7cc 100644 --- a/docs/source/getting_started/index.md +++ b/docs/source/getting_started/index.md @@ -538,7 +538,7 @@ Once the server is set up, we can test it with a client to verify it's working c $ curl http://localhost:5000/inference/chat_completion \ -H "Content-Type: application/json" \ -d '{ - "model": "Llama3.1-8B-Instruct", + "model_id": "Llama3.1-8B-Instruct", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Write me a 2 sentence poem about the moon"} diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index c44c641a2..7120e9e97 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -15,7 +15,6 @@ ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_mo def get_huggingface_repo(model_descriptor: str) -> Optional[str]: - """Get the Hugging Face repository for a given CoreModelId.""" for model in all_registered_models(): if model.descriptor() == model_descriptor: return model.huggingface_repo @@ -33,11 +32,8 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli ) -class ModelLookup: - def __init__( - self, - model_aliases: List[ModelAlias], - ): +class ModelRegistryHelper(ModelsProtocolPrivate): + def __init__(self, model_aliases: List[ModelAlias]): self.alias_to_provider_id_map = {} self.provider_id_to_llama_model_map = {} for alias_obj in model_aliases: @@ -57,22 +53,12 @@ class ModelLookup: else: raise ValueError(f"Unknown model: `{identifier}`") - -class ModelRegistryHelper(ModelsProtocolPrivate): - - def __init__(self, model_aliases: List[ModelAlias]): - self.model_lookup = ModelLookup(model_aliases) - def get_llama_model(self, provider_model_id: str) -> str: - return self.model_lookup.provider_id_to_llama_model_map[provider_model_id] + return self.provider_id_to_llama_model_map[provider_model_id] async def register_model(self, model: Model) -> Model: - provider_model_id = self.model_lookup.get_provider_model_id( + model.provider_resource_id = self.get_provider_model_id( model.provider_resource_id ) - if not provider_model_id: - raise ValueError(f"Unknown model: `{model.provider_resource_id}`") - - model.provider_resource_id = provider_model_id return model