forked from phoenix-oss/llama-stack-mirror
Fix conversion to RawMessage everywhere
This commit is contained in:
parent
fbca51d6da
commit
b7a7caa9a8
11 changed files with 87 additions and 78 deletions
|
@ -20,6 +20,7 @@ from llama_models.llama3.api.datatypes import (
|
|||
RawContent,
|
||||
RawContentItem,
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
RawTextItem,
|
||||
Role,
|
||||
ToolPromptFormat,
|
||||
|
@ -58,6 +59,14 @@ from llama_stack.providers.utils.inference import supported_inference_models
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
||||
messages: List[RawMessage]
|
||||
|
||||
|
||||
class CompletionRequestWithRawContent(CompletionRequest):
|
||||
content: RawContent
|
||||
|
||||
|
||||
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
|
||||
def _process(c) -> str:
|
||||
if isinstance(c, str):
|
||||
|
@ -75,6 +84,23 @@ def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> s
|
|||
return _process(content)
|
||||
|
||||
|
||||
async def convert_request_to_raw(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
messages = []
|
||||
for m in request.messages:
|
||||
content = await interleaved_content_convert_to_raw(m.content)
|
||||
d = m.model_dump()
|
||||
d["content"] = content
|
||||
messages.append(RawMessage(**d))
|
||||
request.messages = messages
|
||||
else:
|
||||
request.content = await interleaved_content_convert_to_raw(request.content)
|
||||
|
||||
return request
|
||||
|
||||
|
||||
async def interleaved_content_convert_to_raw(
|
||||
content: InterleavedContent,
|
||||
) -> RawContent:
|
||||
|
@ -169,23 +195,27 @@ async def convert_image_content_to_url(
|
|||
return base64.b64encode(content).decode("utf-8")
|
||||
|
||||
|
||||
def completion_request_to_prompt(
|
||||
async def completion_request_to_prompt(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> str:
|
||||
content = augment_content_with_response_format_prompt(
|
||||
request.response_format, request.content
|
||||
)
|
||||
model_input = formatter.encode_content(content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
model_input = formatter.encode_content(request.content)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
def completion_request_to_prompt_model_input_info(
|
||||
async def completion_request_to_prompt_model_input_info(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> Tuple[str, int]:
|
||||
content = augment_content_with_response_format_prompt(
|
||||
request.response_format, request.content
|
||||
)
|
||||
model_input = formatter.encode_content(content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
model_input = formatter.encode_content(request.content)
|
||||
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
|
||||
|
||||
|
||||
|
@ -199,19 +229,23 @@ def augment_content_with_response_format_prompt(response_format, content):
|
|||
return content
|
||||
|
||||
|
||||
def chat_completion_request_to_prompt(
|
||||
async def chat_completion_request_to_prompt(
|
||||
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
||||
) -> str:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
model_input = formatter.encode_dialog_prompt(messages)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
model_input = formatter.encode_dialog_prompt(request.messages)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
def chat_completion_request_to_model_input_info(
|
||||
async def chat_completion_request_to_model_input_info(
|
||||
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
||||
) -> Tuple[str, int]:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
model_input = formatter.encode_dialog_prompt(messages)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
model_input = formatter.encode_dialog_prompt(request.messages)
|
||||
return (
|
||||
formatter.tokenizer.decode(model_input.tokens),
|
||||
len(model_input.tokens),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue