chore: remove llama_models.llama3.api imports from providers (#1107)

There should be a choke-point for llama3.api imports -- this is the
prompt adapter. Creating a ChatFormat() object on demand is inexpensive.
The underlying Tokenizer is a singleton anyway.
This commit is contained in:
Ashwin Bharambe 2025-02-19 19:01:29 -08:00 committed by GitHub
parent e9b8259cf9
commit cdcbeb005b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 77 additions and 113 deletions

View file

@ -8,8 +8,6 @@ import json
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
@ -54,7 +52,6 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
self._config = config
self._client = create_bedrock_client(config)
self.formatter = ChatFormat(Tokenizer.get_instance())
@property
def client(self) -> BaseClient:
@ -119,7 +116,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
)
response = OpenAICompatCompletionResponse(choices=[choice])
return process_chat_completion_response(response, self.formatter, request)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_chat_completion(request)
@ -137,7 +134,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
yield OpenAICompatCompletionResponse(choices=[choice])
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
@ -151,7 +148,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
if sampling_params.repetition_penalty > 0:
options["repetition_penalty"] = sampling_params.repetition_penalty
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter)
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
return {
"modelId": bedrock_model,
"body": json.dumps(

View file

@ -7,8 +7,6 @@
from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
@ -53,7 +51,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
model_aliases=model_aliases,
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = AsyncCerebras(
base_url=self.config.base_url,
@ -96,14 +93,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
r = await self.client.completions.create(**params)
return process_completion_response(r, self.formatter)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
async def chat_completion(
@ -143,14 +140,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
r = await self.client.completions.create(**params)
return process_chat_completion_response(r, self.formatter, request)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
@ -159,11 +156,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
prompt = ""
if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
elif isinstance(request, CompletionRequest):
prompt = await completion_request_to_prompt(request, self.formatter)
prompt = await completion_request_to_prompt(request)
else:
raise ValueError(f"Unknown request type {type(request)}")

View file

@ -6,8 +6,6 @@
from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent
@ -54,12 +52,8 @@ model_aliases = [
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=model_aliases,
)
ModelRegistryHelper.__init__(self, model_aliases=model_aliases)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
@ -112,7 +106,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, self.formatter, request)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)
@ -123,13 +117,13 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": request.model,
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter),
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}

View file

@ -7,8 +7,6 @@
from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
@ -56,7 +54,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
pass
@ -105,7 +102,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self._get_client().completion.acreate(**params)
return process_completion_response(r, self.formatter)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -117,7 +114,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
def _build_options(
@ -186,7 +183,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
r = await self._get_client().chat.completions.acreate(**params)
else:
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, self.formatter, request)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -200,7 +197,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
@ -214,11 +211,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
request, self.get_llama_model(request.model)
)
else:
assert not media_present, "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
input_dict["prompt"] = await completion_request_to_prompt(request)
# Fireworks always prepends with BOS
if "prompt" in input_dict:

View file

@ -8,8 +8,6 @@ import logging
from typing import AsyncGenerator, List, Optional, Union
import httpx
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
from llama_stack.apis.common.content_types import (
@ -138,7 +136,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, url: str) -> None:
self.register_helper = ModelRegistryHelper(model_aliases)
self.url = url
self.formatter = ChatFormat(Tokenizer.get_instance())
@property
def client(self) -> AsyncClient:
@ -197,7 +194,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
@ -212,7 +209,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
choices=[choice],
)
return process_completion_response(response, self.formatter)
return process_completion_response(response)
async def chat_completion(
self,
@ -262,11 +259,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["prompt"] = await chat_completion_request_to_prompt(
request,
self.register_helper.get_llama_model(request.model),
self.formatter,
)
else:
assert not media_present, "Ollama does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
input_dict["prompt"] = await completion_request_to_prompt(request)
input_dict["raw"] = True
if fmt := request.response_format:
@ -304,7 +300,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, self.formatter, request)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -330,7 +326,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def embeddings(

View file

@ -5,8 +5,6 @@
# the root directory of this source tree.
from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
@ -45,7 +43,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
@ -56,7 +53,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
async def completion(
self,
model: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -97,7 +94,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, self.formatter, request)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)
@ -108,13 +105,13 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"prompt": chat_completion_request_to_prompt(request),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
@ -122,6 +119,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -7,8 +7,6 @@
import json
from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.common.content_types import (
@ -38,13 +36,8 @@ from .models import MODEL_ALIASES
class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: SambaNovaImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=MODEL_ALIASES,
)
ModelRegistryHelper.__init__(self, model_aliases=MODEL_ALIASES)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
@ -120,7 +113,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def embeddings(

View file

@ -9,8 +9,6 @@ import logging
from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
@ -72,7 +70,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
model_id: str
def __init__(self) -> None:
self.formatter = ChatFormat(Tokenizer.get_instance())
self.register_helper = ModelRegistryHelper(build_model_aliases())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
@ -149,7 +146,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
return options
async def _get_params_for_completion(self, request: CompletionRequest) -> dict:
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request, self.formatter)
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request)
return dict(
prompt=prompt,
@ -177,7 +174,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
@ -193,7 +190,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
choices=[choice],
)
return process_completion_response(response, self.formatter)
return process_completion_response(response)
async def chat_completion(
self,
@ -236,7 +233,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, self.formatter, request)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -252,12 +249,12 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = await chat_completion_request_to_model_input_info(
request, self.register_helper.get_llama_model(request.model), self.formatter
request, self.register_helper.get_llama_model(request.model)
)
return dict(
prompt=prompt,

View file

@ -6,8 +6,6 @@
from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together
from llama_stack.apis.common.content_types import InterleavedContent
@ -55,7 +53,6 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
pass
@ -102,7 +99,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client().completions.create(**params)
return process_completion_response(r, self.formatter)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -114,7 +111,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
def _build_options(
@ -180,7 +177,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
r = self._get_client().chat.completions.create(**params)
else:
r = self._get_client().completions.create(**params)
return process_chat_completion_response(r, self.formatter, request)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -195,7 +192,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
@ -206,11 +203,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
request, self.get_llama_model(request.model)
)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
input_dict["prompt"] = await completion_request_to_prompt(request)
return {
"model": request.model,

View file

@ -8,8 +8,6 @@ import logging
from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import StopReason, ToolCall
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
@ -191,7 +189,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_model_aliases())
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = None
async def initialize(self) -> None:
@ -286,14 +283,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if len(request.tools) > 0:
res = _process_vllm_chat_completion_stream_response(stream)
else:
res = process_chat_completion_stream_response(stream, self.formatter, request)
res = process_chat_completion_stream_response(stream, request)
async for chunk in res:
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = self.client.completions.create(**params)
return process_completion_response(r, self.formatter)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
@ -305,7 +302,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(stream):
yield chunk
async def register_model(self, model: Model) -> Model:
@ -332,10 +329,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
else:
assert not request_has_media(request), "vLLM does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(
request,
self.formatter,
)
input_dict["prompt"] = await completion_request_to_prompt(request)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value: