forked from phoenix-oss/llama-stack-mirror
chore: remove llama_models.llama3.api imports from providers (#1107)
There should be a choke-point for llama3.api imports -- this is the prompt adapter. Creating a ChatFormat() object on demand is inexpensive. The underlying Tokenizer is a singleton anyway.
This commit is contained in:
parent
e9b8259cf9
commit
cdcbeb005b
13 changed files with 77 additions and 113 deletions
|
@ -13,7 +13,9 @@ import re
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from llama_models.datatypes import StopReason
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -66,6 +68,11 @@ class CompletionRequestWithRawContent(CompletionRequest):
|
|||
content: RawContent
|
||||
|
||||
|
||||
def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage:
|
||||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
return formatter.decode_assistant_message_from_content(content, stop_reason)
|
||||
|
||||
|
||||
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
|
||||
def _process(c) -> str:
|
||||
if isinstance(c, str):
|
||||
|
@ -207,20 +214,22 @@ async def convert_image_content_to_url(
|
|||
return base64.b64encode(content).decode("utf-8")
|
||||
|
||||
|
||||
async def completion_request_to_prompt(request: CompletionRequest, formatter: ChatFormat) -> str:
|
||||
async def completion_request_to_prompt(request: CompletionRequest) -> str:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_content(request.content)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
async def completion_request_to_prompt_model_input_info(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> Tuple[str, int]:
|
||||
async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_content(request.content)
|
||||
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
|
||||
|
||||
|
@ -237,22 +246,24 @@ def augment_content_with_response_format_prompt(response_format, content):
|
|||
return content
|
||||
|
||||
|
||||
async def chat_completion_request_to_prompt(
|
||||
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
||||
) -> str:
|
||||
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(request.messages)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
async def chat_completion_request_to_model_input_info(
|
||||
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
||||
request: ChatCompletionRequest, llama_model: str
|
||||
) -> Tuple[str, int]:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(request.messages)
|
||||
return (
|
||||
formatter.tokenizer.decode(model_input.tokens),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue