mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 05:38:38 +00:00
test
# What does this PR do? ## Test Plan
This commit is contained in:
parent
f50ce11a3b
commit
4a3d1e33f8
31 changed files with 727 additions and 892 deletions
|
@ -8,11 +8,11 @@ import asyncio
|
|||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import Field, TypeAdapter
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
@ -31,15 +31,16 @@ from llama_stack.apis.inference import (
|
|||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenaiChatCompletionRequest,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIChoiceLogprobs,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequest,
|
||||
OpenAICompletionWithInputMessages,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
Order,
|
||||
StopReason,
|
||||
ToolPromptFormat,
|
||||
|
@ -181,61 +182,23 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
params: OpenAICompletionRequest,
|
||||
) -> OpenAICompletion:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||
)
|
||||
model_obj = await self._get_model(model, ModelType.llm)
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
suffix=suffix,
|
||||
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
|
||||
)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
return await provider.openai_completion(**params)
|
||||
if params.stream:
|
||||
return await provider.openai_completion(params)
|
||||
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
||||
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
||||
# response_stream = await provider.openai_completion(**params)
|
||||
|
||||
response = await provider.openai_completion(**params)
|
||||
response = await provider.openai_completion(params)
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
|
@ -254,93 +217,49 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenaiChatCompletionRequest,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
|
||||
)
|
||||
model_obj = await self._get_model(model, ModelType.llm)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
|
||||
# Use the OpenAI client for a bit of extra input validation without
|
||||
# exposing the OpenAI client itself as part of our API surface
|
||||
if tool_choice:
|
||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
|
||||
if tools is None:
|
||||
if params.tool_choice:
|
||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
|
||||
if params.tools is None:
|
||||
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if params.tools:
|
||||
for tool in params.tools:
|
||||
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||
|
||||
# Some providers make tool calls even when tool_choice is "none"
|
||||
# so just clear them both out to avoid unexpected tool calls
|
||||
if tool_choice == "none" and tools is not None:
|
||||
tool_choice = None
|
||||
tools = None
|
||||
if params.tool_choice == "none" and params.tools is not None:
|
||||
params.tool_choice = None
|
||||
params.tools = None
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
response_stream = await provider.openai_chat_completion(**params)
|
||||
if params.stream:
|
||||
response_stream = await provider.openai_chat_completion(params)
|
||||
|
||||
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
|
||||
# We need to add metrics to each chunk and store the final completion
|
||||
return self.stream_tokens_and_compute_metrics_openai_chat(
|
||||
response=response_stream,
|
||||
model=model_obj,
|
||||
messages=messages,
|
||||
messages=params.messages,
|
||||
)
|
||||
|
||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||
|
||||
# Store the response with the ID that will be returned to the client
|
||||
if self.store:
|
||||
asyncio.create_task(self.store.store_chat_completion(response, messages))
|
||||
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
|
||||
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
|
@ -396,8 +315,10 @@ class InferenceRouter(Inference):
|
|||
return await self.store.get_chat_completion(completion_id)
|
||||
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
||||
|
||||
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
||||
response = await provider.openai_chat_completion(**params)
|
||||
async def _nonstream_openai_chat_completion(
|
||||
self, provider: Inference, params: OpenaiChatCompletionRequest
|
||||
) -> OpenAIChatCompletion:
|
||||
response = await provider.openai_chat_completion(params)
|
||||
for choice in response.choices:
|
||||
# some providers return an empty list for no tool calls in non-streaming responses
|
||||
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue