chore: remove deprecated inference.chat_completion implementations (#3654)

# What does this PR do?

remove unused chat_completion implementations

vllm features ported -
 - requires max_tokens be set, use config value
 - set tool_choice to none if no tools provided


## Test Plan

ci
This commit is contained in:
Matthew Farrellee 2025-10-03 07:55:34 -04:00 committed by GitHub
parent 4dfbe46954
commit d266c59c2a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 193 additions and 1410 deletions

View file

@ -5,37 +5,17 @@
# the root directory of this source tree.
import asyncio
import os
import sys
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from typing import Any
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
InferenceProvider,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
StopReason,
TokenLogProbs,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
UserMessage,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
@ -53,13 +33,6 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
convert_request_to_raw,
)
from .config import MetaReferenceInferenceConfig
from .generators import LlamaGenerator
@ -76,7 +49,6 @@ def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_
class MetaReferenceInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
InferenceProvider,
ModelsProtocolPrivate,
@ -161,10 +133,10 @@ class MetaReferenceInferenceImpl(
self.llama_model = llama_model
log.info("Warming up...")
await self.chat_completion(
model_id=model_id,
messages=[UserMessage(content="Hi how are you?")],
sampling_params=SamplingParams(max_tokens=20),
await self.openai_chat_completion(
model=model_id,
messages=[{"role": "user", "content": "Hi how are you?"}],
max_tokens=20,
)
log.info("Warmed up!")
@ -176,242 +148,30 @@ class MetaReferenceInferenceImpl(
elif request.model != self.model_id:
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
async def chat_completion(
async def openai_chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if request.stream:
return self._stream_chat_completion(request)
else:
results = await self._nonstream_chat_completion([request])
return results[0]
async def _nonstream_chat_completion(
self, request_batch: list[ChatCompletionRequest]
) -> list[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: list[int] = []
logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl():
states = [ItemState() for _ in request_batch]
for token_results in self.generator.chat_completion(request_batch):
first = token_results[0]
if not first.finished and not first.ignore_token:
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
cprint(first.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
results = []
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
results.append(
ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=state.logprobs if first_request.logprobs else None,
)
)
return results
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
tokens = []
logprobs = []
stop_reason = None
ipython = False
for token_results in self.generator.chat_completion([request]):
token_result = token_results[0]
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.started,
),
)
)
continue
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
tool_call=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = TextDelta(text=text)
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")

View file

@ -4,21 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.inference import (
InferenceProvider,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAICompletion
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
@ -73,21 +71,6 @@ class SentenceTransformersInferenceImpl(
async def unregister_model(self, model_id: str) -> None:
pass
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
async def openai_completion(
self,
# Standard OpenAI completion parameters
@ -115,3 +98,31 @@ class SentenceTransformersInferenceImpl(
suffix: str | None = None,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")