From cdcbeb005b6f8026506263d8e18ed3fb656fa01b Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 19 Feb 2025 19:01:29 -0800 Subject: [PATCH] chore: remove llama_models.llama3.api imports from providers (#1107) There should be a choke-point for llama3.api imports -- this is the prompt adapter. Creating a ChatFormat() object on demand is inexpensive. The underlying Tokenizer is a singleton anyway. --- .../providers/inline/inference/vllm/vllm.py | 12 ++++----- .../remote/inference/bedrock/bedrock.py | 9 +++---- .../remote/inference/cerebras/cerebras.py | 17 +++++------- .../remote/inference/databricks/databricks.py | 14 +++------- .../remote/inference/fireworks/fireworks.py | 15 +++++------ .../remote/inference/ollama/ollama.py | 14 ++++------ .../remote/inference/runpod/runpod.py | 13 ++++----- .../remote/inference/sambanova/sambanova.py | 11 ++------ .../providers/remote/inference/tgi/tgi.py | 15 +++++------ .../remote/inference/together/together.py | 15 +++++------ .../providers/remote/inference/vllm/vllm.py | 14 +++------- .../utils/inference/openai_compat.py | 14 ++++------ .../utils/inference/prompt_adapter.py | 27 +++++++++++++------ 13 files changed, 77 insertions(+), 113 deletions(-) diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 5536ea3a5..5b0df91e7 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -9,7 +9,6 @@ import os import uuid from typing import AsyncGenerator, List, Optional -from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -62,7 +61,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMConfig): self.config = config self.engine = None - self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self): log.info("Initializing vLLM inference provider.") @@ -177,7 +175,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): log.info("Sampling params: %s", sampling_params) request_id = _random_uuid() - prompt = await chat_completion_request_to_prompt(request, self.config.model, self.formatter) + prompt = await chat_completion_request_to_prompt(request, self.config.model) vllm_sampling_params = self._sampling_params(request.sampling_params) results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id) if stream: @@ -201,11 +199,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(response, self.formatter, request) + return process_chat_completion_response(response, request) async def _stream_chat_completion( self, request: ChatCompletionRequest, results_generator: AsyncGenerator ) -> AsyncGenerator: + tokenizer = Tokenizer.get_instance() + async def _generate_and_convert_to_openai_compat(): cur = [] async for chunk in results_generator: @@ -216,7 +216,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): output = chunk.outputs[-1] new_tokens = output.token_ids[len(cur) :] - text = self.formatter.tokenizer.decode(new_tokens) + text = tokenizer.decode(new_tokens) cur.extend(new_tokens) choice = OpenAICompatCompletionChoice( finish_reason=output.finish_reason, @@ -227,7 +227,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse: diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index a706d4304..610707f3f 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -8,8 +8,6 @@ import json from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from botocore.client import BaseClient -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( @@ -54,7 +52,6 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): self._config = config self._client = create_bedrock_client(config) - self.formatter = ChatFormat(Tokenizer.get_instance()) @property def client(self) -> BaseClient: @@ -119,7 +116,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): ) response = OpenAICompatCompletionResponse(choices=[choice]) - return process_chat_completion_response(response, self.formatter, request) + return process_chat_completion_response(response, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params_for_chat_completion(request) @@ -137,7 +134,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): yield OpenAICompatCompletionResponse(choices=[choice]) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict: @@ -151,7 +148,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): if sampling_params.repetition_penalty > 0: options["repetition_penalty"] = sampling_params.repetition_penalty - prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter) + prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model)) return { "modelId": bedrock_model, "body": json.dumps( diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 0d8824fd2..e7b77a6e9 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -7,8 +7,6 @@ from typing import AsyncGenerator, List, Optional, Union from cerebras.cloud.sdk import AsyncCerebras -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( @@ -53,7 +51,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): model_aliases=model_aliases, ) self.config = config - self.formatter = ChatFormat(Tokenizer.get_instance()) self.client = AsyncCerebras( base_url=self.config.base_url, @@ -96,14 +93,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): r = await self.client.completions.create(**params) - return process_completion_response(r, self.formatter) + return process_completion_response(r) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) stream = await self.client.completions.create(**params) - async for chunk in process_completion_stream_response(stream, self.formatter): + async for chunk in process_completion_stream_response(stream): yield chunk async def chat_completion( @@ -143,14 +140,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): r = await self.client.completions.create(**params) - return process_chat_completion_response(r, self.formatter, request) + return process_chat_completion_response(r, request) async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) stream = await self.client.completions.create(**params) - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: @@ -159,11 +156,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): prompt = "" if isinstance(request, ChatCompletionRequest): - prompt = await chat_completion_request_to_prompt( - request, self.get_llama_model(request.model), self.formatter - ) + prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model)) elif isinstance(request, CompletionRequest): - prompt = await completion_request_to_prompt(request, self.formatter) + prompt = await completion_request_to_prompt(request) else: raise ValueError(f"Unknown request type {type(request)}") diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 3d306e61f..05e61361c 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -6,8 +6,6 @@ from typing import AsyncGenerator, List, Optional -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI from llama_stack.apis.common.content_types import InterleavedContent @@ -54,12 +52,8 @@ model_aliases = [ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: DatabricksImplConfig) -> None: - ModelRegistryHelper.__init__( - self, - model_aliases=model_aliases, - ) + ModelRegistryHelper.__init__(self, model_aliases=model_aliases) self.config = config - self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: return @@ -112,7 +106,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): ) -> ChatCompletionResponse: params = self._get_params(request) r = client.completions.create(**params) - return process_chat_completion_response(r, self.formatter, request) + return process_chat_completion_response(r, request) async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: params = self._get_params(request) @@ -123,13 +117,13 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: return { "model": request.model, - "prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter), + "prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)), "stream": request.stream, **get_sampling_options(request.sampling_params), } diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 3b834673d..4f8d167f1 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -7,8 +7,6 @@ from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( @@ -56,7 +54,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv def __init__(self, config: FireworksImplConfig) -> None: ModelRegistryHelper.__init__(self, MODEL_ALIASES) self.config = config - self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: pass @@ -105,7 +102,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: params = await self._get_params(request) r = await self._get_client().completion.acreate(**params) - return process_completion_response(r, self.formatter) + return process_completion_response(r) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -117,7 +114,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv yield chunk stream = _to_async_generator() - async for chunk in process_completion_stream_response(stream, self.formatter): + async for chunk in process_completion_stream_response(stream): yield chunk def _build_options( @@ -186,7 +183,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv r = await self._get_client().chat.completions.acreate(**params) else: r = await self._get_client().completion.acreate(**params) - return process_chat_completion_response(r, self.formatter, request) + return process_chat_completion_response(r, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -200,7 +197,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: @@ -214,11 +211,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv ] else: input_dict["prompt"] = await chat_completion_request_to_prompt( - request, self.get_llama_model(request.model), self.formatter + request, self.get_llama_model(request.model) ) else: assert not media_present, "Fireworks does not support media for Completion requests" - input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter) + input_dict["prompt"] = await completion_request_to_prompt(request) # Fireworks always prepends with BOS if "prompt" in input_dict: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index f524c0734..2488d9071 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -8,8 +8,6 @@ import logging from typing import AsyncGenerator, List, Optional, Union import httpx -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from ollama import AsyncClient from llama_stack.apis.common.content_types import ( @@ -138,7 +136,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, url: str) -> None: self.register_helper = ModelRegistryHelper(model_aliases) self.url = url - self.formatter = ChatFormat(Tokenizer.get_instance()) @property def client(self) -> AsyncClient: @@ -197,7 +194,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_completion_stream_response(stream, self.formatter): + async for chunk in process_completion_stream_response(stream): yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: @@ -212,7 +209,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): choices=[choice], ) - return process_completion_response(response, self.formatter) + return process_completion_response(response) async def chat_completion( self, @@ -262,11 +259,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict["prompt"] = await chat_completion_request_to_prompt( request, self.register_helper.get_llama_model(request.model), - self.formatter, ) else: assert not media_present, "Ollama does not support media for Completion requests" - input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter) + input_dict["prompt"] = await completion_request_to_prompt(request) input_dict["raw"] = True if fmt := request.response_format: @@ -304,7 +300,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(response, self.formatter, request) + return process_chat_completion_response(response, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -330,7 +326,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def embeddings( diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 1abb17336..09122a8e6 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -5,8 +5,6 @@ # the root directory of this source tree. from typing import AsyncGenerator -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 @@ -45,7 +43,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: RunpodImplConfig) -> None: ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS) self.config = config - self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: return @@ -56,7 +53,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): async def completion( self, model: str, - content: InterleavedTextMedia, + content: InterleavedContent, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, @@ -97,7 +94,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): ) -> ChatCompletionResponse: params = self._get_params(request) r = client.completions.create(**params) - return process_chat_completion_response(r, self.formatter, request) + return process_chat_completion_response(r, request) async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: params = self._get_params(request) @@ -108,13 +105,13 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: return { "model": self.map_to_provider_model(request.model), - "prompt": chat_completion_request_to_prompt(request, self.formatter), + "prompt": chat_completion_request_to_prompt(request), "stream": request.stream, **get_sampling_options(request.sampling_params), } @@ -122,6 +119,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): async def embeddings( self, model: str, - contents: List[InterleavedTextMedia], + contents: List[InterleavedContent], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 9b3562870..30a0934a3 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -7,8 +7,6 @@ import json from typing import AsyncGenerator -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI from llama_stack.apis.common.content_types import ( @@ -38,13 +36,8 @@ from .models import MODEL_ALIASES class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: SambaNovaImplConfig) -> None: - ModelRegistryHelper.__init__( - self, - model_aliases=MODEL_ALIASES, - ) - + ModelRegistryHelper.__init__(self, model_aliases=MODEL_ALIASES) self.config = config - self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: return @@ -120,7 +113,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def embeddings( diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 1909e01f8..7ffeced95 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -9,8 +9,6 @@ import logging from typing import AsyncGenerator, List, Optional from huggingface_hub import AsyncInferenceClient, HfApi -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( @@ -72,7 +70,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): model_id: str def __init__(self) -> None: - self.formatter = ChatFormat(Tokenizer.get_instance()) self.register_helper = ModelRegistryHelper(build_model_aliases()) self.huggingface_repo_to_llama_model_id = { model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo @@ -149,7 +146,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): return options async def _get_params_for_completion(self, request: CompletionRequest) -> dict: - prompt, input_tokens = await completion_request_to_prompt_model_input_info(request, self.formatter) + prompt, input_tokens = await completion_request_to_prompt_model_input_info(request) return dict( prompt=prompt, @@ -177,7 +174,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_completion_stream_response(stream, self.formatter): + async for chunk in process_completion_stream_response(stream): yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: @@ -193,7 +190,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): choices=[choice], ) - return process_completion_response(response, self.formatter) + return process_completion_response(response) async def chat_completion( self, @@ -236,7 +233,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(response, self.formatter, request) + return process_chat_completion_response(response, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -252,12 +249,12 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def _get_params(self, request: ChatCompletionRequest) -> dict: prompt, input_tokens = await chat_completion_request_to_model_input_info( - request, self.register_helper.get_llama_model(request.model), self.formatter + request, self.register_helper.get_llama_model(request.model) ) return dict( prompt=prompt, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 7a37ff616..75428e70a 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -6,8 +6,6 @@ from typing import AsyncGenerator, List, Optional, Union -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from together import Together from llama_stack.apis.common.content_types import InterleavedContent @@ -55,7 +53,6 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi def __init__(self, config: TogetherImplConfig) -> None: ModelRegistryHelper.__init__(self, MODEL_ALIASES) self.config = config - self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: pass @@ -102,7 +99,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) r = self._get_client().completions.create(**params) - return process_completion_response(r, self.formatter) + return process_completion_response(r) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -114,7 +111,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi yield chunk stream = _to_async_generator() - async for chunk in process_completion_stream_response(stream, self.formatter): + async for chunk in process_completion_stream_response(stream): yield chunk def _build_options( @@ -180,7 +177,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi r = self._get_client().chat.completions.create(**params) else: r = self._get_client().completions.create(**params) - return process_chat_completion_response(r, self.formatter, request) + return process_chat_completion_response(r, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -195,7 +192,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: @@ -206,11 +203,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages] else: input_dict["prompt"] = await chat_completion_request_to_prompt( - request, self.get_llama_model(request.model), self.formatter + request, self.get_llama_model(request.model) ) else: assert not media_present, "Together does not support media for Completion requests" - input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter) + input_dict["prompt"] = await completion_request_to_prompt(request) return { "model": request.model, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index b22284302..e073b98c6 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -8,8 +8,6 @@ import logging from typing import AsyncGenerator, List, Optional, Union from llama_models.datatypes import StopReason, ToolCall -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus @@ -191,7 +189,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.register_helper = ModelRegistryHelper(build_model_aliases()) self.config = config - self.formatter = ChatFormat(Tokenizer.get_instance()) self.client = None async def initialize(self) -> None: @@ -286,14 +283,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): if len(request.tools) > 0: res = _process_vllm_chat_completion_stream_response(stream) else: - res = process_chat_completion_stream_response(stream, self.formatter, request) + res = process_chat_completion_stream_response(stream, request) async for chunk in res: yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: params = await self._get_params(request) r = self.client.completions.create(**params) - return process_completion_response(r, self.formatter) + return process_completion_response(r) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -305,7 +302,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk stream = _to_async_generator() - async for chunk in process_completion_stream_response(stream, self.formatter): + async for chunk in process_completion_stream_response(stream): yield chunk async def register_model(self, model: Model) -> Model: @@ -332,10 +329,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages] else: assert not request_has_media(request), "vLLM does not support media for Completion requests" - input_dict["prompt"] = await completion_request_to_prompt( - request, - self.formatter, - ) + input_dict["prompt"] = await completion_request_to_prompt(request) if fmt := request.response_format: if fmt.type == ResponseFormatType.json_schema.value: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index def7e8f37..946d27763 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -7,7 +7,6 @@ import json import logging from typing import AsyncGenerator, Dict, List, Optional, Union -from llama_models.llama3.api.chat_format import ChatFormat from openai.types.chat import ChatCompletionMessageToolCall from pydantic import BaseModel @@ -40,6 +39,7 @@ from llama_stack.models.llama.datatypes import ( ) from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, + decode_assistant_message, ) logger = logging.getLogger(__name__) @@ -149,7 +149,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Unio return None -def process_completion_response(response: OpenAICompatCompletionResponse, formatter: ChatFormat) -> CompletionResponse: +def process_completion_response(response: OpenAICompatCompletionResponse) -> CompletionResponse: choice = response.choices[0] # drop suffix if present and return stop reason as end of turn if choice.text.endswith("<|eot_id|>"): @@ -174,16 +174,13 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format def process_chat_completion_response( response: OpenAICompatCompletionResponse, - formatter: ChatFormat, request: ChatCompletionRequest, ) -> ChatCompletionResponse: choice = response.choices[0] # TODO: This does not work well with tool calls for vLLM remote provider # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = formatter.decode_assistant_message_from_content( - text_from_choice(choice), get_stop_reason(choice.finish_reason) - ) + raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason)) # NOTE: If we do not set tools in chat-completion request, we should not # expect the ToolCall in the response. Instead, we should return the raw @@ -217,7 +214,7 @@ def process_chat_completion_response( async def process_completion_stream_response( - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat + stream: AsyncGenerator[OpenAICompatCompletionResponse, None], ) -> AsyncGenerator: stop_reason = None @@ -254,7 +251,6 @@ async def process_completion_stream_response( async def process_chat_completion_stream_response( stream: AsyncGenerator[OpenAICompatCompletionResponse, None], - formatter: ChatFormat, request: ChatCompletionRequest, ) -> AsyncGenerator: yield ChatCompletionResponseStreamChunk( @@ -333,7 +329,7 @@ async def process_chat_completion_stream_response( ) # parse tool calls and report errors - message = formatter.decode_assistant_message_from_content(buffer, stop_reason) + message = decode_assistant_message(buffer, stop_reason) parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 2782c661f..10fe442e8 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -13,7 +13,9 @@ import re from typing import List, Optional, Tuple, Union import httpx +from llama_models.datatypes import StopReason from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.tokenizer import Tokenizer from PIL import Image as PIL_Image from llama_stack.apis.common.content_types import ( @@ -66,6 +68,11 @@ class CompletionRequestWithRawContent(CompletionRequest): content: RawContent +def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage: + formatter = ChatFormat(Tokenizer.get_instance()) + return formatter.decode_assistant_message_from_content(content, stop_reason) + + def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str: def _process(c) -> str: if isinstance(c, str): @@ -207,20 +214,22 @@ async def convert_image_content_to_url( return base64.b64encode(content).decode("utf-8") -async def completion_request_to_prompt(request: CompletionRequest, formatter: ChatFormat) -> str: +async def completion_request_to_prompt(request: CompletionRequest) -> str: content = augment_content_with_response_format_prompt(request.response_format, request.content) request.content = content request = await convert_request_to_raw(request) + + formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_content(request.content) return formatter.tokenizer.decode(model_input.tokens) -async def completion_request_to_prompt_model_input_info( - request: CompletionRequest, formatter: ChatFormat -) -> Tuple[str, int]: +async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]: content = augment_content_with_response_format_prompt(request.response_format, request.content) request.content = content request = await convert_request_to_raw(request) + + formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_content(request.content) return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens)) @@ -237,22 +246,24 @@ def augment_content_with_response_format_prompt(response_format, content): return content -async def chat_completion_request_to_prompt( - request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat -) -> str: +async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str: messages = chat_completion_request_to_messages(request, llama_model) request.messages = messages request = await convert_request_to_raw(request) + + formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_dialog_prompt(request.messages) return formatter.tokenizer.decode(model_input.tokens) async def chat_completion_request_to_model_input_info( - request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat + request: ChatCompletionRequest, llama_model: str ) -> Tuple[str, int]: messages = chat_completion_request_to_messages(request, llama_model) request.messages = messages request = await convert_request_to_raw(request) + + formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_dialog_prompt(request.messages) return ( formatter.tokenizer.decode(model_input.tokens),