mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 05:38:38 +00:00
chore: refactor (chat)completions endpoints to use shared params struct (#3761)
# What does this PR do? Converts openai(_chat)_completions params to pydantic BaseModel to reduce code duplication across all providers. ## Test Plan CI --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/llamastack/llama-stack/pull/3761). * #3777 * __->__ #3761
This commit is contained in:
parent
6954fe2274
commit
80d58ab519
33 changed files with 599 additions and 890 deletions
|
@ -54,6 +54,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
@ -383,7 +384,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
|
||||
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
@ -446,7 +447,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
body = self._convert_body(path, options.method, body)
|
||||
# Prepare body for the function call (handles both Pydantic and traditional params)
|
||||
body = self._convert_body(func, body)
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
@ -493,17 +495,20 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return await response.parse()
|
||||
|
||||
def _convert_body(
|
||||
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
|
||||
) -> dict:
|
||||
def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||
exclude_params = exclude_params or set()
|
||||
|
||||
func, _, _, _ = find_matching_route(method, path, self.route_impls)
|
||||
sig = inspect.signature(func)
|
||||
params_list = [p for p in sig.parameters.values() if p.name != "self"]
|
||||
# Flatten if there's a single unwrapped body parameter (BaseModel or Annotated[BaseModel, Body(embed=False)])
|
||||
if len(params_list) == 1:
|
||||
param = params_list[0]
|
||||
param_type = param.annotation
|
||||
if is_unwrapped_body_param(param_type):
|
||||
base_type = get_args(param_type)[0]
|
||||
return {param.name: base_type(**body)}
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
|
||||
|
|
|
@ -10,9 +10,10 @@ from collections.abc import AsyncGenerator, AsyncIterator
|
|||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import Body
|
||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import Field, TypeAdapter
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
@ -31,15 +32,16 @@ from llama_stack.apis.inference import (
|
|||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequest,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIChoiceLogprobs,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequest,
|
||||
OpenAICompletionWithInputMessages,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
Order,
|
||||
StopReason,
|
||||
ToolPromptFormat,
|
||||
|
@ -181,61 +183,23 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
params: Annotated[OpenAICompletionRequest, Body(...)],
|
||||
) -> OpenAICompletion:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||
)
|
||||
model_obj = await self._get_model(model, ModelType.llm)
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
suffix=suffix,
|
||||
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
|
||||
)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
return await provider.openai_completion(**params)
|
||||
if params.stream:
|
||||
return await provider.openai_completion(params)
|
||||
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
||||
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
||||
# response_stream = await provider.openai_completion(**params)
|
||||
|
||||
response = await provider.openai_completion(**params)
|
||||
response = await provider.openai_completion(params)
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
|
@ -254,93 +218,49 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
params: Annotated[OpenAIChatCompletionRequest, Body(...)],
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
|
||||
)
|
||||
model_obj = await self._get_model(model, ModelType.llm)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
|
||||
# Use the OpenAI client for a bit of extra input validation without
|
||||
# exposing the OpenAI client itself as part of our API surface
|
||||
if tool_choice:
|
||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
|
||||
if tools is None:
|
||||
if params.tool_choice:
|
||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
|
||||
if params.tools is None:
|
||||
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if params.tools:
|
||||
for tool in params.tools:
|
||||
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||
|
||||
# Some providers make tool calls even when tool_choice is "none"
|
||||
# so just clear them both out to avoid unexpected tool calls
|
||||
if tool_choice == "none" and tools is not None:
|
||||
tool_choice = None
|
||||
tools = None
|
||||
if params.tool_choice == "none" and params.tools is not None:
|
||||
params.tool_choice = None
|
||||
params.tools = None
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
response_stream = await provider.openai_chat_completion(**params)
|
||||
if params.stream:
|
||||
response_stream = await provider.openai_chat_completion(params)
|
||||
|
||||
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
|
||||
# We need to add metrics to each chunk and store the final completion
|
||||
return self.stream_tokens_and_compute_metrics_openai_chat(
|
||||
response=response_stream,
|
||||
model=model_obj,
|
||||
messages=messages,
|
||||
messages=params.messages,
|
||||
)
|
||||
|
||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||
|
||||
# Store the response with the ID that will be returned to the client
|
||||
if self.store:
|
||||
asyncio.create_task(self.store.store_chat_completion(response, messages))
|
||||
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
|
||||
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
|
@ -396,8 +316,10 @@ class InferenceRouter(Inference):
|
|||
return await self.store.get_chat_completion(completion_id)
|
||||
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
||||
|
||||
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
||||
response = await provider.openai_chat_completion(**params)
|
||||
async def _nonstream_openai_chat_completion(
|
||||
self, provider: Inference, params: OpenAIChatCompletionRequest
|
||||
) -> OpenAIChatCompletion:
|
||||
response = await provider.openai_chat_completion(params)
|
||||
for choice in response.choices:
|
||||
# some providers return an empty list for no tool calls in non-streaming responses
|
||||
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
||||
|
|
|
@ -184,7 +184,17 @@ async def lifespan(app: StackApp):
|
|||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
# If there's a stream parameter at top level, use it
|
||||
if "stream" in kwargs:
|
||||
return kwargs["stream"]
|
||||
|
||||
# If there's a stream parameter inside a "params" parameter, e.g. openai_chat_completion() use it
|
||||
if "params" in kwargs:
|
||||
params = kwargs["params"]
|
||||
if hasattr(params, "stream"):
|
||||
return params.stream
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def maybe_await(value):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue