# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-10-09 20:53:19 -07:00
parent f50ce11a3b
commit 4a3d1e33f8
31 changed files with 727 additions and 892 deletions

View file

@ -49,6 +49,7 @@ from llama_stack.apis.inference import (
Inference,
Message,
OpenAIAssistantMessageParam,
OpenaiChatCompletionRequest,
OpenAIDeveloperMessageParam,
OpenAIMessageParam,
OpenAISystemMessageParam,
@ -582,7 +583,7 @@ class ChatAgent(ShieldRunnerMixin):
max_tokens = getattr(sampling_params, "max_tokens", None)
# Use OpenAI chat completion
openai_stream = await self.inference_api.openai_chat_completion(
params = OpenaiChatCompletionRequest(
model=self.agent_config.model,
messages=openai_messages,
tools=openai_tools if openai_tools else None,
@ -593,6 +594,7 @@ class ChatAgent(ShieldRunnerMixin):
max_tokens=max_tokens,
stream=True,
)
openai_stream = await self.inference_api.openai_chat_completion(params)
# Convert OpenAI stream back to Llama Stack format
response_stream = convert_openai_chat_completion_stream(

View file

@ -41,6 +41,7 @@ from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenaiChatCompletionRequest,
OpenAIChatCompletionToolCall,
OpenAIChoice,
OpenAIMessageParam,
@ -130,7 +131,7 @@ class StreamingResponseOrchestrator:
# (some providers don't support non-empty response_format when tools are present)
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
completion_result = await self.inference_api.openai_chat_completion(
params = OpenaiChatCompletionRequest(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
@ -138,6 +139,7 @@ class StreamingResponseOrchestrator:
temperature=self.ctx.temperature,
response_format=response_format,
)
completion_result = await self.inference_api.openai_chat_completion(params)
# Process streaming chunks and build complete response
completion_result_data = None

View file

@ -22,6 +22,8 @@ from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
OpenAIDeveloperMessageParam,
OpenAIMessageParam,
OpenAISystemMessageParam,
@ -601,7 +603,8 @@ class ReferenceBatchesImpl(Batches):
# TODO(SECURITY): review body for security issues
if request.url == "/v1/chat/completions":
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body)
params = OpenaiChatCompletionRequest(**request.body)
chat_response = await self.inference_api.openai_chat_completion(params)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
@ -615,7 +618,8 @@ class ReferenceBatchesImpl(Batches):
},
}
else: # /v1/completions
completion_response = await self.inference_api.openai_completion(**request.body)
params = OpenAICompletionRequest(**request.body)
completion_response = await self.inference_api.openai_completion(params)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(completion_response, "model_dump_json"), (

View file

@ -12,7 +12,14 @@ from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, OpenAISystemMessageParam, OpenAIUserMessageParam, UserMessage
from llama_stack.apis.inference import (
Inference,
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
OpenAISystemMessageParam,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
@ -168,11 +175,12 @@ class MetaReferenceEvalImpl(
sampling_params["stop"] = candidate.sampling_params.stop
input_content = json.loads(x[ColumnName.completion_input.value])
response = await self.inference_api.openai_completion(
params = OpenAICompletionRequest(
model=candidate.model,
prompt=input_content,
**sampling_params,
)
response = await self.inference_api.openai_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
@ -187,11 +195,12 @@ class MetaReferenceEvalImpl(
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += input_messages
response = await self.inference_api.openai_chat_completion(
params = OpenaiChatCompletionRequest(
model=candidate.model,
messages=messages,
**sampling_params,
)
response = await self.inference_api.openai_chat_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
else:
raise ValueError("Invalid input row")

View file

@ -6,16 +6,16 @@
import asyncio
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.inference import (
InferenceProvider,
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAICompletion,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
@ -65,7 +65,10 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group:
self.generator.stop()
async def openai_completion(self, *args, **kwargs):
async def openai_completion(
self,
params: OpenAICompletionRequest,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
async def should_refresh_models(self) -> bool:
@ -150,28 +153,6 @@ class MetaReferenceInferenceImpl(
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenaiChatCompletionRequest,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")

View file

@ -5,17 +5,16 @@
# the root directory of this source tree.
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.inference import (
InferenceProvider,
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
@ -73,56 +72,12 @@ class SentenceTransformersInferenceImpl(
async def openai_completion(
self,
# Standard OpenAI completion parameters
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
params: OpenAICompletionRequest,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenaiChatCompletionRequest,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")

View file

@ -10,7 +10,13 @@ from string import Template
from typing import Any
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import Inference, Message, UserMessage
from llama_stack.apis.inference import (
Inference,
Message,
OpenaiChatCompletionRequest,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -290,20 +296,21 @@ class LlamaGuardShield:
else:
shield_input_message = self.build_text_shield_input(messages)
response = await self.inference_api.openai_chat_completion(
params = OpenaiChatCompletionRequest(
model=self.model,
messages=[shield_input_message],
stream=False,
temperature=0.0, # default is 1, which is too high for safety
)
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content
content = content.strip()
return self.get_shield_response(content)
def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
return UserMessage(content=self.build_prompt(messages))
def build_text_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
return OpenAIUserMessageParam(role="user", content=self.build_prompt(messages))
def build_vision_shield_input(self, messages: list[Message]) -> UserMessage:
def build_vision_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
conversation = []
most_recent_img = None
@ -335,7 +342,7 @@ class LlamaGuardShield:
prompt.append(most_recent_img)
prompt.append(self.build_prompt(conversation[::-1]))
return UserMessage(content=prompt)
return OpenAIUserMessageParam(role="user", content=prompt)
def build_prompt(self, messages: list[Message]) -> str:
categories = self.get_safety_categories()
@ -377,11 +384,12 @@ class LlamaGuardShield:
# TODO: Add Image based support for OpenAI Moderations
shield_input_message = self.build_text_shield_input(messages)
response = await self.inference_api.openai_chat_completion(
params = OpenaiChatCompletionRequest(
model=self.model,
messages=[shield_input_message],
stream=False,
)
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content
content = content.strip()
return self.get_moderation_object(content)

View file

@ -6,7 +6,7 @@
import re
from typing import Any
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import Inference, OpenaiChatCompletionRequest
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
@ -55,7 +55,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
generated_answer=generated_answer,
)
judge_response = await self.inference_api.openai_chat_completion(
params = OpenaiChatCompletionRequest(
model=fn_def.params.judge_model,
messages=[
{
@ -64,6 +64,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
}
],
)
judge_response = await self.inference_api.openai_chat_completion(params)
content = judge_response.choices[0].message.content
rating_regexes = fn_def.params.judge_score_regexes

View file

@ -8,7 +8,7 @@
from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import OpenAIUserMessageParam
from llama_stack.apis.inference import OpenaiChatCompletionRequest, OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
@ -65,11 +65,12 @@ async def llm_rag_query_generator(
model = config.model
message = OpenAIUserMessageParam(content=rendered_content)
response = await inference_api.openai_chat_completion(
params = OpenaiChatCompletionRequest(
model=model,
messages=[message],
stream=False,
)
response = await inference_api.openai_chat_completion(params)
query = response.choices[0].message.content