chore: remove llama_models.llama3.api imports from providers (#1107)

There should be a choke-point for llama3.api imports -- this is the
prompt adapter. Creating a ChatFormat() object on demand is inexpensive.
The underlying Tokenizer is a singleton anyway.
This commit is contained in:
Ashwin Bharambe 2025-02-19 19:01:29 -08:00 committed by GitHub
parent e9b8259cf9
commit cdcbeb005b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 77 additions and 113 deletions

View file

@ -9,7 +9,6 @@ import os
import uuid
from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -62,7 +61,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMConfig):
self.config = config
self.engine = None
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self):
log.info("Initializing vLLM inference provider.")
@ -177,7 +175,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid()
prompt = await chat_completion_request_to_prompt(request, self.config.model, self.formatter)
prompt = await chat_completion_request_to_prompt(request, self.config.model)
vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
if stream:
@ -201,11 +199,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, self.formatter, request)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> AsyncGenerator:
tokenizer = Tokenizer.get_instance()
async def _generate_and_convert_to_openai_compat():
cur = []
async for chunk in results_generator:
@ -216,7 +216,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
output = chunk.outputs[-1]
new_tokens = output.token_ids[len(cur) :]
text = self.formatter.tokenizer.decode(new_tokens)
text = tokenizer.decode(new_tokens)
cur.extend(new_tokens)
choice = OpenAICompatCompletionChoice(
finish_reason=output.finish_reason,
@ -227,7 +227,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:

View file

@ -8,8 +8,6 @@ import json
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
@ -54,7 +52,6 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
self._config = config
self._client = create_bedrock_client(config)
self.formatter = ChatFormat(Tokenizer.get_instance())
@property
def client(self) -> BaseClient:
@ -119,7 +116,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
)
response = OpenAICompatCompletionResponse(choices=[choice])
return process_chat_completion_response(response, self.formatter, request)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_chat_completion(request)
@ -137,7 +134,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
yield OpenAICompatCompletionResponse(choices=[choice])
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
@ -151,7 +148,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
if sampling_params.repetition_penalty > 0:
options["repetition_penalty"] = sampling_params.repetition_penalty
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter)
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
return {
"modelId": bedrock_model,
"body": json.dumps(

View file

@ -7,8 +7,6 @@
from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
@ -53,7 +51,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
model_aliases=model_aliases,
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = AsyncCerebras(
base_url=self.config.base_url,
@ -96,14 +93,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
r = await self.client.completions.create(**params)
return process_completion_response(r, self.formatter)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
async def chat_completion(
@ -143,14 +140,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
r = await self.client.completions.create(**params)
return process_chat_completion_response(r, self.formatter, request)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
@ -159,11 +156,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
prompt = ""
if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
elif isinstance(request, CompletionRequest):
prompt = await completion_request_to_prompt(request, self.formatter)
prompt = await completion_request_to_prompt(request)
else:
raise ValueError(f"Unknown request type {type(request)}")

View file

@ -6,8 +6,6 @@
from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent
@ -54,12 +52,8 @@ model_aliases = [
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=model_aliases,
)
ModelRegistryHelper.__init__(self, model_aliases=model_aliases)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
@ -112,7 +106,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, self.formatter, request)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)
@ -123,13 +117,13 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": request.model,
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter),
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}

View file

@ -7,8 +7,6 @@
from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
@ -56,7 +54,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
pass
@ -105,7 +102,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self._get_client().completion.acreate(**params)
return process_completion_response(r, self.formatter)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -117,7 +114,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
def _build_options(
@ -186,7 +183,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
r = await self._get_client().chat.completions.acreate(**params)
else:
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, self.formatter, request)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -200,7 +197,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
@ -214,11 +211,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
request, self.get_llama_model(request.model)
)
else:
assert not media_present, "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
input_dict["prompt"] = await completion_request_to_prompt(request)
# Fireworks always prepends with BOS
if "prompt" in input_dict:

View file

@ -8,8 +8,6 @@ import logging
from typing import AsyncGenerator, List, Optional, Union
import httpx
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
from llama_stack.apis.common.content_types import (
@ -138,7 +136,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, url: str) -> None:
self.register_helper = ModelRegistryHelper(model_aliases)
self.url = url
self.formatter = ChatFormat(Tokenizer.get_instance())
@property
def client(self) -> AsyncClient:
@ -197,7 +194,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
@ -212,7 +209,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
choices=[choice],
)
return process_completion_response(response, self.formatter)
return process_completion_response(response)
async def chat_completion(
self,
@ -262,11 +259,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["prompt"] = await chat_completion_request_to_prompt(
request,
self.register_helper.get_llama_model(request.model),
self.formatter,
)
else:
assert not media_present, "Ollama does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
input_dict["prompt"] = await completion_request_to_prompt(request)
input_dict["raw"] = True
if fmt := request.response_format:
@ -304,7 +300,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, self.formatter, request)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -330,7 +326,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def embeddings(

View file

@ -5,8 +5,6 @@
# the root directory of this source tree.
from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
@ -45,7 +43,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
@ -56,7 +53,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
async def completion(
self,
model: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -97,7 +94,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, self.formatter, request)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)
@ -108,13 +105,13 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"prompt": chat_completion_request_to_prompt(request),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
@ -122,6 +119,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -7,8 +7,6 @@
import json
from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.common.content_types import (
@ -38,13 +36,8 @@ from .models import MODEL_ALIASES
class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: SambaNovaImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=MODEL_ALIASES,
)
ModelRegistryHelper.__init__(self, model_aliases=MODEL_ALIASES)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
@ -120,7 +113,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def embeddings(

View file

@ -9,8 +9,6 @@ import logging
from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
@ -72,7 +70,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
model_id: str
def __init__(self) -> None:
self.formatter = ChatFormat(Tokenizer.get_instance())
self.register_helper = ModelRegistryHelper(build_model_aliases())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
@ -149,7 +146,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
return options
async def _get_params_for_completion(self, request: CompletionRequest) -> dict:
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request, self.formatter)
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request)
return dict(
prompt=prompt,
@ -177,7 +174,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
@ -193,7 +190,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
choices=[choice],
)
return process_completion_response(response, self.formatter)
return process_completion_response(response)
async def chat_completion(
self,
@ -236,7 +233,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, self.formatter, request)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -252,12 +249,12 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = await chat_completion_request_to_model_input_info(
request, self.register_helper.get_llama_model(request.model), self.formatter
request, self.register_helper.get_llama_model(request.model)
)
return dict(
prompt=prompt,

View file

@ -6,8 +6,6 @@
from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together
from llama_stack.apis.common.content_types import InterleavedContent
@ -55,7 +53,6 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
pass
@ -102,7 +99,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client().completions.create(**params)
return process_completion_response(r, self.formatter)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -114,7 +111,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
def _build_options(
@ -180,7 +177,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
r = self._get_client().chat.completions.create(**params)
else:
r = self._get_client().completions.create(**params)
return process_chat_completion_response(r, self.formatter, request)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -195,7 +192,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
@ -206,11 +203,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
request, self.get_llama_model(request.model)
)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
input_dict["prompt"] = await completion_request_to_prompt(request)
return {
"model": request.model,

View file

@ -8,8 +8,6 @@ import logging
from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import StopReason, ToolCall
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
@ -191,7 +189,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_model_aliases())
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = None
async def initialize(self) -> None:
@ -286,14 +283,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if len(request.tools) > 0:
res = _process_vllm_chat_completion_stream_response(stream)
else:
res = process_chat_completion_stream_response(stream, self.formatter, request)
res = process_chat_completion_stream_response(stream, request)
async for chunk in res:
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = self.client.completions.create(**params)
return process_completion_response(r, self.formatter)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -305,7 +302,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
async def register_model(self, model: Model) -> Model:
@ -332,10 +329,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
else:
assert not request_has_media(request), "vLLM does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(
request,
self.formatter,
)
input_dict["prompt"] = await completion_request_to_prompt(request)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:

View file

@ -7,7 +7,6 @@ import json
import logging
from typing import AsyncGenerator, Dict, List, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat
from openai.types.chat import ChatCompletionMessageToolCall
from pydantic import BaseModel
@ -40,6 +39,7 @@ from llama_stack.models.llama.datatypes import (
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
decode_assistant_message,
)
logger = logging.getLogger(__name__)
@ -149,7 +149,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Unio
return None
def process_completion_response(response: OpenAICompatCompletionResponse, formatter: ChatFormat) -> CompletionResponse:
def process_completion_response(response: OpenAICompatCompletionResponse) -> CompletionResponse:
choice = response.choices[0]
# drop suffix <eot_id> if present and return stop reason as end of turn
if choice.text.endswith("<|eot_id|>"):
@ -174,16 +174,13 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format
def process_chat_completion_response(
response: OpenAICompatCompletionResponse,
formatter: ChatFormat,
request: ChatCompletionRequest,
) -> ChatCompletionResponse:
choice = response.choices[0]
# TODO: This does not work well with tool calls for vLLM remote provider
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
raw_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
# NOTE: If we do not set tools in chat-completion request, we should not
# expect the ToolCall in the response. Instead, we should return the raw
@ -217,7 +214,7 @@ def process_chat_completion_response(
async def process_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
) -> AsyncGenerator:
stop_reason = None
@ -254,7 +251,6 @@ async def process_completion_stream_response(
async def process_chat_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
formatter: ChatFormat,
request: ChatCompletionRequest,
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
@ -333,7 +329,7 @@ async def process_chat_completion_stream_response(
)
# parse tool calls and report errors
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
message = decode_assistant_message(buffer, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:

View file

@ -13,7 +13,9 @@ import re
from typing import List, Optional, Tuple, Union
import httpx
from llama_models.datatypes import StopReason
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from PIL import Image as PIL_Image
from llama_stack.apis.common.content_types import (
@ -66,6 +68,11 @@ class CompletionRequestWithRawContent(CompletionRequest):
content: RawContent
def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage:
formatter = ChatFormat(Tokenizer.get_instance())
return formatter.decode_assistant_message_from_content(content, stop_reason)
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
def _process(c) -> str:
if isinstance(c, str):
@ -207,20 +214,22 @@ async def convert_image_content_to_url(
return base64.b64encode(content).decode("utf-8")
async def completion_request_to_prompt(request: CompletionRequest, formatter: ChatFormat) -> str:
async def completion_request_to_prompt(request: CompletionRequest) -> str:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.content = content
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_content(request.content)
return formatter.tokenizer.decode(model_input.tokens)
async def completion_request_to_prompt_model_input_info(
request: CompletionRequest, formatter: ChatFormat
) -> Tuple[str, int]:
async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.content = content
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_content(request.content)
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
@ -237,22 +246,24 @@ def augment_content_with_response_format_prompt(response_format, content):
return content
async def chat_completion_request_to_prompt(
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
) -> str:
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(request.messages)
return formatter.tokenizer.decode(model_input.tokens)
async def chat_completion_request_to_model_input_info(
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
request: ChatCompletionRequest, llama_model: str
) -> Tuple[str, int]:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(request.messages)
return (
formatter.tokenizer.decode(model_input.tokens),