Allow TGI adaptor to have non-standard llama model names

This commit is contained in:
Hardik Shah 2024-09-19 21:36:10 -07:00
parent 59af1c8fec
commit 7e25db8478

View file

@ -18,12 +18,6 @@ from llama_stack.providers.utils.inference.prepare_messages import prepare_messa
from .config import TGIImplConfig
HF_SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"Meta-Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"Meta-Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct",
}
class TGIAdapter(Inference):
def __init__(self, config: TGIImplConfig) -> None:
@ -50,16 +44,6 @@ class TGIAdapter(Inference):
raise RuntimeError("Missing max_total_tokens in model info")
self.max_tokens = info["max_total_tokens"]
model_id = info["model_id"]
model_name = next(
(name for name, id in HF_SUPPORTED_MODELS.items() if id == model_id),
None,
)
if model_name is None:
raise RuntimeError(
f"TGI is serving model: {model_id}, use one of the supported models: {', '.join(HF_SUPPORTED_MODELS.values())}"
)
self.model_name = model_name
self.inference_url = info["inference_url"]
except Exception as e:
import traceback
@ -116,10 +100,6 @@ class TGIAdapter(Inference):
print(f"Calculated max_new_tokens: {max_new_tokens}")
assert (
request.model == self.model_name
), f"Model mismatch, expected {self.model_name}, got {request.model}"
options = self.get_chat_options(request)
if not request.stream:
response = self.client.text_generation(