mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
Allow TGI adaptor to have non-standard llama model names
This commit is contained in:
parent
59af1c8fec
commit
7e25db8478
1 changed files with 0 additions and 20 deletions
|
@ -18,12 +18,6 @@ from llama_stack.providers.utils.inference.prepare_messages import prepare_messa
|
||||||
|
|
||||||
from .config import TGIImplConfig
|
from .config import TGIImplConfig
|
||||||
|
|
||||||
HF_SUPPORTED_MODELS = {
|
|
||||||
"Meta-Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
|
||||||
"Meta-Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
|
||||||
"Meta-Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TGIAdapter(Inference):
|
class TGIAdapter(Inference):
|
||||||
def __init__(self, config: TGIImplConfig) -> None:
|
def __init__(self, config: TGIImplConfig) -> None:
|
||||||
|
@ -50,16 +44,6 @@ class TGIAdapter(Inference):
|
||||||
raise RuntimeError("Missing max_total_tokens in model info")
|
raise RuntimeError("Missing max_total_tokens in model info")
|
||||||
self.max_tokens = info["max_total_tokens"]
|
self.max_tokens = info["max_total_tokens"]
|
||||||
|
|
||||||
model_id = info["model_id"]
|
|
||||||
model_name = next(
|
|
||||||
(name for name, id in HF_SUPPORTED_MODELS.items() if id == model_id),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if model_name is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"TGI is serving model: {model_id}, use one of the supported models: {', '.join(HF_SUPPORTED_MODELS.values())}"
|
|
||||||
)
|
|
||||||
self.model_name = model_name
|
|
||||||
self.inference_url = info["inference_url"]
|
self.inference_url = info["inference_url"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -116,10 +100,6 @@ class TGIAdapter(Inference):
|
||||||
|
|
||||||
print(f"Calculated max_new_tokens: {max_new_tokens}")
|
print(f"Calculated max_new_tokens: {max_new_tokens}")
|
||||||
|
|
||||||
assert (
|
|
||||||
request.model == self.model_name
|
|
||||||
), f"Model mismatch, expected {self.model_name}, got {request.model}"
|
|
||||||
|
|
||||||
options = self.get_chat_options(request)
|
options = self.get_chat_options(request)
|
||||||
if not request.stream:
|
if not request.stream:
|
||||||
response = self.client.text_generation(
|
response = self.client.text_generation(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue