chore: remove llama_models.llama3.api imports from providers (#1107)

There should be a choke-point for llama3.api imports -- this is the
prompt adapter. Creating a ChatFormat() object on demand is inexpensive.
The underlying Tokenizer is a singleton anyway.
This commit is contained in:
Ashwin Bharambe 2025-02-19 19:01:29 -08:00 committed by GitHub
parent e9b8259cf9
commit cdcbeb005b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 77 additions and 113 deletions

View file

@ -7,7 +7,6 @@ import json
import logging
from typing import AsyncGenerator, Dict, List, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat
from openai.types.chat import ChatCompletionMessageToolCall
from pydantic import BaseModel
@ -40,6 +39,7 @@ from llama_stack.models.llama.datatypes import (
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
decode_assistant_message,
)
logger = logging.getLogger(__name__)
@ -149,7 +149,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Unio
return None
def process_completion_response(response: OpenAICompatCompletionResponse, formatter: ChatFormat) -> CompletionResponse:
def process_completion_response(response: OpenAICompatCompletionResponse) -> CompletionResponse:
choice = response.choices[0]
# drop suffix <eot_id> if present and return stop reason as end of turn
if choice.text.endswith("<|eot_id|>"):
@ -174,16 +174,13 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format
def process_chat_completion_response(
response: OpenAICompatCompletionResponse,
formatter: ChatFormat,
request: ChatCompletionRequest,
) -> ChatCompletionResponse:
choice = response.choices[0]
# TODO: This does not work well with tool calls for vLLM remote provider
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
raw_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
# NOTE: If we do not set tools in chat-completion request, we should not
# expect the ToolCall in the response. Instead, we should return the raw
@ -217,7 +214,7 @@ def process_chat_completion_response(
async def process_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
) -> AsyncGenerator:
stop_reason = None
@ -254,7 +251,6 @@ async def process_completion_stream_response(
async def process_chat_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
formatter: ChatFormat,
request: ChatCompletionRequest,
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
@ -333,7 +329,7 @@ async def process_chat_completion_stream_response(
)
# parse tool calls and report errors
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
message = decode_assistant_message(buffer, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:

View file

@ -13,7 +13,9 @@ import re
from typing import List, Optional, Tuple, Union
import httpx
from llama_models.datatypes import StopReason
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from PIL import Image as PIL_Image
from llama_stack.apis.common.content_types import (
@ -66,6 +68,11 @@ class CompletionRequestWithRawContent(CompletionRequest):
content: RawContent
def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage:
formatter = ChatFormat(Tokenizer.get_instance())
return formatter.decode_assistant_message_from_content(content, stop_reason)
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
def _process(c) -> str:
if isinstance(c, str):
@ -207,20 +214,22 @@ async def convert_image_content_to_url(
return base64.b64encode(content).decode("utf-8")
async def completion_request_to_prompt(request: CompletionRequest, formatter: ChatFormat) -> str:
async def completion_request_to_prompt(request: CompletionRequest) -> str:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.content = content
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_content(request.content)
return formatter.tokenizer.decode(model_input.tokens)
async def completion_request_to_prompt_model_input_info(
request: CompletionRequest, formatter: ChatFormat
) -> Tuple[str, int]:
async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.content = content
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_content(request.content)
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
@ -237,22 +246,24 @@ def augment_content_with_response_format_prompt(response_format, content):
return content
async def chat_completion_request_to_prompt(
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
) -> str:
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(request.messages)
return formatter.tokenizer.decode(model_input.tokens)
async def chat_completion_request_to_model_input_info(
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
request: ChatCompletionRequest, llama_model: str
) -> Tuple[str, int]:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(request.messages)
return (
formatter.tokenizer.decode(model_input.tokens),