chore: remove llama_models.llama3.api imports from providers (#1107)

There should be a choke-point for llama3.api imports -- this is the
prompt adapter. Creating a ChatFormat() object on demand is inexpensive.
The underlying Tokenizer is a singleton anyway.
This commit is contained in:
Ashwin Bharambe 2025-02-19 19:01:29 -08:00 committed by GitHub
parent e9b8259cf9
commit cdcbeb005b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 77 additions and 113 deletions

View file

@ -9,7 +9,6 @@ import os
import uuid import uuid
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -62,7 +61,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMConfig): def __init__(self, config: VLLMConfig):
self.config = config self.config = config
self.engine = None self.engine = None
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self): async def initialize(self):
log.info("Initializing vLLM inference provider.") log.info("Initializing vLLM inference provider.")
@ -177,7 +175,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
log.info("Sampling params: %s", sampling_params) log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid() request_id = _random_uuid()
prompt = await chat_completion_request_to_prompt(request, self.config.model, self.formatter) prompt = await chat_completion_request_to_prompt(request, self.config.model)
vllm_sampling_params = self._sampling_params(request.sampling_params) vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id) results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
if stream: if stream:
@ -201,11 +199,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse( response = OpenAICompatCompletionResponse(
choices=[choice], choices=[choice],
) )
return process_chat_completion_response(response, self.formatter, request) return process_chat_completion_response(response, request)
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> AsyncGenerator: ) -> AsyncGenerator:
tokenizer = Tokenizer.get_instance()
async def _generate_and_convert_to_openai_compat(): async def _generate_and_convert_to_openai_compat():
cur = [] cur = []
async for chunk in results_generator: async for chunk in results_generator:
@ -216,7 +216,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
output = chunk.outputs[-1] output = chunk.outputs[-1]
new_tokens = output.token_ids[len(cur) :] new_tokens = output.token_ids[len(cur) :]
text = self.formatter.tokenizer.decode(new_tokens) text = tokenizer.decode(new_tokens)
cur.extend(new_tokens) cur.extend(new_tokens)
choice = OpenAICompatCompletionChoice( choice = OpenAICompatCompletionChoice(
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
@ -227,7 +227,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
) )
stream = _generate_and_convert_to_openai_compat() stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse: async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:

View file

@ -8,8 +8,6 @@ import json
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -54,7 +52,6 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
self._config = config self._config = config
self._client = create_bedrock_client(config) self._client = create_bedrock_client(config)
self.formatter = ChatFormat(Tokenizer.get_instance())
@property @property
def client(self) -> BaseClient: def client(self) -> BaseClient:
@ -119,7 +116,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
) )
response = OpenAICompatCompletionResponse(choices=[choice]) response = OpenAICompatCompletionResponse(choices=[choice])
return process_chat_completion_response(response, self.formatter, request) return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_chat_completion(request) params = await self._get_params_for_chat_completion(request)
@ -137,7 +134,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
yield OpenAICompatCompletionResponse(choices=[choice]) yield OpenAICompatCompletionResponse(choices=[choice])
stream = _generate_and_convert_to_openai_compat() stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict: async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
@ -151,7 +148,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
if sampling_params.repetition_penalty > 0: if sampling_params.repetition_penalty > 0:
options["repetition_penalty"] = sampling_params.repetition_penalty options["repetition_penalty"] = sampling_params.repetition_penalty
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter) prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
return { return {
"modelId": bedrock_model, "modelId": bedrock_model,
"body": json.dumps( "body": json.dumps(

View file

@ -7,8 +7,6 @@
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras from cerebras.cloud.sdk import AsyncCerebras
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -53,7 +51,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
model_aliases=model_aliases, model_aliases=model_aliases,
) )
self.config = config self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = AsyncCerebras( self.client = AsyncCerebras(
base_url=self.config.base_url, base_url=self.config.base_url,
@ -96,14 +93,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
r = await self.client.completions.create(**params) r = await self.client.completions.create(**params)
return process_completion_response(r, self.formatter) return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
stream = await self.client.completions.create(**params) stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream, self.formatter): async for chunk in process_completion_stream_response(stream):
yield chunk yield chunk
async def chat_completion( async def chat_completion(
@ -143,14 +140,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
r = await self.client.completions.create(**params) r = await self.client.completions.create(**params)
return process_chat_completion_response(r, self.formatter, request) return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
stream = await self.client.completions.create(**params) stream = await self.client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
@ -159,11 +156,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
prompt = "" prompt = ""
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt( prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
request, self.get_llama_model(request.model), self.formatter
)
elif isinstance(request, CompletionRequest): elif isinstance(request, CompletionRequest):
prompt = await completion_request_to_prompt(request, self.formatter) prompt = await completion_request_to_prompt(request)
else: else:
raise ValueError(f"Unknown request type {type(request)}") raise ValueError(f"Unknown request type {type(request)}")

View file

@ -6,8 +6,6 @@
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
@ -54,12 +52,8 @@ model_aliases = [
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: DatabricksImplConfig) -> None: def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__( ModelRegistryHelper.__init__(self, model_aliases=model_aliases)
self,
model_aliases=model_aliases,
)
self.config = config self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None: async def initialize(self) -> None:
return return
@ -112,7 +106,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
params = self._get_params(request) params = self._get_params(request)
r = client.completions.create(**params) r = client.completions.create(**params)
return process_chat_completion_response(r, self.formatter, request) return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request) params = self._get_params(request)
@ -123,13 +117,13 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
yield chunk yield chunk
stream = _to_async_generator() stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict: def _get_params(self, request: ChatCompletionRequest) -> dict:
return { return {
"model": request.model, "model": request.model,
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter), "prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)),
"stream": request.stream, "stream": request.stream,
**get_sampling_options(request.sampling_params), **get_sampling_options(request.sampling_params),
} }

View file

@ -7,8 +7,6 @@
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -56,7 +54,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES) ModelRegistryHelper.__init__(self, MODEL_ALIASES)
self.config = config self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -105,7 +102,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
r = await self._get_client().completion.acreate(**params) r = await self._get_client().completion.acreate(**params)
return process_completion_response(r, self.formatter) return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
@ -117,7 +114,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
yield chunk yield chunk
stream = _to_async_generator() stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter): async for chunk in process_completion_stream_response(stream):
yield chunk yield chunk
def _build_options( def _build_options(
@ -186,7 +183,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
r = await self._get_client().chat.completions.acreate(**params) r = await self._get_client().chat.completions.acreate(**params)
else: else:
r = await self._get_client().completion.acreate(**params) r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, self.formatter, request) return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
@ -200,7 +197,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
yield chunk yield chunk
stream = _to_async_generator() stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
@ -214,11 +211,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
] ]
else: else:
input_dict["prompt"] = await chat_completion_request_to_prompt( input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter request, self.get_llama_model(request.model)
) )
else: else:
assert not media_present, "Fireworks does not support media for Completion requests" assert not media_present, "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter) input_dict["prompt"] = await completion_request_to_prompt(request)
# Fireworks always prepends with BOS # Fireworks always prepends with BOS
if "prompt" in input_dict: if "prompt" in input_dict:

View file

