From d266c59c2a8fabded65d390ccd594175f4901bca Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 3 Oct 2025 07:55:34 -0400 Subject: [PATCH] chore: remove deprecated inference.chat_completion implementations (#3654) # What does this PR do? remove unused chat_completion implementations vllm features ported - - requires max_tokens be set, use config value - set tool_choice to none if no tools provided ## Test Plan ci --- llama_stack/apis/inference/inference.py | 39 --- llama_stack/core/routers/inference.py | 88 ----- .../inference/meta_reference/inference.py | 316 +++--------------- .../sentence_transformers.py | 61 ++-- .../remote/inference/bedrock/bedrock.py | 119 ++----- .../remote/inference/cerebras/cerebras.py | 61 ---- .../remote/inference/databricks/databricks.py | 26 -- .../remote/inference/fireworks/fireworks.py | 70 ---- .../remote/inference/nvidia/nvidia.py | 67 +--- .../remote/inference/ollama/ollama.py | 98 ------ .../inference/passthrough/passthrough.py | 85 +---- .../remote/inference/runpod/runpod.py | 56 ---- .../providers/remote/inference/tgi/tgi.py | 75 ----- .../remote/inference/together/together.py | 61 ---- .../providers/remote/inference/vllm/vllm.py | 160 ++++----- .../remote/inference/watsonx/watsonx.py | 81 ----- .../utils/inference/litellm_openai_mixin.py | 62 ---- .../providers/inference/test_remote_vllm.py | 78 +---- 18 files changed, 193 insertions(+), 1410 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 829a94a6a..e88a16315 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -1006,45 +1006,6 @@ class InferenceProvider(Protocol): model_store: ModelStore | None = None - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - response_format: ResponseFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: - """Generate a chat completion for the given messages using the specified model. - - :param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. - :param messages: List of messages in the conversation. - :param sampling_params: Parameters to control the sampling strategy. - :param tools: (Optional) List of tool definitions available to the model. - :param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. - .. deprecated:: - Use tool_config instead. - :param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls. - .. deprecated:: - Use tool_config instead. - :param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options: - - `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format. - - `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it. - :param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False. - :param logprobs: (Optional) If specified, log probabilities for each token position will be returned. - :param tool_config: (Optional) Configuration for tool use. - :returns: If stream=False, returns a ChatCompletionResponse with the full completion. - If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk. - """ - ... - @webmethod(route="/inference/rerank", method="POST", level=LLAMA_STACK_API_V1ALPHA) async def rerank( self, diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 4b004a82c..c4338e614 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -27,7 +27,6 @@ from llama_stack.apis.inference import ( CompletionResponseStreamChunk, Inference, ListOpenAIChatCompletionResponse, - LogProbConfig, Message, OpenAIAssistantMessageParam, OpenAIChatCompletion, @@ -42,12 +41,7 @@ from llama_stack.apis.inference import ( OpenAIMessageParam, OpenAIResponseFormatParam, Order, - ResponseFormat, - SamplingParams, StopReason, - ToolChoice, - ToolConfig, - ToolDefinition, ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType @@ -185,88 +179,6 @@ class InferenceRouter(Inference): raise ModelTypeError(model_id, model.model_type, expected_model_type) return model - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = None, - tool_prompt_format: ToolPromptFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: - logger.debug( - f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", - ) - if sampling_params is None: - sampling_params = SamplingParams() - model = await self._get_model(model_id, ModelType.llm) - if tool_config: - if tool_choice and tool_choice != tool_config.tool_choice: - raise ValueError("tool_choice and tool_config.tool_choice must match") - if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: - raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") - else: - params = {} - if tool_choice: - params["tool_choice"] = tool_choice - if tool_prompt_format: - params["tool_prompt_format"] = tool_prompt_format - tool_config = ToolConfig(**params) - - tools = tools or [] - if tool_config.tool_choice == ToolChoice.none: - tools = [] - elif tool_config.tool_choice == ToolChoice.auto: - pass - elif tool_config.tool_choice == ToolChoice.required: - pass - else: - # verify tool_choice is one of the tools - tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] - if tool_config.tool_choice not in tool_names: - raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") - - params = dict( - model_id=model_id, - messages=messages, - sampling_params=sampling_params, - tools=tools, - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, - response_format=response_format, - stream=stream, - logprobs=logprobs, - tool_config=tool_config, - ) - provider = await self.routing_table.get_provider_impl(model_id) - prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) - - if stream: - response_stream = await provider.chat_completion(**params) - return self.stream_tokens_and_compute_metrics( - response=response_stream, - prompt_tokens=prompt_tokens, - model=model, - tool_prompt_format=tool_config.tool_prompt_format, - ) - - response = await provider.chat_completion(**params) - metrics = await self.count_tokens_and_compute_metrics( - response=response, - prompt_tokens=prompt_tokens, - model=model, - tool_prompt_format=tool_config.tool_prompt_format, - ) - # these metrics will show up in the client response. - response.metrics = ( - metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics - ) - return response - async def openai_completion( self, model: str, diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index db022d65d..fd65fa10d 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -5,37 +5,17 @@ # the root directory of this source tree. import asyncio -import os -import sys -from collections.abc import AsyncGenerator +from collections.abc import AsyncIterator +from typing import Any -from pydantic import BaseModel -from termcolor import cprint - -from llama_stack.apis.common.content_types import ( - TextDelta, - ToolCallDelta, - ToolCallParseStatus, -) from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - CompletionMessage, InferenceProvider, - LogProbConfig, - Message, - ResponseFormat, - SamplingParams, - StopReason, - TokenLogProbs, - ToolChoice, - ToolConfig, - ToolDefinition, - ToolPromptFormat, - UserMessage, +) +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAIMessageParam, + OpenAIResponseFormatParam, ) from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger @@ -53,13 +33,6 @@ from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_hf_repo_model_entry, ) -from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionToLlamaStackMixin, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_messages, - convert_request_to_raw, -) from .config import MetaReferenceInferenceConfig from .generators import LlamaGenerator @@ -76,7 +49,6 @@ def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_ class MetaReferenceInferenceImpl( - OpenAIChatCompletionToLlamaStackMixin, SentenceTransformerEmbeddingMixin, InferenceProvider, ModelsProtocolPrivate, @@ -161,10 +133,10 @@ class MetaReferenceInferenceImpl( self.llama_model = llama_model log.info("Warming up...") - await self.chat_completion( - model_id=model_id, - messages=[UserMessage(content="Hi how are you?")], - sampling_params=SamplingParams(max_tokens=20), + await self.openai_chat_completion( + model=model_id, + messages=[{"role": "user", "content": "Hi how are you?"}], + max_tokens=20, ) log.info("Warmed up!") @@ -176,242 +148,30 @@ class MetaReferenceInferenceImpl( elif request.model != self.model_id: raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}") - async def chat_completion( + async def openai_chat_completion( self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> AsyncGenerator: - if sampling_params is None: - sampling_params = SamplingParams() - if logprobs: - assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" - - # wrapper request to make it easier to pass around (internal only, not exposed to API) - request = ChatCompletionRequest( - model=model_id, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - response_format=response_format, - stream=stream, - logprobs=logprobs, - tool_config=tool_config or ToolConfig(), - ) - self.check_model(request) - - # augment and rewrite messages depending on the model - request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value) - # download media and convert to raw content so we can send it to the model - request = await convert_request_to_raw(request) - - if self.config.create_distributed_process_group: - if SEMAPHORE.locked(): - raise RuntimeError("Only one concurrent request is supported") - - if request.stream: - return self._stream_chat_completion(request) - else: - results = await self._nonstream_chat_completion([request]) - return results[0] - - async def _nonstream_chat_completion( - self, request_batch: list[ChatCompletionRequest] - ) -> list[ChatCompletionResponse]: - tokenizer = self.generator.formatter.tokenizer - - first_request = request_batch[0] - - class ItemState(BaseModel): - tokens: list[int] = [] - logprobs: list[TokenLogProbs] = [] - stop_reason: StopReason | None = None - finished: bool = False - - def impl(): - states = [ItemState() for _ in request_batch] - - for token_results in self.generator.chat_completion(request_batch): - first = token_results[0] - if not first.finished and not first.ignore_token: - if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"): - cprint(first.text, color="cyan", end="", file=sys.stderr) - if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2": - cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr) - - for result in token_results: - idx = result.batch_idx - state = states[idx] - if state.finished or result.ignore_token: - continue - - state.finished = result.finished - if first_request.logprobs: - state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]})) - - state.tokens.append(result.token) - if result.token == tokenizer.eot_id: - state.stop_reason = StopReason.end_of_turn - elif result.token == tokenizer.eom_id: - state.stop_reason = StopReason.end_of_message - - results = [] - for state in states: - if state.stop_reason is None: - state.stop_reason = StopReason.out_of_tokens - - raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason) - results.append( - ChatCompletionResponse( - completion_message=CompletionMessage( - content=raw_message.content, - stop_reason=raw_message.stop_reason, - tool_calls=raw_message.tool_calls, - ), - logprobs=state.logprobs if first_request.logprobs else None, - ) - ) - - return results - - if self.config.create_distributed_process_group: - async with SEMAPHORE: - return impl() - else: - return impl() - - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: - tokenizer = self.generator.formatter.tokenizer - - def impl(): - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta=TextDelta(text=""), - ) - ) - - tokens = [] - logprobs = [] - stop_reason = None - ipython = False - - for token_results in self.generator.chat_completion([request]): - token_result = token_results[0] - if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1": - cprint(token_result.text, color="cyan", end="", file=sys.stderr) - if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2": - cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr) - - if token_result.token == tokenizer.eot_id: - stop_reason = StopReason.end_of_turn - text = "" - elif token_result.token == tokenizer.eom_id: - stop_reason = StopReason.end_of_message - text = "" - else: - text = token_result.text - - if request.logprobs: - assert len(token_result.logprobs) == 1 - - logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) - - tokens.append(token_result.token) - - if not ipython and token_result.text.startswith("<|python_tag|>"): - ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call="", - parse_status=ToolCallParseStatus.started, - ), - ) - ) - continue - - if token_result.token == tokenizer.eot_id: - stop_reason = StopReason.end_of_turn - text = "" - elif token_result.token == tokenizer.eom_id: - stop_reason = StopReason.end_of_message - text = "" - else: - text = token_result.text - - if ipython: - delta = ToolCallDelta( - tool_call=text, - parse_status=ToolCallParseStatus.in_progress, - ) - else: - delta = TextDelta(text=text) - - if stop_reason is None: - if request.logprobs: - assert len(token_result.logprobs) == 1 - - logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - logprobs=logprobs if request.logprobs else None, - ) - ) - - if stop_reason is None: - stop_reason = StopReason.out_of_tokens - - message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) - - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call="", - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, - ) - ) - - for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=tool_call, - parse_status=ToolCallParseStatus.succeeded, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=TextDelta(text=""), - stop_reason=stop_reason, - ) - ) - - if self.config.create_distributed_process_group: - async with SEMAPHORE: - for x in impl(): - yield x - else: - for x in impl(): - yield x + model: str, + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider") diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index cd682dca6..b984d97bf 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -4,21 +4,19 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncGenerator +from collections.abc import AsyncIterator from typing import Any from llama_stack.apis.inference import ( InferenceProvider, - LogProbConfig, - Message, - ResponseFormat, - SamplingParams, - ToolChoice, - ToolConfig, - ToolDefinition, - ToolPromptFormat, ) -from llama_stack.apis.inference.inference import OpenAICompletion +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) from llama_stack.apis.models import ModelType from llama_stack.log import get_logger from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate @@ -73,21 +71,6 @@ class SentenceTransformersInferenceImpl( async def unregister_model(self, model_id: str) -> None: pass - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> AsyncGenerator: - raise ValueError("Sentence transformers don't support chat completion") - async def openai_completion( self, # Standard OpenAI completion parameters @@ -115,3 +98,31 @@ class SentenceTransformersInferenceImpl( suffix: str | None = None, ) -> OpenAICompletion: raise NotImplementedError("OpenAI completion not supported by sentence transformers provider") + + async def openai_chat_completion( + self, + model: str, + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider") diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index f87a5b5e2..9c8a74b47 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -5,39 +5,30 @@ # the root directory of this source tree. import json -from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import AsyncIterator from typing import Any from botocore.client import BaseClient from llama_stack.apis.inference import ( ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseStreamChunk, Inference, - LogProbConfig, - Message, OpenAIEmbeddingsResponse, - ResponseFormat, - SamplingParams, - ToolChoice, - ToolConfig, - ToolDefinition, - ToolPromptFormat, ) -from llama_stack.apis.inference.inference import OpenAICompletion +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionToLlamaStackMixin, - OpenAICompatCompletionChoice, - OpenAICompatCompletionResponse, get_sampling_strategy_options, - process_chat_completion_response, - process_chat_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, @@ -86,7 +77,6 @@ def _to_inference_profile_id(model_id: str, region: str = None) -> str: class BedrockInferenceAdapter( ModelRegistryHelper, Inference, - OpenAIChatCompletionToLlamaStackMixin, ): def __init__(self, config: BedrockConfig) -> None: ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) @@ -106,71 +96,6 @@ class BedrockInferenceAdapter( if self._client is not None: self._client.close() - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: - if sampling_params is None: - sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) - request = ChatCompletionRequest( - model=model.provider_resource_id, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - response_format=response_format, - stream=stream, - logprobs=logprobs, - tool_config=tool_config, - ) - - if stream: - return self._stream_chat_completion(request) - else: - return await self._nonstream_chat_completion(request) - - async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - params = await self._get_params_for_chat_completion(request) - res = self.client.invoke_model(**params) - chunk = next(res["body"]) - result = json.loads(chunk.decode("utf-8")) - - choice = OpenAICompatCompletionChoice( - finish_reason=result["stop_reason"], - text=result["generation"], - ) - - response = OpenAICompatCompletionResponse(choices=[choice]) - return process_chat_completion_response(response, request) - - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: - params = await self._get_params_for_chat_completion(request) - res = self.client.invoke_model_with_response_stream(**params) - event_stream = res["body"] - - async def _generate_and_convert_to_openai_compat(): - for chunk in event_stream: - chunk = chunk["chunk"]["bytes"] - result = json.loads(chunk.decode("utf-8")) - choice = OpenAICompatCompletionChoice( - finish_reason=result["stop_reason"], - text=result["generation"], - ) - yield OpenAICompatCompletionResponse(choices=[choice]) - - stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, request): - yield chunk - async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict: bedrock_model = request.model @@ -235,3 +160,31 @@ class BedrockInferenceAdapter( suffix: str | None = None, ) -> OpenAICompletion: raise NotImplementedError("OpenAI completion not supported by the Bedrock provider") + + async def openai_chat_completion( + self, + model: str, + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider") diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 43b984f7f..e3ce9bfab 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncGenerator from urllib.parse import urljoin from cerebras.cloud.sdk import AsyncCerebras @@ -12,23 +11,12 @@ from cerebras.cloud.sdk import AsyncCerebras from llama_stack.apis.inference import ( ChatCompletionRequest, CompletionRequest, - CompletionResponse, Inference, - LogProbConfig, - Message, OpenAIEmbeddingsResponse, - ResponseFormat, - SamplingParams, - ToolChoice, - ToolConfig, - ToolDefinition, - ToolPromptFormat, TopKSamplingStrategy, ) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, - process_chat_completion_response, - process_chat_completion_stream_response, ) from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -64,55 +52,6 @@ class CerebrasInferenceAdapter( async def shutdown(self) -> None: pass - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - response_format: ResponseFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> AsyncGenerator: - if sampling_params is None: - sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) - request = ChatCompletionRequest( - model=model.provider_resource_id, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, - response_format=response_format, - stream=stream, - logprobs=logprobs, - tool_config=tool_config, - ) - - if stream: - return self._stream_chat_completion(request) - else: - return await self._nonstream_chat_completion(request) - - async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse: - params = await self._get_params(request) - - r = await self._cerebras_client.completions.create(**params) - - return process_chat_completion_response(r, request) - - async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator: - params = await self._get_params(request) - - stream = await self._cerebras_client.completions.create(**params) - - async for chunk in process_chat_completion_stream_response(stream, request): - yield chunk - async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy): raise ValueError("`top_k` not supported by Cerebras") diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index cd5dfb40d..a2621b81e 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -4,25 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncIterator from typing import Any from databricks.sdk import WorkspaceClient from llama_stack.apis.inference import ( - ChatCompletionResponse, - ChatCompletionResponseStreamChunk, Inference, - LogProbConfig, - Message, Model, OpenAICompletion, - ResponseFormat, - SamplingParams, - ToolChoice, - ToolConfig, - ToolDefinition, - ToolPromptFormat, ) from llama_stack.apis.models import ModelType from llama_stack.log import get_logger @@ -83,21 +72,6 @@ class DatabricksInferenceAdapter( ) -> OpenAICompletion: raise NotImplementedError() - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - response_format: ResponseFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: - raise NotImplementedError() - async def list_models(self) -> list[Model] | None: self._model_cache = {} # from OpenAIMixin ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 83d9ac354..56c12fd49 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -4,23 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncGenerator from fireworks.client import Fireworks from llama_stack.apis.inference import ( ChatCompletionRequest, - ChatCompletionResponse, Inference, LogProbConfig, - Message, ResponseFormat, ResponseFormatType, SamplingParams, - ToolChoice, - ToolConfig, - ToolDefinition, - ToolPromptFormat, ) from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger @@ -30,8 +23,6 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, - process_chat_completion_response, - process_chat_completion_stream_response, ) from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -80,67 +71,6 @@ class FireworksInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData fireworks_api_key = self.get_api_key() return Fireworks(api_key=fireworks_api_key) - def _preprocess_prompt_for_fireworks(self, prompt: str) -> str: - """Remove BOS token as Fireworks automatically prepends it""" - if prompt.startswith("<|begin_of_text|>"): - return prompt[len("<|begin_of_text|>") :] - return prompt - - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - response_format: ResponseFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> AsyncGenerator: - if sampling_params is None: - sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) - request = ChatCompletionRequest( - model=model.provider_resource_id, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - response_format=response_format, - stream=stream, - logprobs=logprobs, - tool_config=tool_config, - ) - - if stream: - return self._stream_chat_completion(request) - else: - return await self._nonstream_chat_completion(request) - - async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - params = await self._get_params(request) - if "messages" in params: - r = await self._get_client().chat.completions.acreate(**params) - else: - r = await self._get_client().completion.acreate(**params) - return process_chat_completion_response(r, request) - - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: - params = await self._get_params(request) - - async def _to_async_generator(): - if "messages" in params: - stream = self._get_client().chat.completions.acreate(**params) - else: - stream = self._get_client().completion.acreate(**params) - async for chunk in stream: - yield chunk - - stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, request): - yield chunk - def _build_options( self, sampling_params: SamplingParams | None, diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 8619b6b68..2e6c3d769 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -4,38 +4,19 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import warnings -from collections.abc import AsyncIterator -from openai import NOT_GIVEN, APIConnectionError +from openai import NOT_GIVEN from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseStreamChunk, Inference, - LogProbConfig, - Message, OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, - ResponseFormat, - SamplingParams, - ToolChoice, - ToolConfig, ) from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat -from llama_stack.providers.utils.inference.openai_compat import ( - convert_openai_chat_completion_choice, - convert_openai_chat_completion_stream, -) from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from . import NVIDIAConfig -from .openai_utils import ( - convert_chat_completion_request, -) from .utils import _is_nvidia_hosted logger = get_logger(name=__name__, category="inference::nvidia") @@ -149,49 +130,3 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference): model=response.model, usage=usage, ) - - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: - if sampling_params is None: - sampling_params = SamplingParams() - if tool_prompt_format: - warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2) - - # await check_health(self._config) # this raises errors - - provider_model_id = await self._get_provider_model_id(model_id) - request = await convert_chat_completion_request( - request=ChatCompletionRequest( - model=provider_model_id, - messages=messages, - sampling_params=sampling_params, - response_format=response_format, - tools=tools, - stream=stream, - logprobs=logprobs, - tool_config=tool_config, - ), - n=1, - ) - - try: - response = await self.client.chat.completions.create(**request) - except APIConnectionError as e: - raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e - - if stream: - return convert_openai_chat_completion_stream(response, enable_incremental_tool_calls=False) - else: - # we pass n=1 to get only one completion - return convert_openai_chat_completion_choice(response.choices[0]) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 85ad62f9a..de55c1b58 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -6,7 +6,6 @@ import asyncio -from collections.abc import AsyncGenerator from typing import Any from ollama import AsyncClient as AsyncOllamaClient @@ -18,19 +17,10 @@ from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.errors import UnsupportedModelError from llama_stack.apis.inference import ( ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseStreamChunk, GrammarResponseFormat, InferenceProvider, JsonSchemaResponseFormat, - LogProbConfig, Message, - ResponseFormat, - SamplingParams, - ToolChoice, - ToolConfig, - ToolDefinition, - ToolPromptFormat, ) from llama_stack.apis.models import Model from llama_stack.log import get_logger @@ -46,11 +36,7 @@ from llama_stack.providers.utils.inference.model_registry import ( build_hf_repo_model_entry, ) from llama_stack.providers.utils.inference.openai_compat import ( - OpenAICompatCompletionChoice, - OpenAICompatCompletionResponse, get_sampling_options, - process_chat_completion_response, - process_chat_completion_stream_response, ) from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -161,39 +147,6 @@ class OllamaInferenceAdapter( raise ValueError("Model store not set") return await self.model_store.get_model(model_id) - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - response_format: ResponseFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - if sampling_params is None: - sampling_params = SamplingParams() - model = await self._get_model(model_id) - if model.provider_resource_id is None: - raise ValueError(f"Model {model_id} has no provider_resource_id set") - request = ChatCompletionRequest( - model=model.provider_resource_id, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - stream=stream, - logprobs=logprobs, - response_format=response_format, - tool_config=tool_config, - ) - if stream: - return self._stream_chat_completion(request) - else: - return await self._nonstream_chat_completion(request) - async def _get_params(self, request: ChatCompletionRequest) -> dict: sampling_options = get_sampling_options(request.sampling_params) # This is needed since the Ollama API expects num_predict to be set @@ -233,57 +186,6 @@ class OllamaInferenceAdapter( return params - async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - params = await self._get_params(request) - if "messages" in params: - r = await self.ollama_client.chat(**params) - else: - r = await self.ollama_client.generate(**params) - - if "message" in r: - choice = OpenAICompatCompletionChoice( - finish_reason=r["done_reason"] if r["done"] else None, - text=r["message"]["content"], - ) - else: - choice = OpenAICompatCompletionChoice( - finish_reason=r["done_reason"] if r["done"] else None, - text=r["response"], - ) - response = OpenAICompatCompletionResponse( - choices=[choice], - ) - return process_chat_completion_response(response, request) - - async def _stream_chat_completion( - self, request: ChatCompletionRequest - ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - params = await self._get_params(request) - - async def _generate_and_convert_to_openai_compat(): - if "messages" in params: - s = await self.ollama_client.chat(**params) - else: - s = await self.ollama_client.generate(**params) - async for chunk in s: - if "message" in chunk: - choice = OpenAICompatCompletionChoice( - finish_reason=chunk["done_reason"] if chunk["done"] else None, - text=chunk["message"]["content"], - ) - else: - choice = OpenAICompatCompletionChoice( - finish_reason=chunk["done_reason"] if chunk["done"] else None, - text=chunk["response"], - ) - yield OpenAICompatCompletionResponse( - choices=[choice], - ) - - stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, request): - yield chunk - async def register_model(self, model: Model) -> Model: if await self.check_model_availability(model.provider_model_id): return model diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 3ac45e949..e0ddb237e 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -4,33 +4,22 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import AsyncIterator from typing import Any from llama_stack_client import AsyncLlamaStackClient from llama_stack.apis.inference import ( - ChatCompletionResponse, - ChatCompletionResponseStreamChunk, - CompletionMessage, Inference, - LogProbConfig, - Message, OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, OpenAIEmbeddingsResponse, OpenAIMessageParam, OpenAIResponseFormatParam, - ResponseFormat, - SamplingParams, - ToolChoice, - ToolConfig, - ToolDefinition, - ToolPromptFormat, ) from llama_stack.apis.models import Model -from llama_stack.core.library_client import convert_pydantic_to_json_value, convert_to_pydantic +from llama_stack.core.library_client import convert_pydantic_to_json_value from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params @@ -85,76 +74,6 @@ class PassthroughInferenceAdapter(Inference): provider_data=provider_data, ) - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - response_format: ResponseFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> AsyncGenerator: - if sampling_params is None: - sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) - - # TODO: revisit this remove tool_calls from messages logic - for message in messages: - if hasattr(message, "tool_calls"): - message.tool_calls = None - - request_params = { - "model_id": model.provider_resource_id, - "messages": messages, - "sampling_params": sampling_params, - "tools": tools, - "tool_choice": tool_choice, - "tool_prompt_format": tool_prompt_format, - "response_format": response_format, - "stream": stream, - "logprobs": logprobs, - } - - # only pass through the not None params - request_params = {key: value for key, value in request_params.items() if value is not None} - - # cast everything to json dict - json_params = self.cast_value_to_json_dict(request_params) - - if stream: - return self._stream_chat_completion(json_params) - else: - return await self._nonstream_chat_completion(json_params) - - async def _nonstream_chat_completion(self, json_params: dict[str, Any]) -> ChatCompletionResponse: - client = self._get_client() - response = await client.inference.chat_completion(**json_params) - - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=response.completion_message.content.text, - stop_reason=response.completion_message.stop_reason, - tool_calls=response.completion_message.tool_calls, - ), - logprobs=response.logprobs, - ) - - async def _stream_chat_completion(self, json_params: dict[str, Any]) -> AsyncGenerator: - client = self._get_client() - stream_response = await client.inference.chat_completion(**json_params) - - async for chunk in stream_response: - chunk = chunk.to_dict() - - # temporary hack to remove the metrics from the response - chunk["metrics"] = [] - chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk) - yield chunk - async def openai_embeddings( self, model: str, diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 77c5c7187..1c99182ea 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -3,9 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncGenerator -from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import OpenAIEmbeddingsResponse @@ -13,10 +11,7 @@ from llama_stack.apis.inference import OpenAIEmbeddingsResponse # from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionToLlamaStackMixin, get_sampling_options, - process_chat_completion_response, - process_chat_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, @@ -53,7 +48,6 @@ MODEL_ENTRIES = [ class RunpodInferenceAdapter( ModelRegistryHelper, Inference, - OpenAIChatCompletionToLlamaStackMixin, ): def __init__(self, config: RunpodImplConfig) -> None: ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS) @@ -65,56 +59,6 @@ class RunpodInferenceAdapter( async def shutdown(self) -> None: pass - async def chat_completion( - self, - model: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> AsyncGenerator: - if sampling_params is None: - sampling_params = SamplingParams() - request = ChatCompletionRequest( - model=model, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - stream=stream, - logprobs=logprobs, - tool_config=tool_config, - ) - - client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) - if stream: - return self._stream_chat_completion(request, client) - else: - return await self._nonstream_chat_completion(request, client) - - async def _nonstream_chat_completion( - self, request: ChatCompletionRequest, client: OpenAI - ) -> ChatCompletionResponse: - params = self._get_params(request) - r = client.completions.create(**params) - return process_chat_completion_response(r, request) - - async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: - params = self._get_params(request) - - async def _to_async_generator(): - s = client.completions.create(**params) - for chunk in s: - yield chunk - - stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, request): - yield chunk - def _get_params(self, request: ChatCompletionRequest) -> dict: return { "model": self.map_to_provider_model(request.model), diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 703ee2c1b..0bb56da2b 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -5,25 +5,16 @@ # the root directory of this source tree. -from collections.abc import AsyncGenerator - from huggingface_hub import AsyncInferenceClient, HfApi from pydantic import SecretStr from llama_stack.apis.inference import ( ChatCompletionRequest, - ChatCompletionResponse, Inference, - LogProbConfig, - Message, OpenAIEmbeddingsResponse, ResponseFormat, ResponseFormatType, SamplingParams, - ToolChoice, - ToolConfig, - ToolDefinition, - ToolPromptFormat, ) from llama_stack.apis.models import Model from llama_stack.apis.models.models import ModelType @@ -34,11 +25,7 @@ from llama_stack.providers.utils.inference.model_registry import ( build_hf_repo_model_entry, ) from llama_stack.providers.utils.inference.openai_compat import ( - OpenAICompatCompletionChoice, - OpenAICompatCompletionResponse, get_sampling_options, - process_chat_completion_response, - process_chat_completion_stream_response, ) from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -146,68 +133,6 @@ class _HfAdapter( return options - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - response_format: ResponseFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> AsyncGenerator: - if sampling_params is None: - sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) - request = ChatCompletionRequest( - model=model.provider_resource_id, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - response_format=response_format, - stream=stream, - logprobs=logprobs, - tool_config=tool_config, - ) - - if stream: - return self._stream_chat_completion(request) - else: - return await self._nonstream_chat_completion(request) - - async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - params = await self._get_params(request) - r = await self.hf_client.text_generation(**params) - - choice = OpenAICompatCompletionChoice( - finish_reason=r.details.finish_reason, - text="".join(t.text for t in r.details.tokens), - ) - response = OpenAICompatCompletionResponse( - choices=[choice], - ) - return process_chat_completion_response(response, request) - - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: - params = await self._get_params(request) - - async def _generate_and_convert_to_openai_compat(): - s = await self.hf_client.text_generation(**params) - async for chunk in s: - token_result = chunk.token - - choice = OpenAICompatCompletionChoice(text=token_result.text) - yield OpenAICompatCompletionResponse( - choices=[choice], - ) - - stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, request): - yield chunk - async def _get_params(self, request: ChatCompletionRequest) -> dict: prompt, input_tokens = await chat_completion_request_to_model_input_info( request, self.register_helper.get_llama_model(request.model) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 1f7a92d69..6f7a19743 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncGenerator from openai import AsyncOpenAI from together import AsyncTogether @@ -12,18 +11,12 @@ from together.constants import BASE_URL from llama_stack.apis.inference import ( ChatCompletionRequest, - ChatCompletionResponse, Inference, LogProbConfig, - Message, OpenAIEmbeddingsResponse, ResponseFormat, ResponseFormatType, SamplingParams, - ToolChoice, - ToolConfig, - ToolDefinition, - ToolPromptFormat, ) from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage from llama_stack.apis.models import Model, ModelType @@ -33,8 +26,6 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, - process_chat_completion_response, - process_chat_completion_stream_response, ) from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -122,58 +113,6 @@ class TogetherInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData) return options - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - response_format: ResponseFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> AsyncGenerator: - if sampling_params is None: - sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) - request = ChatCompletionRequest( - model=model.provider_resource_id, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - response_format=response_format, - stream=stream, - logprobs=logprobs, - tool_config=tool_config, - ) - - if stream: - return self._stream_chat_completion(request) - else: - return await self._nonstream_chat_completion(request) - - async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - params = await self._get_params(request) - client = self._get_client() - if "messages" in params: - r = await client.chat.completions.create(**params) - else: - r = await client.completions.create(**params) - return process_chat_completion_response(r, request) - - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: - params = await self._get_params(request) - client = self._get_client() - if "messages" in params: - stream = await client.chat.completions.create(**params) - else: - stream = await client.completions.create(**params) - - async for chunk in process_chat_completion_stream_response(stream, request): - yield chunk - async def _get_params(self, request: ChatCompletionRequest) -> dict: input_dict = {} media_present = request_has_media(request) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 2b58b4262..54ac8e1dc 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -9,7 +9,7 @@ from typing import Any from urllib.parse import urljoin import httpx -from openai import APIConnectionError, AsyncOpenAI +from openai import APIConnectionError from openai.types.chat.chat_completion_chunk import ( ChatCompletionChunk as OpenAIChatCompletionChunk, ) @@ -21,23 +21,18 @@ from llama_stack.apis.common.content_types import ( ) from llama_stack.apis.inference import ( ChatCompletionRequest, - ChatCompletionResponse, ChatCompletionResponseEvent, ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, - CompletionMessage, GrammarResponseFormat, Inference, JsonSchemaResponseFormat, - LogProbConfig, - Message, ModelStore, - ResponseFormat, - SamplingParams, + OpenAIChatCompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, ToolChoice, - ToolConfig, ToolDefinition, - ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger @@ -56,10 +51,8 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( UnparseableToolCall, convert_message_to_openai_dict, - convert_openai_chat_completion_stream, convert_tool_call, get_sampling_options, - process_chat_completion_stream_response, ) from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin @@ -339,90 +332,6 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro def get_extra_client_params(self): return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)} - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - response_format: ResponseFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - if sampling_params is None: - sampling_params = SamplingParams() - model = await self._get_model(model_id) - if model.provider_resource_id is None: - raise ValueError(f"Model {model_id} has no provider_resource_id set") - # This is to be consistent with OpenAI API and support vLLM <= v0.6.3 - # References: - # * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice - # * https://github.com/vllm-project/vllm/pull/10000 - if not tools and tool_config is not None: - tool_config.tool_choice = ToolChoice.none - request = ChatCompletionRequest( - model=model.provider_resource_id, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - stream=stream, - logprobs=logprobs, - response_format=response_format, - tool_config=tool_config, - ) - if stream: - return self._stream_chat_completion_with_client(request, self.client) - else: - return await self._nonstream_chat_completion(request, self.client) - - async def _nonstream_chat_completion( - self, request: ChatCompletionRequest, client: AsyncOpenAI - ) -> ChatCompletionResponse: - assert self.client is not None - params = await self._get_params(request) - r = await client.chat.completions.create(**params) - choice = r.choices[0] - result = ChatCompletionResponse( - completion_message=CompletionMessage( - content=choice.message.content or "", - stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason), - tool_calls=_convert_to_vllm_tool_calls_in_response(choice.message.tool_calls), - ), - logprobs=None, - ) - return result - - async def _stream_chat_completion(self, response: Any) -> AsyncIterator[ChatCompletionResponseStreamChunk]: - # This method is called from LiteLLMOpenAIMixin.chat_completion - # The response parameter contains the litellm response - # We need to convert it to our format - async def _stream_generator(): - async for chunk in response: - yield chunk - - async for chunk in convert_openai_chat_completion_stream( - _stream_generator(), enable_incremental_tool_calls=True - ): - yield chunk - - async def _stream_chat_completion_with_client( - self, request: ChatCompletionRequest, client: AsyncOpenAI - ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - """Helper method for streaming with explicit client parameter.""" - assert self.client is not None - params = await self._get_params(request) - - stream = await client.chat.completions.create(**params) - if request.tools: - res = _process_vllm_chat_completion_stream_response(stream) - else: - res = process_chat_completion_stream_response(stream, request) - async for chunk in res: - yield chunk - async def register_model(self, model: Model) -> Model: try: model = await self.register_helper.register_model(model) @@ -471,3 +380,64 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro "stream": request.stream, **options, } + + async def openai_chat_completion( + self, + model: str, + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + max_tokens = max_tokens or self.config.max_tokens + + # This is to be consistent with OpenAI API and support vLLM <= v0.6.3 + # References: + # * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice + # * https://github.com/vllm-project/vllm/pull/10000 + if not tools and tool_choice is not None: + tool_choice = ToolChoice.none.value + + return await super().openai_chat_completion( + model=model, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index cb9d61102..0557aff5f 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -13,35 +13,22 @@ from openai import AsyncOpenAI from llama_stack.apis.inference import ( ChatCompletionRequest, - ChatCompletionResponse, CompletionRequest, GreedySamplingStrategy, Inference, - LogProbConfig, - Message, OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, OpenAIEmbeddingsResponse, OpenAIMessageParam, OpenAIResponseFormatParam, - ResponseFormat, - SamplingParams, - ToolChoice, - ToolConfig, - ToolDefinition, - ToolPromptFormat, TopKSamplingStrategy, TopPSamplingStrategy, ) from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( - OpenAICompatCompletionChoice, - OpenAICompatCompletionResponse, prepare_openai_completion_params, - process_chat_completion_response, - process_chat_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, @@ -100,74 +87,6 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): ) return self._openai_client - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - response_format: ResponseFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> AsyncGenerator: - if sampling_params is None: - sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) - request = ChatCompletionRequest( - model=model.provider_resource_id, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - response_format=response_format, - stream=stream, - logprobs=logprobs, - tool_config=tool_config, - ) - - if stream: - return self._stream_chat_completion(request) - else: - return await self._nonstream_chat_completion(request) - - async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - params = await self._get_params(request) - r = self._get_client(request.model).generate(**params) - choices = [] - if "results" in r: - for result in r["results"]: - choice = OpenAICompatCompletionChoice( - finish_reason=result["stop_reason"] if result["stop_reason"] else None, - text=result["generated_text"], - ) - choices.append(choice) - response = OpenAICompatCompletionResponse( - choices=choices, - ) - return process_chat_completion_response(response, request) - - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: - params = await self._get_params(request) - model_id = request.model - - # if we shift to TogetherAsyncClient, we won't need this wrapper - async def _to_async_generator(): - s = self._get_client(model_id).generate_text_stream(**params) - for chunk in s: - choice = OpenAICompatCompletionChoice( - finish_reason=None, - text=chunk, - ) - yield OpenAICompatCompletionResponse( - choices=[choice], - ) - - stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, request): - yield chunk - async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: input_dict = {"params": {}} media_present = request_has_media(request) diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index c8d3bddc7..6c8f61c3b 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -11,12 +11,8 @@ import litellm from llama_stack.apis.inference import ( ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseStreamChunk, InferenceProvider, JsonSchemaResponseFormat, - LogProbConfig, - Message, OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, @@ -24,12 +20,7 @@ from llama_stack.apis.inference import ( OpenAIEmbeddingUsage, OpenAIMessageParam, OpenAIResponseFormatParam, - ResponseFormat, - SamplingParams, ToolChoice, - ToolConfig, - ToolDefinition, - ToolPromptFormat, ) from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger @@ -37,8 +28,6 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe from llama_stack.providers.utils.inference.openai_compat import ( b64_encode_openai_embeddings_response, convert_message_to_openai_dict_new, - convert_openai_chat_completion_choice, - convert_openai_chat_completion_stream, convert_tooldef_to_openai_tool, get_sampling_options, prepare_openai_completion_params, @@ -105,57 +94,6 @@ class LiteLLMOpenAIMixin( else model_id ) - async def chat_completion( - self, - model_id: str, - messages: list[Message], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - response_format: ResponseFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: - if sampling_params is None: - sampling_params = SamplingParams() - - model = await self.model_store.get_model(model_id) - request = ChatCompletionRequest( - model=model.provider_resource_id, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - response_format=response_format, - stream=stream, - logprobs=logprobs, - tool_config=tool_config, - ) - - params = await self._get_params(request) - params["model"] = self.get_litellm_model_name(params["model"]) - - logger.debug(f"params to litellm (openai compat): {params}") - # see https://docs.litellm.ai/docs/completion/stream#async-completion - response = await litellm.acompletion(**params) - if stream: - return self._stream_chat_completion(response) - else: - return convert_openai_chat_completion_choice(response.choices[0]) - - async def _stream_chat_completion( - self, response: litellm.ModelResponse - ) -> AsyncIterator[ChatCompletionResponseStreamChunk]: - async def _stream_generator(): - async for chunk in response: - yield chunk - - async for chunk in convert_openai_chat_completion_stream( - _stream_generator(), enable_incremental_tool_calls=True - ): - yield chunk - def _add_additional_properties_recursive(self, schema): """ Recursively add additionalProperties: False to all object schemas diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index bb560d378..cd31e4943 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -30,18 +30,14 @@ from openai.types.model import Model as OpenAIModel from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponseEventType, - CompletionMessage, OpenAIAssistantMessageParam, OpenAIChatCompletion, OpenAIChoice, - SystemMessage, ToolChoice, - ToolConfig, - ToolResponseMessage, UserMessage, ) from llama_stack.apis.models import Model -from llama_stack.models.llama.datatypes import StopReason, ToolCall +from llama_stack.models.llama.datatypes import StopReason from llama_stack.providers.datatypes import HealthStatus from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig from llama_stack.providers.remote.inference.vllm.vllm import ( @@ -99,66 +95,24 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter): mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference") vllm_inference_adapter.model_store.get_model.return_value = mock_model - with patch.object(vllm_inference_adapter, "_nonstream_chat_completion") as mock_nonstream_completion: + # Patch the client property to avoid instantiating a real AsyncOpenAI client + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property: + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock() + mock_client_property.return_value = mock_client + # No tools but auto tool choice - await vllm_inference_adapter.chat_completion( + await vllm_inference_adapter.openai_chat_completion( "mock-model", [], stream=False, tools=None, - tool_config=ToolConfig(tool_choice=ToolChoice.auto), + tool_choice=ToolChoice.auto.value, ) - mock_nonstream_completion.assert_called() - request = mock_nonstream_completion.call_args.args[0] + mock_client.chat.completions.create.assert_called() + call_args = mock_client.chat.completions.create.call_args # Ensure tool_choice gets converted to none for older vLLM versions - assert request.tool_config.tool_choice == ToolChoice.none - - -async def test_tool_call_response(vllm_inference_adapter): - """Verify that tool call arguments from a CompletionMessage are correctly converted - into the expected JSON format.""" - - # Patch the client property to avoid instantiating a real AsyncOpenAI client - with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock() - mock_create_client.return_value = mock_client - - # Mock the model to return a proper provider_resource_id - mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference") - vllm_inference_adapter.model_store.get_model.return_value = mock_model - - messages = [ - SystemMessage(content="You are a helpful assistant"), - UserMessage(content="How many?"), - CompletionMessage( - content="", - stop_reason=StopReason.end_of_turn, - tool_calls=[ - ToolCall( - call_id="foo", - tool_name="knowledge_search", - arguments='{"query": "How many?"}', - ) - ], - ), - ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."), - ] - await vllm_inference_adapter.chat_completion( - "mock-model", - messages, - stream=False, - tools=[], - tool_config=ToolConfig(tool_choice=ToolChoice.auto), - ) - - assert mock_client.chat.completions.create.call_args.kwargs["messages"][2]["tool_calls"] == [ - { - "id": "foo", - "type": "function", - "function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'}, - } - ] + assert call_args.kwargs["tool_choice"] == ToolChoice.none.value async def test_tool_call_delta_empty_tool_call_buf(): @@ -744,12 +698,10 @@ async def test_provider_data_var_context_propagation(vllm_inference_adapter): try: # Execute chat completion - await vllm_inference_adapter.chat_completion( - "test-model", - [UserMessage(content="Hello")], + await vllm_inference_adapter.openai_chat_completion( + model="test-model", + messages=[UserMessage(content="Hello")], stream=False, - tools=None, - tool_config=ToolConfig(tool_choice=ToolChoice.auto), ) # Verify that ALL client calls were made with the correct parameters