diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 5536ea3a5..5b0df91e7 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -9,7 +9,6 @@ import os import uuid from typing import AsyncGenerator, List, Optional -from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -62,7 +61,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMConfig): self.config = config self.engine = None - self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self): log.info("Initializing vLLM inference provider.") @@ -177,7 +175,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): log.info("Sampling params: %s", sampling_params) request_id = _random_uuid() - prompt = await chat_completion_request_to_prompt(request, self.config.model, self.formatter) + prompt = await chat_completion_request_to_prompt(request, self.config.model) vllm_sampling_params = self._sampling_params(request.sampling_params) results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id) if stream: @@ -201,11 +199,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(response, self.formatter, request) + return process_chat_completion_response(response, request) async def _stream_chat_completion( self, request: ChatCompletionRequest, results_generator: AsyncGenerator ) -> AsyncGenerator: + tokenizer = Tokenizer.get_instance() + async def _generate_and_convert_to_openai_compat(): cur = [] async for chunk in results_generator: @@ -216,7 +216,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): output = chunk.outputs[-1] new_tokens = output.token_ids[len(cur) :] - text = self.formatter.tokenizer.decode(new_tokens) + text = tokenizer.decode(new_tokens) cur.extend(new_tokens) choice = OpenAICompatCompletionChoice( finish_reason=output.finish_reason, @@ -227,7 +227,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse: diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 3d306e61f..05e61361c 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -6,8 +6,6 @@ from typing import AsyncGenerator, List, Optional -from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI from llama_stack.apis.common.content_types import InterleavedContent @@ -54,12 +52,8 @@ model_aliases = [ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: DatabricksImplConfig) -> None: - ModelRegistryHelper.__init__( - self, - model_aliases=model_aliases, - ) + ModelRegistryHelper.__init__(self, model_aliases=model_aliases) self.config = config - self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: return @@ -112,7 +106,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): ) -> ChatCompletionResponse: params = self._get_params(request) r = client.completions.create(**params) - return process_chat_completion_response(r, self.formatter, request) + return process_chat_completion_response(r, request) async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: params = self._get_params(request) @@ -123,13 +117,13 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + async for chunk in process_chat_completion_stream_response(stream, request): yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: return { "model": request.model, - "prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter), + "prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)), "stream": request.stream, **get_sampling_options(request.sampling_params), }