From 707da55c23742fba40ada290cda8bcc119452c35 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 23 Nov 2024 08:47:05 -0800 Subject: [PATCH] Fix TGI register_model() issue --- .../providers/remote/inference/tgi/tgi.py | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index dad055cbd..621188284 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -17,6 +17,10 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403 from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, @@ -37,6 +41,17 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl log = logging.getLogger(__name__) +def build_model_aliases(): + return [ + build_model_alias( + model.huggingface_repo, + model.descriptor(), + ) + for model in all_registered_models() + if model.huggingface_repo + ] + + class _HfAdapter(Inference, ModelsProtocolPrivate): client: AsyncInferenceClient max_tokens: int @@ -44,31 +59,24 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): def __init__(self) -> None: self.formatter = ChatFormat(Tokenizer.get_instance()) + self.register_helper = ModelRegistryHelper(build_model_aliases()) self.huggingface_repo_to_llama_model_id = { model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo } - async def register_model(self, model: Model) -> None: - pass - - async def list_models(self) -> List[Model]: - repo = self.model_id - identifier = self.huggingface_repo_to_llama_model_id[repo] - return [ - Model( - identifier=identifier, - llama_model=identifier, - metadata={ - "huggingface_repo": repo, - }, - ) - ] - async def shutdown(self) -> None: pass + async def register_model(self, model: Model) -> None: + model = await self.register_helper.register_model(model) + if model.provider_resource_id != self.model_id: + raise ValueError( + f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI." + ) + return model + async def unregister_model(self, model_id: str) -> None: pass