make TGI work well

This commit is contained in:
Hardik Shah 2025-03-28 15:38:27 -07:00
parent e58c7f6c37
commit 021dd0d35d
9 changed files with 617 additions and 326 deletions

View file

@ -33,10 +33,9 @@ from llama_stack.apis.inference import (
from llama_stack.apis.models.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
convert_chat_completion_request_to_openai_params,
convert_message_to_openai_dict_new,
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
@ -55,7 +54,9 @@ class LiteLLMOpenAIMixin(
Inference,
NeedsRequestProviderData,
):
def __init__(self, model_entries, api_key_from_config: str, provider_data_api_key_field: str):
def __init__(
self, model_entries, api_key_from_config: str, provider_data_api_key_field: str
):
ModelRegistryHelper.__init__(self, model_entries)
self.api_key_from_config = api_key_from_config
self.provider_data_api_key_field = provider_data_api_key_field
@ -95,7 +96,9 @@ class LiteLLMOpenAIMixin(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
@ -110,7 +113,17 @@ class LiteLLMOpenAIMixin(
tool_config=tool_config,
)
params = await self._get_params(request)
params = await convert_chat_completion_request_to_openai_params(request)
# add api_key to params if available
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
api_key = self.api_key_from_config
params["api_key"] = api_key
logger.debug(f"params to litellm (openai compat): {params}")
# unfortunately, we need to use synchronous litellm.completion here because litellm
# caches various httpx.client objects in a non-eventloop aware manner
@ -132,87 +145,6 @@ class LiteLLMOpenAIMixin(
):
yield chunk
def _add_additional_properties_recursive(self, schema):
"""
Recursively add additionalProperties: False to all object schemas
"""
if isinstance(schema, dict):
if schema.get("type") == "object":
schema["additionalProperties"] = False
# Add required field with all property keys if properties exist
if "properties" in schema and schema["properties"]:
schema["required"] = list(schema["properties"].keys())
if "properties" in schema:
for prop_schema in schema["properties"].values():
self._add_additional_properties_recursive(prop_schema)
for key in ["anyOf", "allOf", "oneOf"]:
if key in schema:
for sub_schema in schema[key]:
self._add_additional_properties_recursive(sub_schema)
if "not" in schema:
self._add_additional_properties_recursive(schema["not"])
# Handle $defs/$ref
if "$defs" in schema:
for def_schema in schema["$defs"].values():
self._add_additional_properties_recursive(def_schema)
return schema
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages]
if fmt := request.response_format:
if not isinstance(fmt, JsonSchemaResponseFormat):
raise ValueError(
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
)
fmt = fmt.json_schema
name = fmt["title"]
del fmt["title"]
fmt["additionalProperties"] = False
# Apply additionalProperties: False recursively to all objects
fmt = self._add_additional_properties_recursive(fmt)
input_dict["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": name,
"schema": fmt,
"strict": True,
},
}
if request.tools:
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config.tool_choice:
input_dict["tool_choice"] = (
request.tool_config.tool_choice.value
if isinstance(request.tool_config.tool_choice, ToolChoice)
else request.tool_config.tool_choice
)
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
api_key = self.api_key_from_config
return {
"model": request.model,
"api_key": api_key,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def embeddings(
self,
model_id: str,

View file

@ -6,55 +6,7 @@
import json
import logging
import warnings
from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union
from openai import AsyncStream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
)
from openai.types.chat import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat import (
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
from openai.types.chat import (
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
)
from openai.types.chat import (
ChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
)
from openai.types.chat import (
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
)
from openai.types.chat import (
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
)
from openai.types.chat.chat_completion import (
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction,
)
from pydantic import BaseModel
from typing import AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
from llama_stack.apis.common.content_types import (
ImageContentItem,
@ -71,11 +23,14 @@ from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
JsonSchemaResponseFormat,
Message,
SystemMessage,
TokenLogProbs,
ToolChoice,
ToolResponseMessage,
UserMessage,
)
@ -94,6 +49,32 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
decode_assistant_message,
)
from openai import AsyncStream, Stream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
ChatCompletionChunk as OpenAIChatCompletionChunk,
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall,
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction,
)
from pydantic import BaseModel
logger = logging.getLogger(__name__)
@ -188,12 +169,16 @@ def convert_openai_completion_logprobs(
if logprobs.tokens and logprobs.token_logprobs:
return [
TokenLogProbs(logprobs_by_token={token: token_lp})
for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False)
for token, token_lp in zip(
logprobs.tokens, logprobs.token_logprobs, strict=False
)
]
return None
def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]):
def convert_openai_completion_logprobs_stream(
text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]
):
if logprobs is None:
return None
if isinstance(logprobs, float):
@ -238,7 +223,9 @@ def process_chat_completion_response(
if not choice.message or not choice.message.tool_calls:
raise ValueError("Tool calls are not present in the response")
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
tool_calls = [
convert_tool_call(tool_call) for tool_call in choice.message.tool_calls
]
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
# If we couldn't parse a tool call, jsonify the tool calls and return them
return ChatCompletionResponse(
@ -262,7 +249,9 @@ def process_chat_completion_response(
# TODO: This does not work well with tool calls for vLLM remote provider
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
raw_message = decode_assistant_message(
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
# NOTE: If we do not set tools in chat-completion request, we should not
# expect the ToolCall in the response. Instead, we should return the raw
@ -463,13 +452,17 @@ async def process_chat_completion_stream_response(
)
async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict:
async def convert_message_to_openai_dict(
message: Message, download: bool = False
) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
return {
"type": "image_url",
"image_url": {
"url": await convert_image_content_to_url(content, download=download),
"url": await convert_image_content_to_url(
content, download=download
),
},
}
else:
@ -548,7 +541,9 @@ async def convert_message_to_openai_dict_new(
elif isinstance(content_, ImageContentItem):
return OpenAIChatCompletionContentPartImageParam(
type="image_url",
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_)),
image_url=OpenAIImageURL(
url=await convert_image_content_to_url(content_)
),
)
elif isinstance(content_, list):
return [await impl(item) for item in content_]
@ -574,14 +569,32 @@ async def convert_message_to_openai_dict_new(
role="assistant",
content=await _convert_message_content(message.content),
tool_calls=[
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
arguments=json.dumps(tool.arguments),
),
type="function",
)
# OpenAIChatCompletionMessageToolCall(
# id=tool.call_id,
# function=OpenAIFunction(
# name=(
# tool.tool_name
# if not isinstance(tool.tool_name, BuiltinTool)
# else tool.tool_name.value
# ),
# arguments=json.dumps(tool.arguments),
# ),
# type="function",
# )
# using a dict instead of OpenAIChatCompletionMessageToolCall object
# as it fails to get json encoded
{
"id": tool.call_id,
"function": {
"name": (
tool.tool_name
if not isinstance(tool.tool_name, BuiltinTool)
else tool.tool_name.value
),
"arguments": json.dumps(tool.arguments),
},
"type": "function",
}
for tool in message.tool_calls
]
or None,
@ -604,7 +617,7 @@ async def convert_message_to_openai_dict_new(
def convert_tool_call(
tool_call: ChatCompletionMessageToolCall,
tool_call: OpenAIChatCompletionMessageToolCall,
) -> Union[ToolCall, UnparseableToolCall]:
"""
Convert a ChatCompletionMessageToolCall tool call to either a
@ -696,7 +709,11 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
properties = parameters["properties"]
required = []
for param_name, param in tool.parameters.items():
properties[param_name] = {"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(param.param_type, param.param_type)}
properties[param_name] = {
"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(
param.param_type, param.param_type
)
}
if param.description:
properties[param_name].update(description=param.description)
if param.default:
@ -762,15 +779,30 @@ def _convert_openai_tool_calls(
if not tool_calls:
return [] # CompletionMessage tool_calls is not optional
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
ls_tool_calls = []
for call in tool_calls:
args = call.function.arguments
# TGI is sending a dict instead of a json string
# While OpenAI spec expects a json string
if isinstance(args, str):
arguments = json.loads(args)
arguments_json = args
elif isinstance(args, dict):
arguments = args
arguments_json = json.dumps(args)
else:
raise ValueError(f"Unsupported arguments type: {type(args)}")
ls_tool_calls.append(
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=arguments,
arguments_json=arguments_json,
)
)
for call in tool_calls
]
return ls_tool_calls
def _convert_openai_logprobs(
@ -802,7 +834,11 @@ def _convert_openai_logprobs(
return None
return [
TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs})
TokenLogProbs(
logprobs_by_token={
logprobs.token: logprobs.logprob for logprobs in content.top_logprobs
}
)
for content in logprobs.content
]
@ -840,14 +876,17 @@ def convert_openai_chat_completion_choice(
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
assert hasattr(choice, "message") and choice.message, "error in server response: message not found"
assert hasattr(choice, "finish_reason") and choice.finish_reason, (
"error in server response: finish_reason not found"
)
assert (
hasattr(choice, "message") and choice.message
), "error in server response: message not found"
assert (
hasattr(choice, "finish_reason") and choice.finish_reason
), "error in server response: finish_reason not found"
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content or "", # CompletionMessage content is not optional
content=choice.message.content
or "", # CompletionMessage content is not optional
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
),
@ -856,13 +895,24 @@ def convert_openai_chat_completion_choice(
async def convert_openai_chat_completion_stream(
stream: AsyncStream[OpenAIChatCompletionChunk],
stream: Union[
AsyncStream[OpenAIChatCompletionChunk], Stream[OpenAIChatCompletionChunk]
],
enable_incremental_tool_calls: bool,
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
"""
Convert a stream of OpenAI chat completion chunks into a stream
of ChatCompletionResponseStreamChunk.
"""
async def yield_from_stream(stream):
if isinstance(stream, AsyncGenerator):
async for chunk in stream:
yield chunk
elif isinstance(stream, Generator):
for chunk in stream:
yield chunk
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
@ -874,7 +924,7 @@ async def convert_openai_chat_completion_stream(
stop_reason = None
tool_call_idx_to_buffer = {}
async for chunk in stream:
async for chunk in yield_from_stream(stream):
choice = chunk.choices[0] # assuming only one choice per chunk
# we assume there's only one finish_reason in the stream
@ -916,12 +966,60 @@ async def convert_openai_chat_completion_stream(
)
)
else:
for tool_call in choice.delta.tool_calls:
idx = tool_call.index if hasattr(tool_call, "index") else 0
if isinstance(choice.delta.tool_calls, list):
tool_calls = choice.delta.tool_calls
for tool_call in tool_calls:
idx = tool_call.index if hasattr(tool_call, "index") else 0
if idx not in tool_call_idx_to_buffer:
tool_call_idx_to_buffer[idx] = {
"call_id": tool_call.id,
"name": None,
"arguments": "",
"content": "",
}
buffer = tool_call_idx_to_buffer[idx]
if tool_call.function:
if tool_call.function.name:
buffer["name"] = tool_call.function.name
delta = f"{buffer['name']}("
buffer["content"] += delta
if tool_call.function.arguments:
delta = tool_call.function.arguments
buffer["arguments"] += delta
buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
logprobs=_convert_openai_logprobs(logprobs),
)
)
# TGI streams a non-openai compat response
elif isinstance(choice.delta.tool_calls, dict):
# tool_calls is a dict of the format
# {
# 'index': 0,
# 'id': '',
# 'type': 'function',
# 'function': {
# 'name': None,
# 'arguments': '{"'
# }
# }
tool_call = choice.delta.tool_calls
idx = tool_call["index"] if "index" in tool_call else 0
if idx not in tool_call_idx_to_buffer:
tool_call_idx_to_buffer[idx] = {
"call_id": tool_call.id,
"call_id": tool_call["id"],
"name": None,
"arguments": "",
"content": "",
@ -929,14 +1027,15 @@ async def convert_openai_chat_completion_stream(
buffer = tool_call_idx_to_buffer[idx]
if tool_call.function:
if tool_call.function.name:
buffer["name"] = tool_call.function.name
if "function" in tool_call:
function = tool_call["function"]
if function["name"]:
buffer["name"] = function["name"]
delta = f"{buffer['name']}("
buffer["content"] += delta
if tool_call.function.arguments:
delta = tool_call.function.arguments
if function["arguments"]:
delta = function["arguments"]
buffer["arguments"] += delta
buffer["content"] += delta
@ -994,7 +1093,6 @@ async def convert_openai_chat_completion_stream(
)
)
except json.JSONDecodeError as e:
print(f"Failed to parse arguments: {e}")
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
@ -1005,6 +1103,51 @@ async def convert_openai_chat_completion_stream(
stop_reason=stop_reason,
)
)
else:
# name is None but we have the arguments contain the entire function call
# example response where arguments is -->
# '{"function": {"_name": "get_weather", "location": "San Francisco, CA"}}<|eot_id|>'
# - parse the arguments
# - build try to build ToolCall and return it or return the content as is
if buffer["arguments"]:
arguments = buffer["arguments"]
# remove the eot_id and eom_id from the arguments
if arguments.endswith("<|eom_id|>"):
arguments = arguments[: -len("<|eom_id|>")]
if arguments.endswith("<|eot_id|>"):
arguments = arguments[: -len("<|eot_id|>")]
arguments = json.loads(arguments)
try:
tool_name = arguments["function"].pop("_name", None)
parsed_tool_call = ToolCall(
call_id=buffer["call_id"],
tool_name=tool_name,
arguments=arguments["function"],
arguments_json=json.dumps(arguments["function"]),
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=parsed_tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
except (KeyError, json.JSONDecodeError) as e:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=buffer["content"],
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -1013,3 +1156,113 @@ async def convert_openai_chat_completion_stream(
stop_reason=stop_reason,
)
)
async def convert_completion_request_to_openai_params(
request: CompletionRequest,
) -> dict:
"""
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
"""
input_dict = {}
if request.logprobs:
input_dict["logprobs"] = request.logprobs.top_k
return {
"model": request.model,
"prompt": request.content,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
"n": 1,
}
async def convert_chat_completion_request_to_openai_params(
request: ChatCompletionRequest,
) -> dict:
"""
Convert a ChatCompletionRequest to an OpenAI chat completion request.
"""
input_dict = {}
input_dict["messages"] = [
await convert_message_to_openai_dict_new(m) for m in request.messages
]
if fmt := request.response_format:
if not isinstance(fmt, JsonSchemaResponseFormat):
raise ValueError(
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
)
fmt = fmt.json_schema
name = fmt["title"]
del fmt["title"]
fmt["additionalProperties"] = False
# Apply additionalProperties: False recursively to all objects
fmt = _add_additional_properties_recursive(fmt)
from rich.pretty import pprint
pprint(fmt)
input_dict["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": name,
"schema": fmt,
"strict": True,
},
}
if request.tools:
input_dict["tools"] = [
convert_tooldef_to_openai_tool(tool) for tool in request.tools
]
if request.tool_config.tool_choice:
input_dict["tool_choice"] = (
request.tool_config.tool_choice.value
if isinstance(request.tool_config.tool_choice, ToolChoice)
else request.tool_config.tool_choice
)
return {
"model": request.model,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
"n": 1,
}
def _add_additional_properties_recursive(schema):
"""
Recursively add `additionalProperties: False` to all object schemas
"""
if isinstance(schema, dict):
if schema.get("type") == "object":
schema["additionalProperties"] = False
# Add required field with all property keys if properties exist
if "properties" in schema and schema["properties"]:
schema["required"] = list(schema["properties"].keys())
if "properties" in schema:
for prop_schema in schema["properties"].values():
_add_additional_properties_recursive(prop_schema)
for key in ["anyOf", "allOf", "oneOf"]:
if key in schema:
for sub_schema in schema[key]:
_add_additional_properties_recursive(sub_schema)
if "not" in schema:
_add_additional_properties_recursive(schema["not"])
# Handle $defs/$ref
if "$defs" in schema:
for def_schema in schema["$defs"].values():
_add_additional_properties_recursive(def_schema)
return schema

View file

@ -12,7 +12,6 @@ import re
from typing import List, Optional, Tuple, Union
import httpx
from PIL import Image as PIL_Image
from llama_stack.apis.common.content_types import (
ImageContentItem,
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
is_multimodal,
ModelFamily,
RawContent,
RawContentItem,
@ -43,7 +43,6 @@ from llama_stack.models.llama.datatypes import (
Role,
StopReason,
ToolPromptFormat,
is_multimodal,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import (
@ -56,6 +55,7 @@ from llama_stack.models.llama.llama3.prompt_templates import (
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.utils.inference import supported_inference_models
from PIL import Image as PIL_Image
log = get_logger(name=__name__, category="inference")
@ -129,7 +129,9 @@ async def interleaved_content_convert_to_raw(
if image.url.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
if not match:
raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
raise ValueError(
f"Invalid data URL format, {image.url.uri[:40]}..."
)
_, image_data = match.groups()
data = base64.b64decode(image_data)
elif image.url.uri.startswith("file://"):
@ -209,13 +211,17 @@ async def convert_image_content_to_url(
content, format = await localize_image_content(media)
if include_format:
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
return f"data:image/{format};base64," + base64.b64encode(content).decode(
"utf-8"
)
else:
return base64.b64encode(content).decode("utf-8")
async def completion_request_to_prompt(request: CompletionRequest) -> str:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
request.content = content
request = await convert_request_to_raw(request)
@ -224,8 +230,12 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str:
return formatter.tokenizer.decode(model_input.tokens)
async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
async def completion_request_to_prompt_model_input_info(
request: CompletionRequest,
) -> Tuple[str, int]:
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
request.content = content
request = await convert_request_to_raw(request)
@ -246,7 +256,9 @@ def augment_content_with_response_format_prompt(response_format, content):
return content
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
async def chat_completion_request_to_prompt(
request: ChatCompletionRequest, llama_model: str
) -> str:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
@ -254,7 +266,8 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
tool_prompt_format=request.tool_config.tool_prompt_format
or get_default_tool_prompt_format(llama_model),
)
return formatter.tokenizer.decode(model_input.tokens)
@ -269,10 +282,17 @@ async def chat_completion_request_to_model_input_info(
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
tool_prompt_format=request.tool_config.tool_prompt_format
or get_default_tool_prompt_format(llama_model),
)
tokens = []
for t in model_input.tokens:
if t == 128256:
tokens.append(formatter.vision_token)
else:
tokens.append(t)
return (
formatter.tokenizer.decode(model_input.tokens),
formatter.tokenizer.decode(tokens),
len(model_input.tokens),
)
@ -298,7 +318,8 @@ def chat_completion_request_to_messages(
return request.messages
if model.model_family == ModelFamily.llama3_1 or (
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
model.model_family == ModelFamily.llama3_2
and is_multimodal(model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_1(request)
@ -334,7 +355,9 @@ def augment_messages_for_tools_llama_3_1(
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
messages = []
@ -366,9 +389,13 @@ def augment_messages_for_tools_llama_3_1(
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
sys_content += "\n".join(
[_process(c) for c in existing_system_message.content]
)
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
tool_choice_prompt = _get_tool_choice_prompt(
request.tool_config.tool_choice, request.tools
)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
@ -402,7 +429,9 @@ def augment_messages_for_tools_llama_3_2(
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
sys_content = ""
custom_tools, builtin_tools = [], []
@ -423,10 +452,16 @@ def augment_messages_for_tools_llama_3_2(
if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}"
)
system_prompt = None
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
if (
existing_system_message
and request.tool_config.system_message_behavior
== SystemMessageBehavior.replace
):
system_prompt = existing_system_message.content
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
@ -435,11 +470,16 @@ def augment_messages_for_tools_llama_3_2(
sys_content += "\n"
if existing_system_message and (
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
request.tool_config.system_message_behavior == SystemMessageBehavior.append
or not custom_tools
):
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
sys_content += interleaved_content_as_str(
existing_system_message.content, sep="\n"
)
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
tool_choice_prompt = _get_tool_choice_prompt(
request.tool_config.tool_choice, request.tools
)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
@ -447,11 +487,15 @@ def augment_messages_for_tools_llama_3_2(
return messages
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str:
def _get_tool_choice_prompt(
tool_choice: ToolChoice | str, tools: List[ToolDefinition]
) -> str:
if tool_choice == ToolChoice.auto:
return ""
elif tool_choice == ToolChoice.required:
return "You MUST use one of the provided functions/tools to answer the user query."
return (
"You MUST use one of the provided functions/tools to answer the user query."
)
elif tool_choice == ToolChoice.none:
# tools are already not passed in
return ""
@ -463,11 +507,14 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
llama_model = resolve_model(model)
if llama_model is None:
log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
log.warning(
f"Could not resolve model {model}, defaulting to json tool prompt format"
)
return ToolPromptFormat.json
if llama_model.model_family == ModelFamily.llama3_1 or (
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
llama_model.model_family == ModelFamily.llama3_2
and is_multimodal(llama_model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
return ToolPromptFormat.json