chore: remove deprecated inference.chat_completion implementations

vllm -
 - requires max_tokens be set, use config value
 - set tool_choice to none if no tools provided
This commit is contained in:
Matthew Farrellee 2025-10-01 11:28:42 -04:00
parent f1748e2f92
commit f754e1b65b
18 changed files with 193 additions and 1411 deletions

View file

@ -1008,45 +1008,6 @@ class InferenceProvider(Protocol):
model_store: ModelStore | None = None
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
"""Generate a chat completion for the given messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation.
:param sampling_params: Parameters to control the sampling strategy.
:param tools: (Optional) List of tool definitions available to the model.
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
.. deprecated::
Use tool_config instead.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
.. deprecated::
Use tool_config instead.
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:param tool_config: (Optional) Configuration for tool use.
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
"""
...
@webmethod(route="/inference/rerank", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def rerank(
self,

View file

@ -27,7 +27,6 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk,
Inference,
ListOpenAIChatCompletionResponse,
LogProbConfig,
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
@ -42,12 +41,7 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
OpenAIResponseFormatParam,
Order,
ResponseFormat,
SamplingParams,
StopReason,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
@ -185,88 +179,6 @@ class InferenceRouter(Inference):
raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = None,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id, ModelType.llm)
if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match")
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
else:
params = {}
if tool_choice:
params["tool_choice"] = tool_choice
if tool_prompt_format:
params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params)
tools = tools or []
if tool_config.tool_choice == ToolChoice.none:
tools = []
elif tool_config.tool_choice == ToolChoice.auto:
pass
elif tool_config.tool_choice == ToolChoice.required:
pass
else:
# verify tool_choice is one of the tools
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
if tool_config.tool_choice not in tool_names:
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
params = dict(
model_id=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
provider = await self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream:
response_stream = await provider.chat_completion(**params)
return self.stream_tokens_and_compute_metrics(
response=response_stream,
prompt_tokens=prompt_tokens,
model=model,
tool_prompt_format=tool_config.tool_prompt_format,
)
response = await provider.chat_completion(**params)
metrics = await self.count_tokens_and_compute_metrics(
response=response,
prompt_tokens=prompt_tokens,
model=model,
tool_prompt_format=tool_config.tool_prompt_format,
)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
)
return response
async def openai_completion(
self,
model: str,

View file

@ -5,37 +5,17 @@
# the root directory of this source tree.
import asyncio
import os
import sys
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from typing import Any
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
InferenceProvider,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
StopReason,
TokenLogProbs,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
UserMessage,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
@ -53,13 +33,6 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
convert_request_to_raw,
)
from .config import MetaReferenceInferenceConfig
from .generators import LlamaGenerator
@ -76,7 +49,6 @@ def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_
class MetaReferenceInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
InferenceProvider,
ModelsProtocolPrivate,
@ -161,10 +133,10 @@ class MetaReferenceInferenceImpl(
self.llama_model = llama_model
log.info("Warming up...")
await self.chat_completion(
model_id=model_id,
messages=[UserMessage(content="Hi how are you?")],
sampling_params=SamplingParams(max_tokens=20),
await self.openai_chat_completion(
model=model_id,
messages=[{"role": "user", "content": "Hi how are you?"}],
max_tokens=20,
)
log.info("Warmed up!")
@ -176,242 +148,30 @@ class MetaReferenceInferenceImpl(
elif request.model != self.model_id:
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
async def chat_completion(
async def openai_chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if request.stream:
return self._stream_chat_completion(request)
else:
results = await self._nonstream_chat_completion([request])
return results[0]
async def _nonstream_chat_completion(
self, request_batch: list[ChatCompletionRequest]
) -> list[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: list[int] = []
logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl():
states = [ItemState() for _ in request_batch]
for token_results in self.generator.chat_completion(request_batch):
first = token_results[0]
if not first.finished and not first.ignore_token:
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
cprint(first.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
results = []
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
results.append(
ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=state.logprobs if first_request.logprobs else None,
)
)
return results
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
tokens = []
logprobs = []
stop_reason = None
ipython = False
for token_results in self.generator.chat_completion([request]):
token_result = token_results[0]
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.started,
),
)
)
continue
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
tool_call=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = TextDelta(text=text)
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")

View file

@ -4,21 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.inference import (
InferenceProvider,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAICompletion
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
@ -73,21 +71,6 @@ class SentenceTransformersInferenceImpl(
async def unregister_model(self, model_id: str) -> None:
pass
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
async def openai_completion(
self,
# Standard OpenAI completion parameters
@ -115,3 +98,31 @@ class SentenceTransformersInferenceImpl(
suffix: str | None = None,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")

View file

@ -5,39 +5,30 @@
# the root directory of this source tree.
import json
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import AsyncIterator
from typing import Any
from botocore.client import BaseClient
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAICompletion
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_strategy_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
@ -86,7 +77,6 @@ def _to_inference_profile_id(model_id: str, region: str = None) -> str:
class BedrockInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
@ -106,71 +96,6 @@ class BedrockInferenceAdapter(
if self._client is not None:
self._client.close()
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params_for_chat_completion(request)
res = self.client.invoke_model(**params)
chunk = next(res["body"])
result = json.loads(chunk.decode("utf-8"))
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"],
text=result["generation"],
)
response = OpenAICompatCompletionResponse(choices=[choice])
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_chat_completion(request)
res = self.client.invoke_model_with_response_stream(**params)
event_stream = res["body"]
async def _generate_and_convert_to_openai_compat():
for chunk in event_stream:
chunk = chunk["chunk"]["bytes"]
result = json.loads(chunk.decode("utf-8"))
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"],
text=result["generation"],
)
yield OpenAICompatCompletionResponse(choices=[choice])
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
bedrock_model = request.model
@ -235,3 +160,31 @@ class BedrockInferenceAdapter(
suffix: str | None = None,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from urllib.parse import urljoin
from cerebras.cloud.sdk import AsyncCerebras
@ -12,17 +11,8 @@ from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
CompletionResponse,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import (
@ -30,8 +20,6 @@ from llama_stack.providers.utils.inference.model_registry import (
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -68,55 +56,6 @@ class CerebrasInferenceAdapter(
async def shutdown(self) -> None:
pass
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self._cerebras_client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self._cerebras_client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
raise ValueError("`top_k` not supported by Cerebras")

View file

@ -4,25 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from typing import Any
from databricks.sdk import WorkspaceClient
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
Inference,
LogProbConfig,
Message,
Model,
OpenAICompletion,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
@ -83,21 +72,6 @@ class DatabricksInferenceAdapter(
) -> OpenAICompletion:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
raise NotImplementedError()
async def list_models(self) -> list[Model] | None:
self._model_cache = {} # from OpenAIMixin
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async

View file

@ -4,23 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from fireworks.client import Fireworks
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
@ -30,8 +23,6 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -80,67 +71,6 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
fireworks_api_key = self.get_api_key()
return Fireworks(api_key=fireworks_api_key)
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
"""Remove BOS token as Fireworks automatically prepends it"""
if prompt.startswith("<|begin_of_text|>"):
return prompt[len("<|begin_of_text|>") :]
return prompt
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self._get_client().chat.completions.acreate(**params)
else:
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _to_async_generator():
if "messages" in params:
stream = self._get_client().chat.completions.acreate(**params)
else:
stream = self._get_client().completion.acreate(**params)
async for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _build_options(
self,
sampling_params: SamplingParams | None,

View file

@ -4,38 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from collections.abc import AsyncIterator
from openai import NOT_GIVEN, APIConnectionError
from openai import NOT_GIVEN
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference.openai_compat import (
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from . import NVIDIAConfig
from .openai_utils import (
convert_chat_completion_request,
)
from .utils import _is_nvidia_hosted
logger = get_logger(name=__name__, category="inference::nvidia")
@ -149,49 +130,3 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
model=response.model,
usage=usage,
)
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
# await check_health(self._config) # this raises errors
provider_model_id = await self._get_provider_model_id(model_id)
request = await convert_chat_completion_request(
request=ChatCompletionRequest(
model=provider_model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
),
n=1,
)
try:
response = await self.client.chat.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
if stream:
return convert_openai_chat_completion_stream(response, enable_incremental_tool_calls=False)
else:
# we pass n=1 to get only one completion
return convert_openai_chat_completion_choice(response.choices[0])

View file

@ -6,7 +6,6 @@
import asyncio
from collections.abc import AsyncGenerator
from typing import Any
from ollama import AsyncClient as AsyncOllamaClient
@ -18,19 +17,10 @@ from llama_stack.apis.common.content_types import (
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
GrammarResponseFormat,
InferenceProvider,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
@ -46,11 +36,7 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -161,39 +147,6 @@ class OllamaInferenceAdapter(
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
response_format=response_format,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _get_params(self, request: ChatCompletionRequest) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
@ -233,57 +186,6 @@ class OllamaInferenceAdapter(
return params
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self.ollama_client.chat(**params)
else:
r = await self.ollama_client.generate(**params)
if "message" in r:
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["message"]["content"],
)
else:
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
if "messages" in params:
s = await self.ollama_client.chat(**params)
else:
s = await self.ollama_client.generate(**params)
async for chunk in s:
if "message" in chunk:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["message"]["content"],
)
else:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def register_model(self, model: Model) -> Model:
if await self.check_model_availability(model.provider_model_id):
return model

View file

@ -4,33 +4,22 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import AsyncIterator
from typing import Any
from llama_stack_client import AsyncLlamaStackClient
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionMessage,
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.core.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.core.library_client import convert_pydantic_to_json_value
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
@ -85,76 +74,6 @@ class PassthroughInferenceAdapter(Inference):
provider_data=provider_data,
)
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
# TODO: revisit this remove tool_calls from messages logic
for message in messages:
if hasattr(message, "tool_calls"):
message.tool_calls = None
request_params = {
"model_id": model.provider_resource_id,
"messages": messages,
"sampling_params": sampling_params,
"tools": tools,
"tool_choice": tool_choice,
"tool_prompt_format": tool_prompt_format,
"response_format": response_format,
"stream": stream,
"logprobs": logprobs,
}
# only pass through the not None params
request_params = {key: value for key, value in request_params.items() if value is not None}
# cast everything to json dict
json_params = self.cast_value_to_json_dict(request_params)
if stream:
return self._stream_chat_completion(json_params)
else:
return await self._nonstream_chat_completion(json_params)
async def _nonstream_chat_completion(self, json_params: dict[str, Any]) -> ChatCompletionResponse:
client = self._get_client()
response = await client.inference.chat_completion(**json_params)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=response.completion_message.content.text,
stop_reason=response.completion_message.stop_reason,
tool_calls=response.completion_message.tool_calls,
),
logprobs=response.logprobs,
)
async def _stream_chat_completion(self, json_params: dict[str, Any]) -> AsyncGenerator:
client = self._get_client()
stream_response = await client.inference.chat_completion(**json_params)
async for chunk in stream_response:
chunk = chunk.to_dict()
# temporary hack to remove the metrics from the response
chunk["metrics"] = []
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
yield chunk
async def openai_embeddings(
self,
model: str,

View file

@ -3,9 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
@ -13,10 +11,7 @@ from llama_stack.apis.inference import OpenAIEmbeddingsResponse
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
@ -53,7 +48,6 @@ MODEL_ENTRIES = [
class RunpodInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
):
def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
@ -65,56 +59,6 @@ class RunpodInferenceAdapter(
async def shutdown(self) -> None:
pass
async def chat_completion(
self,
model: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream:
return self._stream_chat_completion(request, client)
else:
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),

View file

@ -5,25 +5,16 @@
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from huggingface_hub import AsyncInferenceClient, HfApi
from pydantic import SecretStr
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.apis.models.models import ModelType
@ -35,11 +26,7 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -148,68 +135,6 @@ class _HfAdapter(
return options
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = await self.hf_client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
text="".join(t.text for t in r.details.tokens),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.hf_client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
choice = OpenAICompatCompletionChoice(text=token_result.text)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = await chat_completion_request_to_model_input_info(
request, self.register_helper.get_llama_model(request.model)

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from openai import AsyncOpenAI
from together import AsyncTogether
@ -12,18 +11,12 @@ from together.constants import BASE_URL
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
from llama_stack.apis.models import Model, ModelType
@ -33,8 +26,6 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -122,58 +113,6 @@ class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Need
return options
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
client = self._get_client()
if "messages" in params:
r = await client.chat.completions.create(**params)
else:
r = await client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
client = self._get_client()
if "messages" in params:
stream = await client.chat.completions.create(**params)
else:
stream = await client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
media_present = request_has_media(request)

View file

@ -9,7 +9,7 @@ from typing import Any
from urllib.parse import urljoin
import httpx
from openai import APIConnectionError, AsyncOpenAI
from openai import APIConnectionError
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
@ -21,23 +21,18 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ModelStore,
ResponseFormat,
SamplingParams,
OpenAIChatCompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
@ -56,10 +51,8 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
UnparseableToolCall,
convert_message_to_openai_dict,
convert_openai_chat_completion_stream,
convert_tool_call,
get_sampling_options,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -353,90 +346,6 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# * https://github.com/vllm-project/vllm/pull/10000
if not tools and tool_config is not None:
tool_config.tool_choice = ToolChoice.none
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
response_format=response_format,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion_with_client(request, self.client)
else:
return await self._nonstream_chat_completion(request, self.client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> ChatCompletionResponse:
assert self.client is not None
params = await self._get_params(request)
r = await client.chat.completions.create(**params)
choice = r.choices[0]
result = ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content or "",
stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason),
tool_calls=_convert_to_vllm_tool_calls_in_response(choice.message.tool_calls),
),
logprobs=None,
)
return result
async def _stream_chat_completion(self, response: Any) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
# This method is called from LiteLLMOpenAIMixin.chat_completion
# The response parameter contains the litellm response
# We need to convert it to our format
async def _stream_generator():
async for chunk in response:
yield chunk
async for chunk in convert_openai_chat_completion_stream(
_stream_generator(), enable_incremental_tool_calls=True
):
yield chunk
async def _stream_chat_completion_with_client(
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
"""Helper method for streaming with explicit client parameter."""
assert self.client is not None
params = await self._get_params(request)
stream = await client.chat.completions.create(**params)
if request.tools:
res = _process_vllm_chat_completion_stream_response(stream)
else:
res = process_chat_completion_stream_response(stream, request)
async for chunk in res:
yield chunk
async def register_model(self, model: Model) -> Model:
try:
model = await self.register_helper.register_model(model)
@ -485,3 +394,64 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
"stream": request.stream,
**options,
}
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
max_tokens = max_tokens or self.config.max_tokens
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# * https://github.com/vllm-project/vllm/pull/10000
if not tools and tool_choice is not None:
tool_choice = ToolChoice.none.value
return await super().openai_chat_completion(
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)

View file

@ -13,35 +13,22 @@ from openai import AsyncOpenAI
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
GreedySamplingStrategy,
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
@ -100,74 +87,6 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
)
return self._openai_client
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client(request.model).generate(**params)
choices = []
if "results" in r:
for result in r["results"]:
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
text=result["generated_text"],
)
choices.append(choice)
response = OpenAICompatCompletionResponse(
choices=choices,
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
model_id = request.model
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self._get_client(model_id).generate_text_stream(**params)
for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=None,
text=chunk,
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
input_dict = {"params": {}}
media_present = request_has_media(request)

View file

@ -11,12 +11,8 @@ import litellm
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
InferenceProvider,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
@ -24,12 +20,7 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
@ -37,8 +28,6 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
from llama_stack.providers.utils.inference.openai_compat import (
b64_encode_openai_embeddings_response,
convert_message_to_openai_dict_new,
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
convert_tooldef_to_openai_tool,
get_sampling_options,
prepare_openai_completion_params,
@ -105,57 +94,6 @@ class LiteLLMOpenAIMixin(
else model_id
)
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
params = await self._get_params(request)
params["model"] = self.get_litellm_model_name(params["model"])
logger.debug(f"params to litellm (openai compat): {params}")
# see https://docs.litellm.ai/docs/completion/stream#async-completion
response = await litellm.acompletion(**params)
if stream:
return self._stream_chat_completion(response)
else:
return convert_openai_chat_completion_choice(response.choices[0])
async def _stream_chat_completion(
self, response: litellm.ModelResponse
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
async def _stream_generator():
async for chunk in response:
yield chunk
async for chunk in convert_openai_chat_completion_stream(
_stream_generator(), enable_incremental_tool_calls=True
):
yield chunk
def _add_additional_properties_recursive(self, schema):
"""
Recursively add additionalProperties: False to all object schemas

View file

@ -30,18 +30,14 @@ from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
CompletionMessage,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
SystemMessage,
ToolChoice,
ToolConfig,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.models import Model
from llama_stack.models.llama.datatypes import StopReason, ToolCall
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import (
@ -99,67 +95,24 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
with patch.object(vllm_inference_adapter, "_nonstream_chat_completion") as mock_nonstream_completion:
# Patch the client property to avoid instantiating a real AsyncOpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock()
mock_client_property.return_value = mock_client
# No tools but auto tool choice
await vllm_inference_adapter.chat_completion(
await vllm_inference_adapter.openai_chat_completion(
"mock-model",
[],
stream=False,
tools=None,
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
tool_choice=ToolChoice.auto.value,
)
mock_nonstream_completion.assert_called()
request = mock_nonstream_completion.call_args.args[0]
mock_client.chat.completions.create.assert_called()
call_args = mock_client.chat.completions.create.call_args
# Ensure tool_choice gets converted to none for older vLLM versions
assert request.tool_config.tool_choice == ToolChoice.none
async def test_tool_call_response(vllm_inference_adapter):
"""Verify that tool call arguments from a CompletionMessage are correctly converted
into the expected JSON format."""
# Patch the client property to avoid instantiating a real AsyncOpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock()
mock_create_client.return_value = mock_client
# Mock the model to return a proper provider_resource_id
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
messages = [
SystemMessage(content="You are a helpful assistant"),
UserMessage(content="How many?"),
CompletionMessage(
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
call_id="foo",
tool_name="knowledge_search",
arguments={"query": "How many?"},
arguments_json='{"query": "How many?"}',
)
],
),
ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."),
]
await vllm_inference_adapter.chat_completion(
"mock-model",
messages,
stream=False,
tools=[],
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
assert mock_client.chat.completions.create.call_args.kwargs["messages"][2]["tool_calls"] == [
{
"id": "foo",
"type": "function",
"function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'},
}
]
assert call_args.kwargs["tool_choice"] == ToolChoice.none.value
async def test_tool_call_delta_empty_tool_call_buf():
@ -745,12 +698,10 @@ async def test_provider_data_var_context_propagation(vllm_inference_adapter):
try:
# Execute chat completion
await vllm_inference_adapter.chat_completion(
"test-model",
[UserMessage(content="Hello")],
await vllm_inference_adapter.openai_chat_completion(
model="test-model",
messages=[UserMessage(content="Hello")],
stream=False,
tools=None,
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
# Verify that ALL client calls were made with the correct parameters