@ -8,8 +8,6 @@ import logging
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
import httpx import httpx
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient from ollama import AsyncClient
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -138,7 +136,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, url: str) -> None: def __init__(self, url: str) -> None:
self.register_helper = ModelRegistryHelper(model_aliases) self.register_helper = ModelRegistryHelper(model_aliases)
self.url = url self.url = url
self.formatter = ChatFormat(Tokenizer.get_instance())
@property @property
def client(self) -> AsyncClient: def client(self) -> AsyncClient:
@ -197,7 +194,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
) )
stream = _generate_and_convert_to_openai_compat() stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter): async for chunk in process_completion_stream_response(stream):
yield chunk yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
@ -212,7 +209,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
choices=[choice], choices=[choice],
) )
return process_completion_response(response, self.formatter) return process_completion_response(response)
async def chat_completion( async def chat_completion(
self, self,
@ -262,11 +259,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["prompt"] = await chat_completion_request_to_prompt( input_dict["prompt"] = await chat_completion_request_to_prompt(
request, request,
self.register_helper.get_llama_model(request.model), self.register_helper.get_llama_model(request.model),
self.formatter,
) )
else: else:
assert not media_present, "Ollama does not support media for Completion requests" assert not media_present, "Ollama does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter) input_dict["prompt"] = await completion_request_to_prompt(request)
input_dict["raw"] = True input_dict["raw"] = True
if fmt := request.response_format: if fmt := request.response_format:
@ -304,7 +300,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse( response = OpenAICompatCompletionResponse(
choices=[choice], choices=[choice],
) )
return process_chat_completion_response(response, self.formatter, request) return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
@ -330,7 +326,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
) )
stream = _generate_and_convert_to_openai_compat() stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def embeddings( async def embeddings(

View file

@ -5,8 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
@ -45,7 +43,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: RunpodImplConfig) -> None: def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS) ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
self.config = config self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None: async def initialize(self) -> None:
return return
@ -56,7 +53,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
async def completion( async def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -97,7 +94,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
params = self._get_params(request) params = self._get_params(request)
r = client.completions.create(**params) r = client.completions.create(**params)
return process_chat_completion_response(r, self.formatter, request) return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request) params = self._get_params(request)
@ -108,13 +105,13 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
yield chunk yield chunk
stream = _to_async_generator() stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict: def _get_params(self, request: ChatCompletionRequest) -> dict:
return { return {
"model": self.map_to_provider_model(request.model), "model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter), "prompt": chat_completion_request_to_prompt(request),
"stream": request.stream, "stream": request.stream,
**get_sampling_options(request.sampling_params), **get_sampling_options(request.sampling_params),
} }
@ -122,6 +119,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model: str, model: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedContent],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -7,8 +7,6 @@
import json import json
from typing import AsyncGenerator from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -38,13 +36,8 @@ from .models import MODEL_ALIASES
class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: SambaNovaImplConfig) -> None: def __init__(self, config: SambaNovaImplConfig) -> None:
ModelRegistryHelper.__init__( ModelRegistryHelper.__init__(self, model_aliases=MODEL_ALIASES)
self,
model_aliases=MODEL_ALIASES,
)
self.config = config self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None: async def initialize(self) -> None:
return return
@ -120,7 +113,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
yield chunk yield chunk
stream = _to_async_generator() stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def embeddings( async def embeddings(

View file

@ -9,8 +9,6 @@ import logging
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -72,7 +70,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
model_id: str model_id: str
def __init__(self) -> None: def __init__(self) -> None:
self.formatter = ChatFormat(Tokenizer.get_instance())
self.register_helper = ModelRegistryHelper(build_model_aliases()) self.register_helper = ModelRegistryHelper(build_model_aliases())
self.huggingface_repo_to_llama_model_id = { self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
@ -149,7 +146,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
return options return options
async def _get_params_for_completion(self, request: CompletionRequest) -> dict: async def _get_params_for_completion(self, request: CompletionRequest) -> dict:
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request, self.formatter) prompt, input_tokens = await completion_request_to_prompt_model_input_info(request)
return dict( return dict(
prompt=prompt, prompt=prompt,
@ -177,7 +174,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
) )
stream = _generate_and_convert_to_openai_compat() stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter): async for chunk in process_completion_stream_response(stream):
yield chunk yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
@ -193,7 +190,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
choices=[choice], choices=[choice],
) )
return process_completion_response(response, self.formatter) return process_completion_response(response)
async def chat_completion( async def chat_completion(
self, self,
@ -236,7 +233,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse( response = OpenAICompatCompletionResponse(
choices=[choice], choices=[choice],
) )
return process_chat_completion_response(response, self.formatter, request) return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
@ -252,12 +249,12 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
) )
stream = _generate_and_convert_to_openai_compat() stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def _get_params(self, request: ChatCompletionRequest) -> dict: async def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = await chat_completion_request_to_model_input_info( prompt, input_tokens = await chat_completion_request_to_model_input_info(
request, self.register_helper.get_llama_model(request.model), self.formatter request, self.register_helper.get_llama_model(request.model)
) )
return dict( return dict(
prompt=prompt, prompt=prompt,

View file

@ -6,8 +6,6 @@
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together from together import Together
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
@ -55,7 +53,6 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES) ModelRegistryHelper.__init__(self, MODEL_ALIASES)
self.config = config self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -102,7 +99,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
r = self._get_client().completions.create(**params) r = self._get_client().completions.create(**params)
return process_completion_response(r, self.formatter) return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
@ -114,7 +111,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
yield chunk yield chunk
stream = _to_async_generator() stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter): async for chunk in process_completion_stream_response(stream):
yield chunk yield chunk
def _build_options( def _build_options(
@ -180,7 +177,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
r = self._get_client().chat.completions.create(**params) r = self._get_client().chat.completions.create(**params)
else: else:
r = self._get_client().completions.create(**params) r = self._get_client().completions.create(**params)
return process_chat_completion_response(r, self.formatter, request) return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
@ -195,7 +192,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
yield chunk yield chunk
stream = _to_async_generator() stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
@ -206,11 +203,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages] input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else: else:
input_dict["prompt"] = await chat_completion_request_to_prompt( input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter request, self.get_llama_model(request.model)
) )
else: else:
assert not media_present, "Together does not support media for Completion requests" assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter) input_dict["prompt"] = await completion_request_to_prompt(request)
return { return {
"model": request.model, "model": request.model,

View file

@ -8,8 +8,6 @@ import logging
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import StopReason, ToolCall from llama_models.datatypes import StopReason, ToolCall
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
@ -191,7 +189,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None: def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_model_aliases()) self.register_helper = ModelRegistryHelper(build_model_aliases())
self.config = config self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = None self.client = None
async def initialize(self) -> None: async def initialize(self) -> None:
@ -286,14 +283,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if len(request.tools) > 0: if len(request.tools) > 0:
res = _process_vllm_chat_completion_stream_response(stream) res = _process_vllm_chat_completion_stream_response(stream)
else: else:
res = process_chat_completion_stream_response(stream, self.formatter, request) res = process_chat_completion_stream_response(stream, request)
async for chunk in res: async for chunk in res:
yield chunk yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
r = self.client.completions.create(**params) r = self.client.completions.create(**params)
return process_completion_response(r, self.formatter) return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
@ -305,7 +302,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk yield chunk
stream = _to_async_generator() stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter): async for chunk in process_completion_stream_response(stream):
yield chunk yield chunk
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
@ -332,10 +329,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages] input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
else: else:
assert not request_has_media(request), "vLLM does not support media for Completion requests" assert not request_has_media(request), "vLLM does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt( input_dict["prompt"] = await completion_request_to_prompt(request)
request,
self.formatter,
)
if fmt := request.response_format: if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value: if fmt.type == ResponseFormatType.json_schema.value:

View file

@ -7,7 +7,6 @@ import json
import logging import logging
from typing import AsyncGenerator, Dict, List, Optional, Union from typing import AsyncGenerator, Dict, List, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat
from openai.types.chat import ChatCompletionMessageToolCall from openai.types.chat import ChatCompletionMessageToolCall
from pydantic import BaseModel from pydantic import BaseModel
@ -40,6 +39,7 @@ from llama_stack.models.llama.datatypes import (
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url, convert_image_content_to_url,
decode_assistant_message,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -149,7 +149,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Unio
return None return None
def process_completion_response(response: OpenAICompatCompletionResponse, formatter: ChatFormat) -> CompletionResponse: def process_completion_response(response: OpenAICompatCompletionResponse) -> CompletionResponse:
choice = response.choices[0] choice = response.choices[0]
# drop suffix <eot_id> if present and return stop reason as end of turn # drop suffix <eot_id> if present and return stop reason as end of turn
if choice.text.endswith("<|eot_id|>"): if choice.text.endswith("<|eot_id|>"):
@ -174,16 +174,13 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format
def process_chat_completion_response( def process_chat_completion_response(
response: OpenAICompatCompletionResponse, response: OpenAICompatCompletionResponse,
formatter: ChatFormat,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
choice = response.choices[0] choice = response.choices[0]
# TODO: This does not work well with tool calls for vLLM remote provider # TODO: This does not work well with tool calls for vLLM remote provider
# Ref: https://github.com/meta-llama/llama-stack/issues/1058 # Ref: https://github.com/meta-llama/llama-stack/issues/1058
raw_message = formatter.decode_assistant_message_from_content( raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
# NOTE: If we do not set tools in chat-completion request, we should not # NOTE: If we do not set tools in chat-completion request, we should not
# expect the ToolCall in the response. Instead, we should return the raw # expect the ToolCall in the response. Instead, we should return the raw
@ -217,7 +214,7 @@ def process_chat_completion_response(
async def process_completion_stream_response( async def process_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
) -> AsyncGenerator: ) -> AsyncGenerator:
stop_reason = None stop_reason = None
@ -254,7 +251,6 @@ async def process_completion_stream_response(
async def process_chat_completion_stream_response( async def process_chat_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
formatter: ChatFormat,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> AsyncGenerator: ) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
@ -333,7 +329,7 @@ async def process_chat_completion_stream_response(
) )
# parse tool calls and report errors # parse tool calls and report errors
message = formatter.decode_assistant_message_from_content(buffer, stop_reason) message = decode_assistant_message(buffer, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0 parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls: if ipython and not parsed_tool_calls:

View file

@ -13,7 +13,9 @@ import re
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import httpx import httpx
from llama_models.datatypes import StopReason
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -66,6 +68,11 @@ class CompletionRequestWithRawContent(CompletionRequest):
content: RawContent content: RawContent
def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage:
formatter = ChatFormat(Tokenizer.get_instance())
return formatter.decode_assistant_message_from_content(content, stop_reason)
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str: def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
def _process(c) -> str: def _process(c) -> str:
if isinstance(c, str): if isinstance(c, str):
@ -207,20 +214,22 @@ async def convert_image_content_to_url(
return base64.b64encode(content).decode("utf-8") return base64.b64encode(content).decode("utf-8")
async def completion_request_to_prompt(request: CompletionRequest, formatter: ChatFormat) -> str: async def completion_request_to_prompt(request: CompletionRequest) -> str:
content = augment_content_with_response_format_prompt(request.response_format, request.content) content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.content = content request.content = content
request = await convert_request_to_raw(request) request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_content(request.content) model_input = formatter.encode_content(request.content)
return formatter.tokenizer.decode(model_input.tokens) return formatter.tokenizer.decode(model_input.tokens)
async def completion_request_to_prompt_model_input_info( async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]:
request: CompletionRequest, formatter: ChatFormat
) -> Tuple[str, int]:
content = augment_content_with_response_format_prompt(request.response_format, request.content) content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.content = content request.content = content
request = await convert_request_to_raw(request) request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_content(request.content) model_input = formatter.encode_content(request.content)
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens)) return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
@ -237,22 +246,24 @@ def augment_content_with_response_format_prompt(response_format, content):
return content return content
async def chat_completion_request_to_prompt( async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
) -> str:
messages = chat_completion_request_to_messages(request, llama_model) messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages request.messages = messages
request = await convert_request_to_raw(request) request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(request.messages) model_input = formatter.encode_dialog_prompt(request.messages)
return formatter.tokenizer.decode(model_input.tokens) return formatter.tokenizer.decode(model_input.tokens)
async def chat_completion_request_to_model_input_info( async def chat_completion_request_to_model_input_info(
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat request: ChatCompletionRequest, llama_model: str
) -> Tuple[str, int]: ) -> Tuple[str, int]:
messages = chat_completion_request_to_messages(request, llama_model) messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages request.messages = messages
request = await convert_request_to_raw(request) request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(request.messages) model_input = formatter.encode_dialog_prompt(request.messages)
return ( return (
formatter.tokenizer.decode(model_input.tokens), formatter.tokenizer.decode(model_input.tokens),