Allow TGI adaptor to have non-standard llama model names

This commit is contained in:
Hardik Shah 2024-09-19 21:36:10 -07:00
parent 59af1c8fec
commit 42d29f3a5a

View file

@ -50,16 +50,6 @@ class TGIAdapter(Inference):
raise RuntimeError("Missing max_total_tokens in model info")
self.max_tokens = info["max_total_tokens"]
model_id = info["model_id"]
model_name = next(
(name for name, id in HF_SUPPORTED_MODELS.items() if id == model_id),
None,
)
if model_name is None:
raise RuntimeError(
f"TGI is serving model: {model_id}, use one of the supported models: {', '.join(HF_SUPPORTED_MODELS.values())}"
)
self.model_name = model_name
self.inference_url = info["inference_url"]
except Exception as e:
import traceback
@ -116,10 +106,6 @@ class TGIAdapter(Inference):
print(f"Calculated max_new_tokens: {max_new_tokens}")
assert (
request.model == self.model_name
), f"Model mismatch, expected {self.model_name}, got {request.model}"
options = self.get_chat_options(request)
if not request.stream:
response = self.client.text_generation(