mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 13:44:38 +00:00
test
# What does this PR do? ## Test Plan
This commit is contained in:
parent
f50ce11a3b
commit
4a3d1e33f8
31 changed files with 727 additions and 892 deletions
|
@ -383,7 +383,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
|
||||
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
@ -446,7 +446,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
body = self._convert_body(path, options.method, body)
|
||||
# Prepare body for the function call (handles both Pydantic and traditional params)
|
||||
body = self._convert_body(func, body)
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
@ -493,17 +494,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return await response.parse()
|
||||
|
||||
def _convert_body(
|
||||
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
|
||||
) -> dict:
|
||||
def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||
exclude_params = exclude_params or set()
|
||||
|
||||
func, _, _, _ = find_matching_route(method, path, self.route_impls)
|
||||
sig = inspect.signature(func)
|
||||
params_list = [p for p in sig.parameters.values() if p.name != "self"]
|
||||
|
||||
# Check if the method expects a single Pydantic model parameter
|
||||
if len(params_list) == 1:
|
||||
param = params_list[0]
|
||||
param_type = param.annotation
|
||||
if issubclass(param_type, BaseModel):
|
||||
# Strip NOT_GIVENs before passing to Pydantic
|
||||
clean_body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
|
||||
|
||||
# If the body has a single key matching the parameter name, unwrap it
|
||||
if len(clean_body) == 1 and param.name in clean_body:
|
||||
clean_body = clean_body[param.name]
|
||||
|
||||
return {param.name: param_type(**clean_body)}
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
|
||||
|
|
|
@ -8,11 +8,11 @@ import asyncio
|
|||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import Field, TypeAdapter
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
@ -31,15 +31,16 @@ from llama_stack.apis.inference import (
|
|||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenaiChatCompletionRequest,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIChoiceLogprobs,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequest,
|
||||
OpenAICompletionWithInputMessages,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
Order,
|
||||
StopReason,
|
||||
ToolPromptFormat,
|
||||
|
@ -181,61 +182,23 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
params: OpenAICompletionRequest,
|
||||
) -> OpenAICompletion:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||
)
|
||||
model_obj = await self._get_model(model, ModelType.llm)
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
suffix=suffix,
|
||||
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
|
||||
)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
return await provider.openai_completion(**params)
|
||||
if params.stream:
|
||||
return await provider.openai_completion(params)
|
||||
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
||||
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
||||
# response_stream = await provider.openai_completion(**params)
|
||||
|
||||
response = await provider.openai_completion(**params)
|
||||
response = await provider.openai_completion(params)
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
|
@ -254,93 +217,49 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenaiChatCompletionRequest,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
|
||||
)
|
||||
model_obj = await self._get_model(model, ModelType.llm)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
|
||||
# Use the OpenAI client for a bit of extra input validation without
|
||||
# exposing the OpenAI client itself as part of our API surface
|
||||
if tool_choice:
|
||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
|
||||
if tools is None:
|
||||
if params.tool_choice:
|
||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
|
||||
if params.tools is None:
|
||||
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if params.tools:
|
||||
for tool in params.tools:
|
||||
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||
|
||||
# Some providers make tool calls even when tool_choice is "none"
|
||||
# so just clear them both out to avoid unexpected tool calls
|
||||
if tool_choice == "none" and tools is not None:
|
||||
tool_choice = None
|
||||
tools = None
|
||||
if params.tool_choice == "none" and params.tools is not None:
|
||||
params.tool_choice = None
|
||||
params.tools = None
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
response_stream = await provider.openai_chat_completion(**params)
|
||||
if params.stream:
|
||||
response_stream = await provider.openai_chat_completion(params)
|
||||
|
||||
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
|
||||
# We need to add metrics to each chunk and store the final completion
|
||||
return self.stream_tokens_and_compute_metrics_openai_chat(
|
||||
response=response_stream,
|
||||
model=model_obj,
|
||||
messages=messages,
|
||||
messages=params.messages,
|
||||
)
|
||||
|
||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||
|
||||
# Store the response with the ID that will be returned to the client
|
||||
if self.store:
|
||||
asyncio.create_task(self.store.store_chat_completion(response, messages))
|
||||
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
|
||||
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
|
@ -396,8 +315,10 @@ class InferenceRouter(Inference):
|
|||
return await self.store.get_chat_completion(completion_id)
|
||||
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
||||
|
||||
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
||||
response = await provider.openai_chat_completion(**params)
|
||||
async def _nonstream_openai_chat_completion(
|
||||
self, provider: Inference, params: OpenaiChatCompletionRequest
|
||||
) -> OpenAIChatCompletion:
|
||||
response = await provider.openai_chat_completion(params)
|
||||
for choice in response.choices:
|
||||
# some providers return an empty list for no tool calls in non-streaming responses
|
||||
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
||||
|
|
|
@ -13,12 +13,13 @@ import logging # allow-direct-logging
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from importlib.metadata import version as parse_version
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, get_origin
|
||||
from typing import Annotated, Any, Union, get_origin
|
||||
|
||||
import httpx
|
||||
import rich.pretty
|
||||
|
@ -177,7 +178,17 @@ async def lifespan(app: StackApp):
|
|||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
# Check for stream parameter at top level (old API style)
|
||||
if "stream" in kwargs:
|
||||
return kwargs["stream"]
|
||||
|
||||
# Check for stream parameter inside Pydantic request params (new API style)
|
||||
if "params" in kwargs:
|
||||
params = kwargs["params"]
|
||||
if hasattr(params, "stream"):
|
||||
return params.stream
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def maybe_await(value):
|
||||
|
@ -282,21 +293,42 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
if method == "post":
|
||||
# Annotate parameters that are in the path with Path(...) and others with Body(...),
|
||||
# but preserve existing File() and Form() annotations for multipart form data
|
||||
new_params = (
|
||||
[new_params[0]]
|
||||
+ [
|
||||
(
|
||||
def get_body_embed_value(param: inspect.Parameter) -> bool:
|
||||
"""Determine if Body should use embed=True or embed=False.
|
||||
|
||||
For OpenAI-compatible endpoints (param name is 'params'), use embed=False
|
||||
so the request body is parsed directly as the model (not nested).
|
||||
This allows OpenAI clients to send standard OpenAI format.
|
||||
For other endpoints, use embed=True for SDK compatibility.
|
||||
"""
|
||||
# Get the actual type, stripping Optional/Union if present
|
||||
param_type = param.annotation
|
||||
origin = get_origin(param_type)
|
||||
# Check for Union types (both typing.Union and types.UnionType for | syntax)
|
||||
if origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
|
||||
# Handle Optional[T] / T | None
|
||||
args = param_type.__args__ if hasattr(param_type, "__args__") else []
|
||||
param_type = next((arg for arg in args if arg is not type(None)), param_type)
|
||||
|
||||
# Check if it's a Pydantic BaseModel and param name is 'params' (OpenAI-compatible)
|
||||
is_basemodel = issubclass(param_type, BaseModel)
|
||||
if is_basemodel and param.name == "params":
|
||||
return False # Use embed=False for OpenAI-compatible endpoints
|
||||
return True # Use embed=True for everything else
|
||||
|
||||
original_params = new_params[1:] # Skip request parameter
|
||||
new_params = [new_params[0]] # Keep request parameter
|
||||
|
||||
for param in original_params:
|
||||
if param.name in path_params:
|
||||
new_params.append(
|
||||
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
|
||||
if param.name in path_params
|
||||
else (
|
||||
param # Keep original annotation if it's already an Annotated type
|
||||
if get_origin(param.annotation) is Annotated
|
||||
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
||||
)
|
||||
)
|
||||
for param in new_params[1:]
|
||||
]
|
||||
)
|
||||
elif get_origin(param.annotation) is Annotated:
|
||||
new_params.append(param) # Keep existing annotation
|
||||
else:
|
||||
embed = get_body_embed_value(param)
|
||||
new_params.append(param.replace(annotation=Annotated[param.annotation, Body(..., embed=embed)]))
|
||||
|
||||
route_handler.__signature__ = sig.replace(parameters=new_params)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue