From 42d29f3a5aff36fd8edba40f4bfd540d7d181057 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Thu, 19 Sep 2024 21:36:10 -0700 Subject: [PATCH] Allow TGI adaptor to have non-standard llama model names --- .../providers/adapters/inference/tgi/tgi.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 3be1f3e98..bb0b0ca6a 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -50,16 +50,6 @@ class TGIAdapter(Inference): raise RuntimeError("Missing max_total_tokens in model info") self.max_tokens = info["max_total_tokens"] - model_id = info["model_id"] - model_name = next( - (name for name, id in HF_SUPPORTED_MODELS.items() if id == model_id), - None, - ) - if model_name is None: - raise RuntimeError( - f"TGI is serving model: {model_id}, use one of the supported models: {', '.join(HF_SUPPORTED_MODELS.values())}" - ) - self.model_name = model_name self.inference_url = info["inference_url"] except Exception as e: import traceback @@ -116,10 +106,6 @@ class TGIAdapter(Inference): print(f"Calculated max_new_tokens: {max_new_tokens}") - assert ( - request.model == self.model_name - ), f"Model mismatch, expected {self.model_name}, got {request.model}" - options = self.get_chat_options(request) if not request.stream: response = self.client.text_generation(