make TGI work well

This commit is contained in:
Hardik Shah 2025-03-28 15:38:27 -07:00
parent e58c7f6c37
commit 021dd0d35d
9 changed files with 617 additions and 326 deletions

View file

@ -12,7 +12,6 @@ import re
from typing import List, Optional, Tuple, Union
import httpx
from PIL import Image as PIL_Image
from llama_stack.apis.common.content_types import (
ImageContentItem,
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
is_multimodal,
ModelFamily,
RawContent,
RawContentItem,
@ -43,7 +43,6 @@ from llama_stack.models.llama.datatypes import (
Role,
StopReason,
ToolPromptFormat,
is_multimodal,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import (
@ -56,6 +55,7 @@ from llama_stack.models.llama.llama3.prompt_templates import (
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.utils.inference import supported_inference_models
from PIL import Image as PIL_Image
log = get_logger(name=__name__, category="inference")
@ -129,7 +129,9 @@ async def interleaved_content_convert_to_raw(
if image.url.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
if not match:
raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
raise ValueError(
f"Invalid data URL format, {image.url.uri[:40]}..."
)
_, image_data = match.groups()
data = base64.b64decode(image_data)
elif image.url.uri.startswith("file://"):
@ -209,13 +211,17 @@ async def convert_image_content_to_url(
content, format = await localize_image_content(media)
if include_format:
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
return f"data:image/{format};base64," + base64.b64encode(content).decode(
"utf-8"
)
else:
return base64.b64encode(content).decode("utf-8")
async def completion_request_to_prompt(request: CompletionRequest) -> str:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
request.content = content
request = await convert_request_to_raw(request)
@ -224,8 +230,12 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str:
return formatter.tokenizer.decode(model_input.tokens)
async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
async def completion_request_to_prompt_model_input_info(
request: CompletionRequest,
) -> Tuple[str, int]:
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
request.content = content
request = await convert_request_to_raw(request)
@ -246,7 +256,9 @@ def augment_content_with_response_format_prompt(response_format, content):
return content
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
async def chat_completion_request_to_prompt(
request: ChatCompletionRequest, llama_model: str
) -> str:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
@ -254,7 +266,8 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
tool_prompt_format=request.tool_config.tool_prompt_format
or get_default_tool_prompt_format(llama_model),
)
return formatter.tokenizer.decode(model_input.tokens)
@ -269,10 +282,17 @@ async def chat_completion_request_to_model_input_info(
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
tool_prompt_format=request.tool_config.tool_prompt_format
or get_default_tool_prompt_format(llama_model),
)
tokens = []
for t in model_input.tokens:
if t == 128256:
tokens.append(formatter.vision_token)
else:
tokens.append(t)
return (
formatter.tokenizer.decode(model_input.tokens),
formatter.tokenizer.decode(tokens),
len(model_input.tokens),
)
@ -298,7 +318,8 @@ def chat_completion_request_to_messages(
return request.messages
if model.model_family == ModelFamily.llama3_1 or (
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
model.model_family == ModelFamily.llama3_2
and is_multimodal(model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_1(request)
@ -334,7 +355,9 @@ def augment_messages_for_tools_llama_3_1(
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
messages = []
@ -366,9 +389,13 @@ def augment_messages_for_tools_llama_3_1(
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
sys_content += "\n".join(
[_process(c) for c in existing_system_message.content]
)
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
tool_choice_prompt = _get_tool_choice_prompt(
request.tool_config.tool_choice, request.tools
)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
@ -402,7 +429,9 @@ def augment_messages_for_tools_llama_3_2(
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
sys_content = ""
custom_tools, builtin_tools = [], []
@ -423,10 +452,16 @@ def augment_messages_for_tools_llama_3_2(
if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}"
)
system_prompt = None
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
if (
existing_system_message
and request.tool_config.system_message_behavior
== SystemMessageBehavior.replace
):
system_prompt = existing_system_message.content
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
@ -435,11 +470,16 @@ def augment_messages_for_tools_llama_3_2(
sys_content += "\n"
if existing_system_message and (
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
request.tool_config.system_message_behavior == SystemMessageBehavior.append
or not custom_tools
):
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
sys_content += interleaved_content_as_str(
existing_system_message.content, sep="\n"
)
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
tool_choice_prompt = _get_tool_choice_prompt(
request.tool_config.tool_choice, request.tools
)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
@ -447,11 +487,15 @@ def augment_messages_for_tools_llama_3_2(
return messages
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str:
def _get_tool_choice_prompt(
tool_choice: ToolChoice | str, tools: List[ToolDefinition]
) -> str:
if tool_choice == ToolChoice.auto:
return ""
elif tool_choice == ToolChoice.required:
return "You MUST use one of the provided functions/tools to answer the user query."
return (
"You MUST use one of the provided functions/tools to answer the user query."
)
elif tool_choice == ToolChoice.none:
# tools are already not passed in
return ""
@ -463,11 +507,14 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
llama_model = resolve_model(model)
if llama_model is None:
log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
log.warning(
f"Could not resolve model {model}, defaulting to json tool prompt format"
)
return ToolPromptFormat.json
if llama_model.model_family == ModelFamily.llama3_1 or (
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
llama_model.model_family == ModelFamily.llama3_2
and is_multimodal(llama_model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
return ToolPromptFormat.json