From 7e25db8478cd500578647934b2b6010b5c080932 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Thu, 19 Sep 2024 21:36:10 -0700 Subject: [PATCH] Allow TGI adaptor to have non-standard llama model names --- .../providers/adapters/inference/tgi/tgi.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 3be1f3e98..6c3b38347 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -18,12 +18,6 @@ from llama_stack.providers.utils.inference.prepare_messages import prepare_messa from .config import TGIImplConfig -HF_SUPPORTED_MODELS = { - "Meta-Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct", - "Meta-Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct", - "Meta-Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct", -} - class TGIAdapter(Inference): def __init__(self, config: TGIImplConfig) -> None: @@ -50,16 +44,6 @@ class TGIAdapter(Inference): raise RuntimeError("Missing max_total_tokens in model info") self.max_tokens = info["max_total_tokens"] - model_id = info["model_id"] - model_name = next( - (name for name, id in HF_SUPPORTED_MODELS.items() if id == model_id), - None, - ) - if model_name is None: - raise RuntimeError( - f"TGI is serving model: {model_id}, use one of the supported models: {', '.join(HF_SUPPORTED_MODELS.values())}" - ) - self.model_name = model_name self.inference_url = info["inference_url"] except Exception as e: import traceback @@ -116,10 +100,6 @@ class TGIAdapter(Inference): print(f"Calculated max_new_tokens: {max_new_tokens}") - assert ( - request.model == self.model_name - ), f"Model mismatch, expected {self.model_name}, got {request.model}" - options = self.get_chat_options(request) if not request.stream: response = self.client.text_generation(