mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 07:30:00 +00:00
make TGI work well
This commit is contained in:
parent
e58c7f6c37
commit
021dd0d35d
9 changed files with 617 additions and 326 deletions
|
|
@ -12,7 +12,6 @@ import re
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
|
|
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
is_multimodal,
|
||||
ModelFamily,
|
||||
RawContent,
|
||||
RawContentItem,
|
||||
|
|
@ -43,7 +43,6 @@ from llama_stack.models.llama.datatypes import (
|
|||
Role,
|
||||
StopReason,
|
||||
ToolPromptFormat,
|
||||
is_multimodal,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.prompt_templates import (
|
||||
|
|
@ -56,6 +55,7 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
|||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
|
@ -129,7 +129,9 @@ async def interleaved_content_convert_to_raw(
|
|||
if image.url.uri.startswith("data"):
|
||||
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
|
||||
raise ValueError(
|
||||
f"Invalid data URL format, {image.url.uri[:40]}..."
|
||||
)
|
||||
_, image_data = match.groups()
|
||||
data = base64.b64decode(image_data)
|
||||
elif image.url.uri.startswith("file://"):
|
||||
|
|
@ -209,13 +211,17 @@ async def convert_image_content_to_url(
|
|||
|
||||
content, format = await localize_image_content(media)
|
||||
if include_format:
|
||||
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
|
||||
return f"data:image/{format};base64," + base64.b64encode(content).decode(
|
||||
"utf-8"
|
||||
)
|
||||
else:
|
||||
return base64.b64encode(content).decode("utf-8")
|
||||
|
||||
|
||||
async def completion_request_to_prompt(request: CompletionRequest) -> str:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
content = augment_content_with_response_format_prompt(
|
||||
request.response_format, request.content
|
||||
)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
|
|
@ -224,8 +230,12 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str:
|
|||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
async def completion_request_to_prompt_model_input_info(
|
||||
request: CompletionRequest,
|
||||
) -> Tuple[str, int]:
|
||||
content = augment_content_with_response_format_prompt(
|
||||
request.response_format, request.content
|
||||
)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
|
|
@ -246,7 +256,9 @@ def augment_content_with_response_format_prompt(response_format, content):
|
|||
return content
|
||||
|
||||
|
||||
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
|
||||
async def chat_completion_request_to_prompt(
|
||||
request: ChatCompletionRequest, llama_model: str
|
||||
) -> str:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
|
|
@ -254,7 +266,8 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
|
|||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(
|
||||
request.messages,
|
||||
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
||||
tool_prompt_format=request.tool_config.tool_prompt_format
|
||||
or get_default_tool_prompt_format(llama_model),
|
||||
)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
|
@ -269,10 +282,17 @@ async def chat_completion_request_to_model_input_info(
|
|||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(
|
||||
request.messages,
|
||||
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
||||
tool_prompt_format=request.tool_config.tool_prompt_format
|
||||
or get_default_tool_prompt_format(llama_model),
|
||||
)
|
||||
tokens = []
|
||||
for t in model_input.tokens:
|
||||
if t == 128256:
|
||||
tokens.append(formatter.vision_token)
|
||||
else:
|
||||
tokens.append(t)
|
||||
return (
|
||||
formatter.tokenizer.decode(model_input.tokens),
|
||||
formatter.tokenizer.decode(tokens),
|
||||
len(model_input.tokens),
|
||||
)
|
||||
|
||||
|
|
@ -298,7 +318,8 @@ def chat_completion_request_to_messages(
|
|||
return request.messages
|
||||
|
||||
if model.model_family == ModelFamily.llama3_1 or (
|
||||
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
|
||||
model.model_family == ModelFamily.llama3_2
|
||||
and is_multimodal(model.core_model_id)
|
||||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_1(request)
|
||||
|
|
@ -334,7 +355,9 @@ def augment_messages_for_tools_llama_3_1(
|
|||
if existing_messages[0].role == Role.system.value:
|
||||
existing_system_message = existing_messages.pop(0)
|
||||
|
||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
||||
assert (
|
||||
existing_messages[0].role != Role.system.value
|
||||
), "Should only have 1 system message"
|
||||
|
||||
messages = []
|
||||
|
||||
|
|
@ -366,9 +389,13 @@ def augment_messages_for_tools_llama_3_1(
|
|||
if isinstance(existing_system_message.content, str):
|
||||
sys_content += _process(existing_system_message.content)
|
||||
elif isinstance(existing_system_message.content, list):
|
||||
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
|
||||
sys_content += "\n".join(
|
||||
[_process(c) for c in existing_system_message.content]
|
||||
)
|
||||
|
||||
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
||||
tool_choice_prompt = _get_tool_choice_prompt(
|
||||
request.tool_config.tool_choice, request.tools
|
||||
)
|
||||
if tool_choice_prompt:
|
||||
sys_content += "\n" + tool_choice_prompt
|
||||
|
||||
|
|
@ -402,7 +429,9 @@ def augment_messages_for_tools_llama_3_2(
|
|||
if existing_messages[0].role == Role.system.value:
|
||||
existing_system_message = existing_messages.pop(0)
|
||||
|
||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
||||
assert (
|
||||
existing_messages[0].role != Role.system.value
|
||||
), "Should only have 1 system message"
|
||||
|
||||
sys_content = ""
|
||||
custom_tools, builtin_tools = [], []
|
||||
|
|
@ -423,10 +452,16 @@ def augment_messages_for_tools_llama_3_2(
|
|||
if custom_tools:
|
||||
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
|
||||
if fmt != ToolPromptFormat.python_list:
|
||||
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
|
||||
raise ValueError(
|
||||
f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}"
|
||||
)
|
||||
|
||||
system_prompt = None
|
||||
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
||||
if (
|
||||
existing_system_message
|
||||
and request.tool_config.system_message_behavior
|
||||
== SystemMessageBehavior.replace
|
||||
):
|
||||
system_prompt = existing_system_message.content
|
||||
|
||||
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
|
||||
|
|
@ -435,11 +470,16 @@ def augment_messages_for_tools_llama_3_2(
|
|||
sys_content += "\n"
|
||||
|
||||
if existing_system_message and (
|
||||
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
|
||||
request.tool_config.system_message_behavior == SystemMessageBehavior.append
|
||||
or not custom_tools
|
||||
):
|
||||
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
|
||||
sys_content += interleaved_content_as_str(
|
||||
existing_system_message.content, sep="\n"
|
||||
)
|
||||
|
||||
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
||||
tool_choice_prompt = _get_tool_choice_prompt(
|
||||
request.tool_config.tool_choice, request.tools
|
||||
)
|
||||
if tool_choice_prompt:
|
||||
sys_content += "\n" + tool_choice_prompt
|
||||
|
||||
|
|
@ -447,11 +487,15 @@ def augment_messages_for_tools_llama_3_2(
|
|||
return messages
|
||||
|
||||
|
||||
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str:
|
||||
def _get_tool_choice_prompt(
|
||||
tool_choice: ToolChoice | str, tools: List[ToolDefinition]
|
||||
) -> str:
|
||||
if tool_choice == ToolChoice.auto:
|
||||
return ""
|
||||
elif tool_choice == ToolChoice.required:
|
||||
return "You MUST use one of the provided functions/tools to answer the user query."
|
||||
return (
|
||||
"You MUST use one of the provided functions/tools to answer the user query."
|
||||
)
|
||||
elif tool_choice == ToolChoice.none:
|
||||
# tools are already not passed in
|
||||
return ""
|
||||
|
|
@ -463,11 +507,14 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
|
|||
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
|
||||
llama_model = resolve_model(model)
|
||||
if llama_model is None:
|
||||
log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
|
||||
log.warning(
|
||||
f"Could not resolve model {model}, defaulting to json tool prompt format"
|
||||
)
|
||||
return ToolPromptFormat.json
|
||||
|
||||
if llama_model.model_family == ModelFamily.llama3_1 or (
|
||||
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
|
||||
llama_model.model_family == ModelFamily.llama3_2
|
||||
and is_multimodal(llama_model.core_model_id)
|
||||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
return ToolPromptFormat.json
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue