diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index e896f0597..0ec225428 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -8,8 +8,6 @@ import json from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from botocore.client import BaseClient -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( @@ -69,7 +67,6 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): self._config = config self._client = create_bedrock_client(config) - self.formatter = ChatFormat(Tokenizer.get_instance()) @property def client(self) -> BaseClient: @@ -134,7 +131,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): ) response = OpenAICompatCompletionResponse(choices=[choice]) - return process_chat_completion_response(response, self.formatter, request) + return process_chat_completion_response(response, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params_for_chat_completion(request) @@ -152,7 +149,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): yield OpenAICompatCompletionResponse(choices=[choice]) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict: @@ -166,7 +163,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): if sampling_params.repetition_penalty > 0: options["repetition_penalty"] = sampling_params.repetition_penalty - prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter) + prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model)) return { "modelId": bedrock_model, "body": json.dumps( diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 1ce267e8d..96ae492ff 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -7,8 +7,6 @@ from typing import AsyncGenerator, List, Optional, Union from cerebras.cloud.sdk import AsyncCerebras -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( @@ -64,7 +62,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): model_aliases=model_aliases, ) self.config = config - self.formatter = ChatFormat(Tokenizer.get_instance()) self.client = AsyncCerebras( base_url=self.config.base_url, @@ -107,14 +104,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): r = await self.client.completions.create(**params) - return process_completion_response(r, self.formatter) + return process_completion_response(r) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) stream = await self.client.completions.create(**params) - async for chunk in process_completion_stream_response(stream, self.formatter): + async for chunk in process_completion_stream_response(stream): yield chunk async def chat_completion( @@ -154,14 +151,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): r = await self.client.completions.create(**params) - return process_chat_completion_response(r, self.formatter, request) + return process_chat_completion_response(r, request) async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) stream = await self.client.completions.create(**params) - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: @@ -170,11 +167,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): prompt = "" if isinstance(request, ChatCompletionRequest): - prompt = await chat_completion_request_to_prompt( - request, self.get_llama_model(request.model), self.formatter - ) + prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model)) elif isinstance(request, CompletionRequest): - prompt = await completion_request_to_prompt(request, self.formatter) + prompt = await completion_request_to_prompt(request) else: raise ValueError(f"Unknown request type {type(request)}") diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index acf37b248..cb85274f7 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -7,8 +7,6 @@ from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( @@ -100,7 +98,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv def __init__(self, config: FireworksImplConfig) -> None: ModelRegistryHelper.__init__(self, MODEL_ALIASES) self.config = config - self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: pass @@ -149,7 +146,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: params = await self._get_params(request) r = await self._get_client().completion.acreate(**params) - return process_completion_response(r, self.formatter) + return process_completion_response(r) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -161,7 +158,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv yield chunk stream = _to_async_generator() - async for chunk in process_completion_stream_response(stream, self.formatter): + async for chunk in process_completion_stream_response(stream): yield chunk def _build_options( @@ -230,7 +227,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv r = await self._get_client().chat.completions.acreate(**params) else: r = await self._get_client().completion.acreate(**params) - return process_chat_completion_response(r, self.formatter, request) + return process_chat_completion_response(r, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -244,7 +241,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: @@ -258,11 +255,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv ] else: input_dict["prompt"] = await chat_completion_request_to_prompt( - request, self.get_llama_model(request.model), self.formatter + request, self.get_llama_model(request.model) ) else: assert not media_present, "Fireworks does not support media for Completion requests" - input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter) + input_dict["prompt"] = await completion_request_to_prompt(request) # Fireworks always prepends with BOS if "prompt" in input_dict: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index f524c0734..2488d9071 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -8,8 +8,6 @@ import logging from typing import AsyncGenerator, List, Optional, Union import httpx -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from ollama import AsyncClient from llama_stack.apis.common.content_types import ( @@ -138,7 +136,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, url: str) -> None: self.register_helper = ModelRegistryHelper(model_aliases) self.url = url - self.formatter = ChatFormat(Tokenizer.get_instance()) @property def client(self) -> AsyncClient: @@ -197,7 +194,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_completion_stream_response(stream, self.formatter): + async for chunk in process_completion_stream_response(stream): yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: @@ -212,7 +209,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): choices=[choice], ) - return process_completion_response(response, self.formatter) + return process_completion_response(response) async def chat_completion( self, @@ -262,11 +259,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict["prompt"] = await chat_completion_request_to_prompt( request, self.register_helper.get_llama_model(request.model), - self.formatter, ) else: assert not media_present, "Ollama does not support media for Completion requests" - input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter) + input_dict["prompt"] = await completion_request_to_prompt(request) input_dict["raw"] = True if fmt := request.response_format: @@ -304,7 +300,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(response, self.formatter, request) + return process_chat_completion_response(response, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -330,7 +326,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def embeddings( diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 1abb17336..09122a8e6 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -5,8 +5,6 @@ # the root directory of this source tree. from typing import AsyncGenerator -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 @@ -45,7 +43,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: RunpodImplConfig) -> None: ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS) self.config = config - self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: return @@ -56,7 +53,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): async def completion( self, model: str, - content: InterleavedTextMedia, + content: InterleavedContent, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, @@ -97,7 +94,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): ) -> ChatCompletionResponse: params = self._get_params(request) r = client.completions.create(**params) - return process_chat_completion_response(r, self.formatter, request) + return process_chat_completion_response(r, request) async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: params = self._get_params(request) @@ -108,13 +105,13 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: return { "model": self.map_to_provider_model(request.model), - "prompt": chat_completion_request_to_prompt(request, self.formatter), + "prompt": chat_completion_request_to_prompt(request), "stream": request.stream, **get_sampling_options(request.sampling_params), } @@ -122,6 +119,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): async def embeddings( self, model: str, - contents: List[InterleavedTextMedia], + contents: List[InterleavedContent], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index b906e0dcb..fae4b24c6 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -7,8 +7,6 @@ import json from typing import AsyncGenerator -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI from llama_stack.apis.common.content_types import ( @@ -78,13 +76,8 @@ MODEL_ALIASES = [ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: SambaNovaImplConfig) -> None: - ModelRegistryHelper.__init__( - self, - model_aliases=MODEL_ALIASES, - ) - + ModelRegistryHelper.__init__(self, model_aliases=MODEL_ALIASES) self.config = config - self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: return @@ -160,7 +153,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def embeddings( diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 1909e01f8..7ffeced95 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -9,8 +9,6 @@ import logging from typing import AsyncGenerator, List, Optional from huggingface_hub import AsyncInferenceClient, HfApi -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( @@ -72,7 +70,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): model_id: str def __init__(self) -> None: - self.formatter = ChatFormat(Tokenizer.get_instance()) self.register_helper = ModelRegistryHelper(build_model_aliases()) self.huggingface_repo_to_llama_model_id = { model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo @@ -149,7 +146,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): return options async def _get_params_for_completion(self, request: CompletionRequest) -> dict: - prompt, input_tokens = await completion_request_to_prompt_model_input_info(request, self.formatter) + prompt, input_tokens = await completion_request_to_prompt_model_input_info(request) return dict( prompt=prompt, @@ -177,7 +174,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_completion_stream_response(stream, self.formatter): + async for chunk in process_completion_stream_response(stream): yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: @@ -193,7 +190,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): choices=[choice], ) - return process_completion_response(response, self.formatter) + return process_completion_response(response) async def chat_completion( self, @@ -236,7 +233,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(response, self.formatter, request) + return process_chat_completion_response(response, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -252,12 +249,12 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def _get_params(self, request: ChatCompletionRequest) -> dict: prompt, input_tokens = await chat_completion_request_to_model_input_info( - request, self.register_helper.get_llama_model(request.model), self.formatter + request, self.register_helper.get_llama_model(request.model) ) return dict( prompt=prompt, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 054501da8..93e0547bb 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -6,8 +6,6 @@ from typing import AsyncGenerator, List, Optional, Union -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from together import Together from llama_stack.apis.common.content_types import InterleavedContent @@ -95,7 +93,6 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi def __init__(self, config: TogetherImplConfig) -> None: ModelRegistryHelper.__init__(self, MODEL_ALIASES) self.config = config - self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: pass @@ -142,7 +139,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) r = self._get_client().completions.create(**params) - return process_completion_response(r, self.formatter) + return process_completion_response(r) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -154,7 +151,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi yield chunk stream = _to_async_generator() - async for chunk in process_completion_stream_response(stream, self.formatter): + async for chunk in process_completion_stream_response(stream): yield chunk def _build_options( @@ -220,7 +217,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi r = self._get_client().chat.completions.create(**params) else: r = self._get_client().completions.create(**params) - return process_chat_completion_response(r, self.formatter, request) + return process_chat_completion_response(r, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -235,7 +232,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: @@ -246,11 +243,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages] else: input_dict["prompt"] = await chat_completion_request_to_prompt( - request, self.get_llama_model(request.model), self.formatter + request, self.get_llama_model(request.model) ) else: assert not media_present, "Together does not support media for Completion requests" - input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter) + input_dict["prompt"] = await completion_request_to_prompt(request) return { "model": request.model, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index b22284302..e073b98c6 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -8,8 +8,6 @@ import logging from typing import AsyncGenerator, List, Optional, Union from llama_models.datatypes import StopReason, ToolCall -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus @@ -191,7 +189,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.register_helper = ModelRegistryHelper(build_model_aliases()) self.config = config - self.formatter = ChatFormat(Tokenizer.get_instance()) self.client = None async def initialize(self) -> None: @@ -286,14 +283,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): if len(request.tools) > 0: res = _process_vllm_chat_completion_stream_response(stream) else: - res = process_chat_completion_stream_response(stream, self.formatter, request) + res = process_chat_completion_stream_response(stream, request) async for chunk in res: yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: params = await self._get_params(request) r = self.client.completions.create(**params) - return process_completion_response(r, self.formatter) + return process_completion_response(r) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -305,7 +302,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk stream = _to_async_generator() - async for chunk in process_completion_stream_response(stream, self.formatter): + async for chunk in process_completion_stream_response(stream): yield chunk async def register_model(self, model: Model) -> Model: @@ -332,10 +329,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages] else: assert not request_has_media(request), "vLLM does not support media for Completion requests" - input_dict["prompt"] = await completion_request_to_prompt( - request, - self.formatter, - ) + input_dict["prompt"] = await completion_request_to_prompt(request) if fmt := request.response_format: if fmt.type == ResponseFormatType.json_schema.value: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index da8e3ce2d..db9c761ba 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -7,7 +7,6 @@ import json import logging from typing import AsyncGenerator, Dict, List, Optional, Union -from llama_models.llama3.api.chat_format import ChatFormat from openai.types.chat import ChatCompletionMessageToolCall from pydantic import BaseModel @@ -40,6 +39,7 @@ from llama_stack.models.llama.datatypes import ( ) from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, + decode_assistant_message, ) logger = logging.getLogger(__name__) @@ -149,7 +149,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Unio return None -def process_completion_response(response: OpenAICompatCompletionResponse, formatter: ChatFormat) -> CompletionResponse: +def process_completion_response(response: OpenAICompatCompletionResponse) -> CompletionResponse: choice = response.choices[0] # drop suffix if present and return stop reason as end of turn if choice.text.endswith("<|eot_id|>"): @@ -174,16 +174,13 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format def process_chat_completion_response( response: OpenAICompatCompletionResponse, - formatter: ChatFormat, request: ChatCompletionRequest, ) -> ChatCompletionResponse: choice = response.choices[0] # TODO: This does not work well with tool calls for vLLM remote provider # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = formatter.decode_assistant_message_from_content( - text_from_choice(choice), get_stop_reason(choice.finish_reason) - ) + raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason)) # NOTE: If we do not set tools in chat-completion request, we should not # expect the ToolCall in the response. Instead, we should return the raw @@ -217,7 +214,7 @@ def process_chat_completion_response( async def process_completion_stream_response( - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat + stream: AsyncGenerator[OpenAICompatCompletionResponse, None], ) -> AsyncGenerator: stop_reason = None @@ -254,7 +251,6 @@ async def process_completion_stream_response( async def process_chat_completion_stream_response( stream: AsyncGenerator[OpenAICompatCompletionResponse, None], - formatter: ChatFormat, request: ChatCompletionRequest, ) -> AsyncGenerator: yield ChatCompletionResponseStreamChunk( @@ -333,7 +329,7 @@ async def process_chat_completion_stream_response( ) # parse tool calls and report errors - message = formatter.decode_assistant_message_from_content(buffer, stop_reason) + message = decode_assistant_message(buffer, stop_reason) parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index b7945dee7..80d2baba7 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -13,7 +13,9 @@ import re from typing import List, Optional, Tuple, Union import httpx +from llama_models.datatypes import StopReason from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.tokenizer import Tokenizer from PIL import Image as PIL_Image from llama_stack.apis.common.content_types import ( @@ -65,6 +67,11 @@ class CompletionRequestWithRawContent(CompletionRequest): content: RawContent +def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage: + formatter = ChatFormat(Tokenizer.get_instance()) + return formatter.decode_assistant_message_from_content(content, stop_reason) + + def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str: def _process(c) -> str: if isinstance(c, str): @@ -206,20 +213,22 @@ async def convert_image_content_to_url( return base64.b64encode(content).decode("utf-8") -async def completion_request_to_prompt(request: CompletionRequest, formatter: ChatFormat) -> str: +async def completion_request_to_prompt(request: CompletionRequest) -> str: content = augment_content_with_response_format_prompt(request.response_format, request.content) request.content = content request = await convert_request_to_raw(request) + + formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_content(request.content) return formatter.tokenizer.decode(model_input.tokens) -async def completion_request_to_prompt_model_input_info( - request: CompletionRequest, formatter: ChatFormat -) -> Tuple[str, int]: +async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]: content = augment_content_with_response_format_prompt(request.response_format, request.content) request.content = content request = await convert_request_to_raw(request) + + formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_content(request.content) return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens)) @@ -236,22 +245,24 @@ def augment_content_with_response_format_prompt(response_format, content): return content -async def chat_completion_request_to_prompt( - request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat -) -> str: +async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str: messages = chat_completion_request_to_messages(request, llama_model) request.messages = messages request = await convert_request_to_raw(request) + + formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_dialog_prompt(request.messages) return formatter.tokenizer.decode(model_input.tokens) async def chat_completion_request_to_model_input_info( - request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat + request: ChatCompletionRequest, llama_model: str ) -> Tuple[str, int]: messages = chat_completion_request_to_messages(request, llama_model) request.messages = messages request = await convert_request_to_raw(request) + + formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_dialog_prompt(request.messages) return ( formatter.tokenizer.decode(model_input.tokens),