forked from phoenix-oss/llama-stack-mirror
chore: remove llama_models.llama3.api imports from providers (#1107)
There should be a choke-point for llama3.api imports -- this is the prompt adapter. Creating a ChatFormat() object on demand is inexpensive. The underlying Tokenizer is a singleton anyway.
This commit is contained in:
parent
e9b8259cf9
commit
cdcbeb005b
13 changed files with 77 additions and 113 deletions
|
@ -9,7 +9,6 @@ import os
|
|||
import uuid
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
|
@ -62,7 +61,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
def __init__(self, config: VLLMConfig):
|
||||
self.config = config
|
||||
self.engine = None
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self):
|
||||
log.info("Initializing vLLM inference provider.")
|
||||
|
@ -177,7 +175,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
log.info("Sampling params: %s", sampling_params)
|
||||
request_id = _random_uuid()
|
||||
|
||||
prompt = await chat_completion_request_to_prompt(request, self.config.model, self.formatter)
|
||||
prompt = await chat_completion_request_to_prompt(request, self.config.model)
|
||||
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
||||
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
|
||||
if stream:
|
||||
|
@ -201,11 +199,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, self.formatter, request)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
) -> AsyncGenerator:
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
cur = []
|
||||
async for chunk in results_generator:
|
||||
|
@ -216,7 +216,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
output = chunk.outputs[-1]
|
||||
|
||||
new_tokens = output.token_ids[len(cur) :]
|
||||
text = self.formatter.tokenizer.decode(new_tokens)
|
||||
text = tokenizer.decode(new_tokens)
|
||||
cur.extend(new_tokens)
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=output.finish_reason,
|
||||
|
@ -227,7 +227,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
|
||||
|
|
|
@ -8,8 +8,6 @@ import json
|
|||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from botocore.client import BaseClient
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -54,7 +52,6 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self._config = config
|
||||
|
||||
self._client = create_bedrock_client(config)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
|
@ -119,7 +116,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
)
|
||||
|
||||
response = OpenAICompatCompletionResponse(choices=[choice])
|
||||
return process_chat_completion_response(response, self.formatter, request)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params_for_chat_completion(request)
|
||||
|
@ -137,7 +134,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
yield OpenAICompatCompletionResponse(choices=[choice])
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
||||
|
@ -151,7 +148,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
if sampling_params.repetition_penalty > 0:
|
||||
options["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
|
||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter)
|
||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||
return {
|
||||
"modelId": bedrock_model,
|
||||
"body": json.dumps(
|
||||
|
|
|
@ -7,8 +7,6 @@
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from cerebras.cloud.sdk import AsyncCerebras
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -53,7 +51,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
model_aliases=model_aliases,
|
||||
)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
self.client = AsyncCerebras(
|
||||
base_url=self.config.base_url,
|
||||
|
@ -96,14 +93,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
r = await self.client.completions.create(**params)
|
||||
|
||||
return process_completion_response(r, self.formatter)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def chat_completion(
|
||||
|
@ -143,14 +140,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
r = await self.client.completions.create(**params)
|
||||
|
||||
return process_chat_completion_response(r, self.formatter, request)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||
|
@ -159,11 +156,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
prompt = ""
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
prompt = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model), self.formatter
|
||||
)
|
||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||
elif isinstance(request, CompletionRequest):
|
||||
prompt = await completion_request_to_prompt(request, self.formatter)
|
||||
prompt = await completion_request_to_prompt(request)
|
||||
else:
|
||||
raise ValueError(f"Unknown request type {type(request)}")
|
||||
|
||||
|
|
|
@ -6,8 +6,6 @@
|
|||
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
@ -54,12 +52,8 @@ model_aliases = [
|
|||
|
||||
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_aliases=model_aliases,
|
||||
)
|
||||
ModelRegistryHelper.__init__(self, model_aliases=model_aliases)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
@ -112,7 +106,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter, request)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
@ -123,13 +117,13 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": request.model,
|
||||
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter),
|
||||
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
|
|
@ -7,8 +7,6 @@
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -56,7 +54,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
@ -105,7 +102,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await self._get_client().completion.acreate(**params)
|
||||
return process_completion_response(r, self.formatter)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -117,7 +114,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
def _build_options(
|
||||
|
@ -186,7 +183,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
r = await self._get_client().chat.completions.acreate(**params)
|
||||
else:
|
||||
r = await self._get_client().completion.acreate(**params)
|
||||
return process_chat_completion_response(r, self.formatter, request)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -200,7 +197,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||
|
@ -214,11 +211,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model), self.formatter
|
||||
request, self.get_llama_model(request.model)
|
||||
)
|
||||
else:
|
||||
assert not media_present, "Fireworks does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if "prompt" in input_dict:
|
||||
|
|
|
@ -8,8 +8,6 @@ import logging
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -138,7 +136,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
def __init__(self, url: str) -> None:
|
||||
self.register_helper = ModelRegistryHelper(model_aliases)
|
||||
self.url = url
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
|
@ -197,7 +194,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
|
@ -212,7 +209,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
choices=[choice],
|
||||
)
|
||||
|
||||
return process_completion_response(response, self.formatter)
|
||||
return process_completion_response(response)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
@ -262,11 +259,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request,
|
||||
self.register_helper.get_llama_model(request.model),
|
||||
self.formatter,
|
||||
)
|
||||
else:
|
||||
assert not media_present, "Ollama does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
input_dict["raw"] = True
|
||||
|
||||
if fmt := request.response_format:
|
||||
|
@ -304,7 +300,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, self.formatter, request)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -330,7 +326,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -5,8 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
@ -45,7 +43,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
def __init__(self, config: RunpodImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
@ -56,7 +53,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
|
@ -97,7 +94,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter, request)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
@ -108,13 +105,13 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"prompt": chat_completion_request_to_prompt(request),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
@ -122,6 +119,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -7,8 +7,6 @@
|
|||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -38,13 +36,8 @@ from .models import MODEL_ALIASES
|
|||
|
||||
class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: SambaNovaImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_aliases=MODEL_ALIASES,
|
||||
)
|
||||
|
||||
ModelRegistryHelper.__init__(self, model_aliases=MODEL_ALIASES)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
@ -120,7 +113,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -9,8 +9,6 @@ import logging
|
|||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -72,7 +70,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
model_id: str
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||
self.huggingface_repo_to_llama_model_id = {
|
||||
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
|
||||
|
@ -149,7 +146,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
return options
|
||||
|
||||
async def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
||||
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request, self.formatter)
|
||||
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request)
|
||||
|
||||
return dict(
|
||||
prompt=prompt,
|
||||
|
@ -177,7 +174,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
|
@ -193,7 +190,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
choices=[choice],
|
||||
)
|
||||
|
||||
return process_completion_response(response, self.formatter)
|
||||
return process_completion_response(response)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
@ -236,7 +233,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, self.formatter, request)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -252,12 +249,12 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
prompt, input_tokens = await chat_completion_request_to_model_input_info(
|
||||
request, self.register_helper.get_llama_model(request.model), self.formatter
|
||||
request, self.register_helper.get_llama_model(request.model)
|
||||
)
|
||||
return dict(
|
||||
prompt=prompt,
|
||||
|
|
|
@ -6,8 +6,6 @@
|
|||
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from together import Together
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
@ -55,7 +53,6 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
@ -102,7 +99,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self._get_client().completions.create(**params)
|
||||
return process_completion_response(r, self.formatter)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -114,7 +111,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
def _build_options(
|
||||
|
@ -180,7 +177,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
r = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
r = self._get_client().completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter, request)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -195,7 +192,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||
|
@ -206,11 +203,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model), self.formatter
|
||||
request, self.get_llama_model(request.model)
|
||||
)
|
||||
else:
|
||||
assert not media_present, "Together does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
|
|
|
@ -8,8 +8,6 @@ import logging
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_models.datatypes import StopReason, ToolCall
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
|
||||
|
@ -191,7 +189,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
self.client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
@ -286,14 +283,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if len(request.tools) > 0:
|
||||
res = _process_vllm_chat_completion_stream_response(stream)
|
||||
else:
|
||||
res = process_chat_completion_stream_response(stream, self.formatter, request)
|
||||
res = process_chat_completion_stream_response(stream, request)
|
||||
async for chunk in res:
|
||||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self.client.completions.create(**params)
|
||||
return process_completion_response(r, self.formatter)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
@ -305,7 +302,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
|
@ -332,10 +329,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
|
||||
else:
|
||||
assert not request_has_media(request), "vLLM does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(
|
||||
request,
|
||||
self.formatter,
|
||||
)
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
||||
if fmt := request.response_format:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
|
|
|
@ -7,7 +7,6 @@ import json
|
|||
import logging
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from openai.types.chat import ChatCompletionMessageToolCall
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -40,6 +39,7 @@ from llama_stack.models.llama.datatypes import (
|
|||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_content_to_url,
|
||||
decode_assistant_message,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -149,7 +149,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Unio
|
|||
return None
|
||||
|
||||
|
||||
def process_completion_response(response: OpenAICompatCompletionResponse, formatter: ChatFormat) -> CompletionResponse:
|
||||
def process_completion_response(response: OpenAICompatCompletionResponse) -> CompletionResponse:
|
||||
choice = response.choices[0]
|
||||
# drop suffix <eot_id> if present and return stop reason as end of turn
|
||||
if choice.text.endswith("<|eot_id|>"):
|
||||
|
@ -174,16 +174,13 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format
|
|||
|
||||
def process_chat_completion_response(
|
||||
response: OpenAICompatCompletionResponse,
|
||||
formatter: ChatFormat,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ChatCompletionResponse:
|
||||
choice = response.choices[0]
|
||||
|
||||
# TODO: This does not work well with tool calls for vLLM remote provider
|
||||
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
||||
raw_message = formatter.decode_assistant_message_from_content(
|
||||
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
||||
)
|
||||
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
|
||||
|
||||
# NOTE: If we do not set tools in chat-completion request, we should not
|
||||
# expect the ToolCall in the response. Instead, we should return the raw
|
||||
|
@ -217,7 +214,7 @@ def process_chat_completion_response(
|
|||
|
||||
|
||||
async def process_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||
) -> AsyncGenerator:
|
||||
stop_reason = None
|
||||
|
||||
|
@ -254,7 +251,6 @@ async def process_completion_stream_response(
|
|||
|
||||
async def process_chat_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||
formatter: ChatFormat,
|
||||
request: ChatCompletionRequest,
|
||||
) -> AsyncGenerator:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
|
@ -333,7 +329,7 @@ async def process_chat_completion_stream_response(
|
|||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
|
||||
message = decode_assistant_message(buffer, stop_reason)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
|
|
|
@ -13,7 +13,9 @@ import re
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from llama_models.datatypes import StopReason
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -66,6 +68,11 @@ class CompletionRequestWithRawContent(CompletionRequest):
|
|||
content: RawContent
|
||||
|
||||
|
||||
def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage:
|
||||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
return formatter.decode_assistant_message_from_content(content, stop_reason)
|
||||
|
||||
|
||||
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
|
||||
def _process(c) -> str:
|
||||
if isinstance(c, str):
|
||||
|
@ -207,20 +214,22 @@ async def convert_image_content_to_url(
|
|||
return base64.b64encode(content).decode("utf-8")
|
||||
|
||||
|
||||
async def completion_request_to_prompt(request: CompletionRequest, formatter: ChatFormat) -> str:
|
||||
async def completion_request_to_prompt(request: CompletionRequest) -> str:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_content(request.content)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
async def completion_request_to_prompt_model_input_info(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> Tuple[str, int]:
|
||||
async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_content(request.content)
|
||||
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
|
||||
|
||||
|
@ -237,22 +246,24 @@ def augment_content_with_response_format_prompt(response_format, content):
|
|||
return content
|
||||
|
||||
|
||||
async def chat_completion_request_to_prompt(
|
||||
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
||||
) -> str:
|
||||
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(request.messages)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
async def chat_completion_request_to_model_input_info(
|
||||
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
||||
request: ChatCompletionRequest, llama_model: str
|
||||
) -> Tuple[str, int]:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(request.messages)
|
||||
return (
|
||||
formatter.tokenizer.decode(model_input.tokens),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue