Ran precommit

This commit is contained in:
Omar Abdelwahab 2025-10-06 13:27:19 -07:00
parent 9886520b40
commit 9fc0d966f6
7 changed files with 153 additions and 310 deletions

View file

@ -8,6 +8,9 @@ from collections.abc import AsyncIterator
from enum import Enum from enum import Enum
from typing import Annotated, Any, Literal, Protocol, runtime_checkable from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import Order from llama_stack.apis.common.responses import Order
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
@ -23,9 +26,6 @@ from llama_stack.models.llama.datatypes import (
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
register_schema(ToolCall) register_schema(ToolCall)
register_schema(ToolDefinition) register_schema(ToolDefinition)
@ -381,9 +381,7 @@ class ToolConfig(BaseModel):
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto) tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
tool_prompt_format: ToolPromptFormat | None = Field(default=None) tool_prompt_format: ToolPromptFormat | None = Field(default=None)
system_message_behavior: SystemMessageBehavior | None = Field( system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
default=SystemMessageBehavior.append
)
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
if isinstance(self.tool_choice, str): if isinstance(self.tool_choice, str):
@ -512,21 +510,15 @@ class OpenAIFile(BaseModel):
OpenAIChatCompletionContentPartParam = Annotated[ OpenAIChatCompletionContentPartParam = Annotated[
OpenAIChatCompletionContentPartTextParam OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile,
| OpenAIChatCompletionContentPartImageParam
| OpenAIFile,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema( register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam"
)
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam] OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
OpenAIChatCompletionTextOnlyMessageContent = ( OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
str | list[OpenAIChatCompletionContentPartTextParam]
)
@json_schema_type @json_schema_type
@ -694,9 +686,7 @@ class OpenAIResponseFormatJSONObject(BaseModel):
OpenAIResponseFormatParam = Annotated[ OpenAIResponseFormatParam = Annotated[
OpenAIResponseFormatText OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
| OpenAIResponseFormatJSONSchema
| OpenAIResponseFormatJSONObject,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam") register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
@ -986,16 +976,8 @@ class InferenceProvider(Protocol):
async def rerank( async def rerank(
self, self,
model: str, model: str,
query: ( query: (str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam),
str items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
| OpenAIChatCompletionContentPartTextParam
| OpenAIChatCompletionContentPartImageParam
),
items: list[
str
| OpenAIChatCompletionContentPartTextParam
| OpenAIChatCompletionContentPartImageParam
],
max_num_results: int | None = None, max_num_results: int | None = None,
) -> RerankResponse: ) -> RerankResponse:
"""Rerank a list of documents based on their relevance to a query. """Rerank a list of documents based on their relevance to a query.

View file

@ -7,9 +7,17 @@
import asyncio import asyncio
import time import time
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from datetime import datetime, UTC from datetime import UTC, datetime
from typing import Annotated, Any from typing import Annotated, Any
from openai.types.chat import (
ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam,
)
from openai.types.chat import (
ChatCompletionToolParam as OpenAIChatCompletionToolParam,
)
from pydantic import Field, TypeAdapter
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -48,12 +56,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
get_current_span, get_current_span,
) )
from openai.types.chat import (
ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam,
ChatCompletionToolParam as OpenAIChatCompletionToolParam,
)
from pydantic import Field, TypeAdapter
logger = get_logger(name=__name__, category="core::routers") logger = get_logger(name=__name__, category="core::routers")
@ -96,9 +98,7 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
) )
await self.routing_table.register_model( await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
model_id, provider_model_id, provider_id, metadata, model_type
)
def _construct_metrics( def _construct_metrics(
self, self,
@ -153,16 +153,11 @@ class InferenceRouter(Inference):
total_tokens: int, total_tokens: int,
model: Model, model: Model,
) -> list[MetricInResponse]: ) -> list[MetricInResponse]:
metrics = self._construct_metrics( metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
prompt_tokens, completion_tokens, total_tokens, model
)
if self.telemetry: if self.telemetry:
for metric in metrics: for metric in metrics:
enqueue_event(metric) enqueue_event(metric)
return [ return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
async def _count_tokens( async def _count_tokens(
self, self,
@ -256,9 +251,7 @@ class InferenceRouter(Inference):
# these metrics will show up in the client response. # these metrics will show up in the client response.
response.metrics = ( response.metrics = (
metrics metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
) )
return response return response
@ -296,13 +289,9 @@ class InferenceRouter(Inference):
# Use the OpenAI client for a bit of extra input validation without # Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface # exposing the OpenAI client itself as part of our API surface
if tool_choice: if tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python( TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
tool_choice
)
if tools is None: if tools is None:
raise ValueError( raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
"'tool_choice' is only allowed when 'tools' is also provided"
)
if tools: if tools:
for tool in tools: for tool in tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool) TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
@ -367,9 +356,7 @@ class InferenceRouter(Inference):
enqueue_event(metric) enqueue_event(metric)
# these metrics will show up in the client response. # these metrics will show up in the client response.
response.metrics = ( response.metrics = (
metrics metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
) )
return response return response
@ -405,31 +392,19 @@ class InferenceRouter(Inference):
) -> ListOpenAIChatCompletionResponse: ) -> ListOpenAIChatCompletionResponse:
if self.store: if self.store:
return await self.store.list_chat_completions(after, limit, model, order) return await self.store.list_chat_completions(after, limit, model, order)
raise NotImplementedError( raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
"List chat completions is not supported: inference store is not configured."
)
async def get_chat_completion( async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
self, completion_id: str
) -> OpenAICompletionWithInputMessages:
if self.store: if self.store:
return await self.store.get_chat_completion(completion_id) return await self.store.get_chat_completion(completion_id)
raise NotImplementedError( raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
"Get chat completion is not supported: inference store is not configured."
)
async def _nonstream_openai_chat_completion( async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
self, provider: Inference, params: dict
) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params) response = await provider.openai_chat_completion(**params)
for choice in response.choices: for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses # some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty # but the OpenAI API returns None. So, set tool_calls to None if it's empty
if ( if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
choice.message
and choice.message.tool_calls is not None
and len(choice.message.tool_calls) == 0
):
choice.message.tool_calls = None choice.message.tool_calls = None
return response return response
@ -449,9 +424,7 @@ class InferenceRouter(Inference):
message=f"Health check timed out after {timeout} seconds", message=f"Health check timed out after {timeout} seconds",
) )
except NotImplementedError: except NotImplementedError:
health_statuses[provider_id] = HealthResponse( health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
status=HealthStatus.NOT_IMPLEMENTED
)
except Exception as e: except Exception as e:
health_statuses[provider_id] = HealthResponse( health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}" status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
@ -486,11 +459,7 @@ class InferenceRouter(Inference):
else: else:
if hasattr(chunk, "delta"): if hasattr(chunk, "delta"):
completion_text += chunk.delta completion_text += chunk.delta
if ( if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
hasattr(chunk, "stop_reason")
and chunk.stop_reason
and self.telemetry
):
complete = True complete = True
completion_tokens = await self._count_tokens(completion_text) completion_tokens = await self._count_tokens(completion_text)
# if we are done receiving tokens # if we are done receiving tokens
@ -515,14 +484,9 @@ class InferenceRouter(Inference):
# Return metrics in response # Return metrics in response
async_metrics = [ async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
for metric in completion_metrics
] ]
chunk.metrics = ( chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
async_metrics
if chunk.metrics is None
else chunk.metrics + async_metrics
)
else: else:
# Fallback if no telemetry # Fallback if no telemetry
completion_metrics = self._construct_metrics( completion_metrics = self._construct_metrics(
@ -532,14 +496,9 @@ class InferenceRouter(Inference):
model, model,
) )
async_metrics = [ async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
for metric in completion_metrics
] ]
chunk.metrics = ( chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
async_metrics
if chunk.metrics is None
else chunk.metrics + async_metrics
)
yield chunk yield chunk
async def count_tokens_and_compute_metrics( async def count_tokens_and_compute_metrics(
@ -553,9 +512,7 @@ class InferenceRouter(Inference):
content = [response.completion_message] content = [response.completion_message]
else: else:
content = response.content content = response.content
completion_tokens = await self._count_tokens( completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
messages=content, tool_prompt_format=tool_prompt_format
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
# Create a separate span for completion metrics # Create a separate span for completion metrics
@ -575,10 +532,7 @@ class InferenceRouter(Inference):
enqueue_event(metric) enqueue_event(metric)
# Return metrics in response # Return metrics in response
return [ return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in completion_metrics
]
# Fallback if no telemetry # Fallback if no telemetry
metrics = self._construct_metrics( metrics = self._construct_metrics(
@ -587,10 +541,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
return [ return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
async def stream_tokens_and_compute_metrics_openai_chat( async def stream_tokens_and_compute_metrics_openai_chat(
self, self,
@ -631,48 +582,33 @@ class InferenceRouter(Inference):
if choice_delta.delta: if choice_delta.delta:
delta = choice_delta.delta delta = choice_delta.delta
if delta.content: if delta.content:
current_choice_data["content_parts"].append( current_choice_data["content_parts"].append(delta.content)
delta.content
)
if delta.tool_calls: if delta.tool_calls:
for tool_call_delta in delta.tool_calls: for tool_call_delta in delta.tool_calls:
tc_idx = tool_call_delta.index tc_idx = tool_call_delta.index
if ( if tc_idx not in current_choice_data["tool_calls_builder"]:
tc_idx current_choice_data["tool_calls_builder"][tc_idx] = {
not in current_choice_data["tool_calls_builder"]
):
current_choice_data["tool_calls_builder"][
tc_idx
] = {
"id": None, "id": None,
"type": "function", "type": "function",
"function_name_parts": [], "function_name_parts": [],
"function_arguments_parts": [], "function_arguments_parts": [],
} }
builder = current_choice_data["tool_calls_builder"][ builder = current_choice_data["tool_calls_builder"][tc_idx]
tc_idx
]
if tool_call_delta.id: if tool_call_delta.id:
builder["id"] = tool_call_delta.id builder["id"] = tool_call_delta.id
if tool_call_delta.type: if tool_call_delta.type:
builder["type"] = tool_call_delta.type builder["type"] = tool_call_delta.type
if tool_call_delta.function: if tool_call_delta.function:
if tool_call_delta.function.name: if tool_call_delta.function.name:
builder["function_name_parts"].append( builder["function_name_parts"].append(tool_call_delta.function.name)
tool_call_delta.function.name
)
if tool_call_delta.function.arguments: if tool_call_delta.function.arguments:
builder["function_arguments_parts"].append( builder["function_arguments_parts"].append(
tool_call_delta.function.arguments tool_call_delta.function.arguments
) )
if choice_delta.finish_reason: if choice_delta.finish_reason:
current_choice_data["finish_reason"] = ( current_choice_data["finish_reason"] = choice_delta.finish_reason
choice_delta.finish_reason
)
if choice_delta.logprobs and choice_delta.logprobs.content: if choice_delta.logprobs and choice_delta.logprobs.content:
current_choice_data["logprobs_content_parts"].extend( current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
choice_delta.logprobs.content
)
# Compute metrics on final chunk # Compute metrics on final chunk
if chunk.choices and chunk.choices[0].finish_reason: if chunk.choices and chunk.choices[0].finish_reason:
@ -702,12 +638,8 @@ class InferenceRouter(Inference):
if choice_data["tool_calls_builder"]: if choice_data["tool_calls_builder"]:
for tc_build_data in choice_data["tool_calls_builder"].values(): for tc_build_data in choice_data["tool_calls_builder"].values():
if tc_build_data["id"]: if tc_build_data["id"]:
func_name = "".join( func_name = "".join(tc_build_data["function_name_parts"])
tc_build_data["function_name_parts"] func_args = "".join(tc_build_data["function_arguments_parts"])
)
func_args = "".join(
tc_build_data["function_arguments_parts"]
)
assembled_tool_calls.append( assembled_tool_calls.append(
OpenAIChatCompletionToolCall( OpenAIChatCompletionToolCall(
id=tc_build_data["id"], id=tc_build_data["id"],
@ -720,16 +652,10 @@ class InferenceRouter(Inference):
message = OpenAIAssistantMessageParam( message = OpenAIAssistantMessageParam(
role="assistant", role="assistant",
content=content_str if content_str else None, content=content_str if content_str else None,
tool_calls=( tool_calls=(assembled_tool_calls if assembled_tool_calls else None),
assembled_tool_calls if assembled_tool_calls else None
),
) )
logprobs_content = choice_data["logprobs_content_parts"] logprobs_content = choice_data["logprobs_content_parts"]
final_logprobs = ( final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
OpenAIChoiceLogprobs(content=logprobs_content)
if logprobs_content
else None
)
assembled_choices.append( assembled_choices.append(
OpenAIChoice( OpenAIChoice(
@ -748,6 +674,4 @@ class InferenceRouter(Inference):
object="chat.completion", object="chat.completion",
) )
logger.debug(f"InferenceRouter.completion_response: {final_response}") logger.debug(f"InferenceRouter.completion_response: {final_response}")
asyncio.create_task( asyncio.create_task(self.store.store_chat_completion(final_response, messages))
self.store.store_chat_completion(final_response, messages)
)

View file

@ -7,10 +7,9 @@
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import OpenAIEmbeddingsResponse from llama_stack.apis.inference import OpenAIEmbeddingsResponse
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
ModelRegistryHelper, ModelRegistryHelper,
build_hf_repo_model_entry,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options, get_sampling_options,
@ -51,9 +50,7 @@ class RunpodInferenceAdapter(
Inference, Inference,
): ):
def __init__(self, config: RunpodImplConfig) -> None: def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__( ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS
)
self.config = config self.config = config
def _get_params(self, request: ChatCompletionRequest) -> dict: def _get_params(self, request: ChatCompletionRequest) -> dict:

View file

@ -9,6 +9,7 @@ from typing import Any
from ibm_watsonx_ai.foundation_models import Model from ibm_watsonx_ai.foundation_models import Model
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from openai import AsyncOpenAI
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
@ -33,7 +34,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt, completion_request_to_prompt,
request_has_media, request_has_media,
) )
from openai import AsyncOpenAI
from . import WatsonXConfig from . import WatsonXConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
@ -65,9 +65,7 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
self._project_id = self._config.project_id self._project_id = self._config.project_id
def _get_client(self, model_id) -> Model: def _get_client(self, model_id) -> Model:
config_api_key = ( config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
self._config.api_key.get_secret_value() if self._config.api_key else None
)
config_url = self._config.url config_url = self._config.url
project_id = self._config.project_id project_id = self._config.project_id
credentials = {"url": config_url, "apikey": config_api_key} credentials = {"url": config_url, "apikey": config_api_key}
@ -82,46 +80,28 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
) )
return self._openai_client return self._openai_client
async def _get_params(self, request: ChatCompletionRequest) -> dict: async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {"params": {}} input_dict = {"params": {}}
media_present = request_has_media(request) media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model) llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
input_dict["prompt"] = await chat_completion_request_to_prompt( input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
request, llama_model
)
else: else:
assert ( assert not media_present, "Together does not support media for Completion requests"
not media_present
), "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request) input_dict["prompt"] = await completion_request_to_prompt(request)
if request.sampling_params: if request.sampling_params:
if request.sampling_params.strategy: if request.sampling_params.strategy:
input_dict["params"][ input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
GenParams.DECODING_METHOD
] = request.sampling_params.strategy.type
if request.sampling_params.max_tokens: if request.sampling_params.max_tokens:
input_dict["params"][ input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
GenParams.MAX_NEW_TOKENS
] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty: if request.sampling_params.repetition_penalty:
input_dict["params"][ input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
GenParams.REPETITION_PENALTY
] = request.sampling_params.repetition_penalty
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
input_dict["params"][ input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
GenParams.TOP_P input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
] = request.sampling_params.strategy.top_p
input_dict["params"][
GenParams.TEMPERATURE
] = request.sampling_params.strategy.temperature
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["params"][ input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
GenParams.TOP_K
] = request.sampling_params.strategy.top_k
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["params"][GenParams.TEMPERATURE] = 0.0 input_dict["params"][GenParams.TEMPERATURE] = 0.0

View file

@ -15,9 +15,17 @@ from typing import Any
from openai import AsyncStream from openai import AsyncStream
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
)
from openai.types.chat import (
ChatCompletionChunk as OpenAIChatCompletionChunk, ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat import (
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
from openai.types.chat import (
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
) )
@ -29,15 +37,56 @@ except ImportError:
from openai.types.chat.chat_completion_message_tool_call import ( from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall, ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall,
) )
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
)
from openai.types.chat import (
ChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
)
from openai.types.chat import (
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
)
from openai.types.chat import (
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
)
from openai.types.chat.chat_completion import (
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChatCompletionChunkChoice,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call import (
Function as OpenAIFunction,
)
from pydantic import BaseModel
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
_URLOrData, URL,
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
TextContentItem, TextContentItem,
TextDelta, TextDelta,
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
URL, _URLOrData,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
@ -74,30 +123,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url, convert_image_content_to_url,
decode_assistant_message, decode_assistant_message,
) )
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
ChatCompletionMessageToolCall,
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChatCompletionChunkChoice,
ChoiceDelta as OpenAIChoiceDelta,
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call import (
Function as OpenAIFunction,
)
from pydantic import BaseModel
logger = get_logger(name=__name__, category="providers::utils") logger = get_logger(name=__name__, category="providers::utils")
@ -196,16 +221,12 @@ def convert_openai_completion_logprobs(
if logprobs.tokens and logprobs.token_logprobs: if logprobs.tokens and logprobs.token_logprobs:
return [ return [
TokenLogProbs(logprobs_by_token={token: token_lp}) TokenLogProbs(logprobs_by_token={token: token_lp})
for token, token_lp in zip( for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False)
logprobs.tokens, logprobs.token_logprobs, strict=False
)
] ]
return None return None
def convert_openai_completion_logprobs_stream( def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None):
text: str, logprobs: float | OpenAICompatLogprobs | None
):
if logprobs is None: if logprobs is None:
return None return None
if isinstance(logprobs, float): if isinstance(logprobs, float):
@ -250,9 +271,7 @@ def process_chat_completion_response(
if not choice.message or not choice.message.tool_calls: if not choice.message or not choice.message.tool_calls:
raise ValueError("Tool calls are not present in the response") raise ValueError("Tool calls are not present in the response")
tool_calls = [ tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
convert_tool_call(tool_call) for tool_call in choice.message.tool_calls
]
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
# If we couldn't parse a tool call, jsonify the tool calls and return them # If we couldn't parse a tool call, jsonify the tool calls and return them
return ChatCompletionResponse( return ChatCompletionResponse(
@ -276,9 +295,7 @@ def process_chat_completion_response(
# TODO: This does not work well with tool calls for vLLM remote provider # TODO: This does not work well with tool calls for vLLM remote provider
# Ref: https://github.com/meta-llama/llama-stack/issues/1058 # Ref: https://github.com/meta-llama/llama-stack/issues/1058
raw_message = decode_assistant_message( raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
# NOTE: If we do not set tools in chat-completion request, we should not # NOTE: If we do not set tools in chat-completion request, we should not
# expect the ToolCall in the response. Instead, we should return the raw # expect the ToolCall in the response. Instead, we should return the raw
@ -479,17 +496,13 @@ async def process_chat_completion_stream_response(
) )
async def convert_message_to_openai_dict( async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict:
message: Message, download: bool = False
) -> dict:
async def _convert_content(content) -> dict: async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem): if isinstance(content, ImageContentItem):
return { return {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": await convert_image_content_to_url( "url": await convert_image_content_to_url(content, download=download),
content, download=download
),
}, },
} }
else: else:
@ -574,11 +587,7 @@ async def convert_message_to_openai_dict_new(
) -> str | Iterable[OpenAIChatCompletionContentPartParam]: ) -> str | Iterable[OpenAIChatCompletionContentPartParam]:
async def impl( async def impl(
content_: InterleavedContent, content_: InterleavedContent,
) -> ( ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]:
str
| OpenAIChatCompletionContentPartParam
| list[OpenAIChatCompletionContentPartParam]
):
# Llama Stack and OpenAI spec match for str and text input # Llama Stack and OpenAI spec match for str and text input
if isinstance(content_, str): if isinstance(content_, str):
return content_ return content_
@ -591,9 +600,7 @@ async def convert_message_to_openai_dict_new(
return OpenAIChatCompletionContentPartImageParam( return OpenAIChatCompletionContentPartImageParam(
type="image_url", type="image_url",
image_url=OpenAIImageURL( image_url=OpenAIImageURL(
url=await convert_image_content_to_url( url=await convert_image_content_to_url(content_, download=download_images)
content_, download=download_images
)
), ),
) )
elif isinstance(content_, list): elif isinstance(content_, list):
@ -620,11 +627,7 @@ async def convert_message_to_openai_dict_new(
OpenAIChatCompletionMessageFunctionToolCall( OpenAIChatCompletionMessageFunctionToolCall(
id=tool.call_id, id=tool.call_id,
function=OpenAIFunction( function=OpenAIFunction(
name=( name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
tool.tool_name
if not isinstance(tool.tool_name, BuiltinTool)
else tool.tool_name.value
),
arguments=tool.arguments, # Already a JSON string, don't double-encode arguments=tool.arguments, # Already a JSON string, don't double-encode
), ),
type="function", type="function",
@ -804,9 +807,7 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
}.get(finish_reason, StopReason.end_of_turn) }.get(finish_reason, StopReason.end_of_turn)
def _convert_openai_request_tool_config( def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig:
tool_choice: str | dict[str, Any] | None = None
) -> ToolConfig:
tool_config = ToolConfig() tool_config = ToolConfig()
if tool_choice: if tool_choice:
try: try:
@ -817,9 +818,7 @@ def _convert_openai_request_tool_config(
return tool_config return tool_config
def _convert_openai_request_tools( def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]:
tools: list[dict[str, Any]] | None = None
) -> list[ToolDefinition]:
lls_tools = [] lls_tools = []
if not tools: if not tools:
return lls_tools return lls_tools
@ -918,11 +917,7 @@ def _convert_openai_logprobs(
return None return None
return [ return [
TokenLogProbs( TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs})
logprobs_by_token={
logprobs.token: logprobs.logprob for logprobs in content.top_logprobs
}
)
for content in logprobs.content for content in logprobs.content
] ]
@ -961,13 +956,9 @@ def openai_messages_to_messages(
converted_messages = [] converted_messages = []
for message in messages: for message in messages:
if message.role == "system": if message.role == "system":
converted_message = SystemMessage( converted_message = SystemMessage(content=openai_content_to_content(message.content))
content=openai_content_to_content(message.content)
)
elif message.role == "user": elif message.role == "user":
converted_message = UserMessage( converted_message = UserMessage(content=openai_content_to_content(message.content))
content=openai_content_to_content(message.content)
)
elif message.role == "assistant": elif message.role == "assistant":
converted_message = CompletionMessage( converted_message = CompletionMessage(
content=openai_content_to_content(message.content), content=openai_content_to_content(message.content),
@ -999,9 +990,7 @@ def openai_content_to_content(
if content.type == "text": if content.type == "text":
return TextContentItem(type="text", text=content.text) return TextContentItem(type="text", text=content.text)
elif content.type == "image_url": elif content.type == "image_url":
return ImageContentItem( return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url)))
type="image", image=_URLOrData(url=URL(uri=content.image_url.url))
)
else: else:
raise ValueError(f"Unknown content type: {content.type}") raise ValueError(f"Unknown content type: {content.type}")
else: else:
@ -1041,17 +1030,14 @@ def convert_openai_chat_completion_choice(
end_of_message = "end_of_message" end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens" out_of_tokens = "out_of_tokens"
""" """
assert ( assert hasattr(choice, "message") and choice.message, "error in server response: message not found"
hasattr(choice, "message") and choice.message assert hasattr(choice, "finish_reason") and choice.finish_reason, (
), "error in server response: message not found" "error in server response: finish_reason not found"
assert ( )
hasattr(choice, "finish_reason") and choice.finish_reason
), "error in server response: finish_reason not found"
return ChatCompletionResponse( return ChatCompletionResponse(
completion_message=CompletionMessage( completion_message=CompletionMessage(
content=choice.message.content content=choice.message.content or "", # CompletionMessage content is not optional
or "", # CompletionMessage content is not optional
stop_reason=_convert_openai_finish_reason(choice.finish_reason), stop_reason=_convert_openai_finish_reason(choice.finish_reason),
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
), ),
@ -1291,9 +1277,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
outstanding_responses.append(response) outstanding_responses.append(response)
if stream: if stream:
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response( return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
self, model, outstanding_responses
)
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response( return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
self, model, outstanding_responses self, model, outstanding_responses
@ -1302,29 +1286,21 @@ class OpenAIChatCompletionToLlamaStackMixin:
async def _process_stream_response( async def _process_stream_response(
self, self,
model: str, model: str,
outstanding_responses: list[ outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]
],
): ):
id = f"chatcmpl-{uuid.uuid4()}" id = f"chatcmpl-{uuid.uuid4()}"
for i, outstanding_response in enumerate(outstanding_responses): for i, outstanding_response in enumerate(outstanding_responses):
response = await outstanding_response response = await outstanding_response
async for chunk in response: async for chunk in response:
event = chunk.event event = chunk.event
finish_reason = _convert_stop_reason_to_openai_finish_reason( finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
event.stop_reason
)
if isinstance(event.delta, TextDelta): if isinstance(event.delta, TextDelta):
text_delta = event.delta.text text_delta = event.delta.text
delta = OpenAIChoiceDelta(content=text_delta) delta = OpenAIChoiceDelta(content=text_delta)
yield OpenAIChatCompletionChunk( yield OpenAIChatCompletionChunk(
id=id, id=id,
choices=[ choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
OpenAIChatCompletionChunkChoice(
index=i, finish_reason=finish_reason, delta=delta
)
],
created=int(time.time()), created=int(time.time()),
model=model, model=model,
object="chat.completion.chunk", object="chat.completion.chunk",
@ -1346,9 +1322,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
yield OpenAIChatCompletionChunk( yield OpenAIChatCompletionChunk(
id=id, id=id,
choices=[ choices=[
OpenAIChatCompletionChunkChoice( OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
index=i, finish_reason=finish_reason, delta=delta
)
], ],
created=int(time.time()), created=int(time.time()),
model=model, model=model,
@ -1365,9 +1339,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
yield OpenAIChatCompletionChunk( yield OpenAIChatCompletionChunk(
id=id, id=id,
choices=[ choices=[
OpenAIChatCompletionChunkChoice( OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
index=i, finish_reason=finish_reason, delta=delta
)
], ],
created=int(time.time()), created=int(time.time()),
model=model, model=model,
@ -1382,9 +1354,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
response = await outstanding_response response = await outstanding_response
completion_message = response.completion_message completion_message = response.completion_message
message = await convert_message_to_openai_dict_new(completion_message) message = await convert_message_to_openai_dict_new(completion_message)
finish_reason = _convert_stop_reason_to_openai_finish_reason( finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason)
completion_message.stop_reason
)
choice = OpenAIChatCompletionChoice( choice = OpenAIChatCompletionChoice(
index=len(choices), index=len(choices),

View file

@ -87,9 +87,7 @@ def pytest_configure(config):
suite = config.getoption("--suite") suite = config.getoption("--suite")
if suite: if suite:
if suite not in SUITE_DEFINITIONS: if suite not in SUITE_DEFINITIONS:
raise pytest.UsageError( raise pytest.UsageError(f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}")
f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}"
)
# Apply setups (global parameterizations): env + defaults # Apply setups (global parameterizations): env + defaults
setup = config.getoption("--setup") setup = config.getoption("--setup")
@ -127,9 +125,7 @@ def pytest_addoption(parser):
""" """
), ),
) )
parser.addoption( parser.addoption("--env", action="append", help="Set environment variables, e.g. --env KEY=value")
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
)
parser.addoption( parser.addoption(
"--text-model", "--text-model",
help="comma-separated list of text models. Fixture name: text_model_id", help="comma-separated list of text models. Fixture name: text_model_id",
@ -169,7 +165,9 @@ def pytest_addoption(parser):
) )
available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys())) available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys()))
suite_help = f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses" suite_help = (
f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses"
)
parser.addoption("--suite", help=suite_help) parser.addoption("--suite", help=suite_help)
# Global setups for any suite # Global setups for any suite
@ -241,11 +239,7 @@ def pytest_generate_tests(metafunc):
# Generate test IDs # Generate test IDs
test_ids = [] test_ids = []
non_empty_params = [ non_empty_params = [(i, values) for i, values in enumerate(param_values.values()) if values[0] is not None]
(i, values)
for i, values in enumerate(param_values.values())
if values[0] is not None
]
# Get actual function parameters using inspect # Get actual function parameters using inspect
test_func_params = set(inspect.signature(metafunc.function).parameters.keys()) test_func_params = set(inspect.signature(metafunc.function).parameters.keys())
@ -262,9 +256,7 @@ def pytest_generate_tests(metafunc):
if parts: if parts:
test_ids.append(":".join(parts)) test_ids.append(":".join(parts))
metafunc.parametrize( metafunc.parametrize(params, value_combinations, scope="session", ids=test_ids if test_ids else None)
params, value_combinations, scope="session", ids=test_ids if test_ids else None
)
def pytest_ignore_collect(path: str, config: pytest.Config) -> bool: def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
@ -274,9 +266,7 @@ def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
return False return False
sobj = SUITE_DEFINITIONS.get(suite) sobj = SUITE_DEFINITIONS.get(suite)
roots: list[str] = ( roots: list[str] = sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", [])
sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", [])
)
if not roots: if not roots:
return False return False

View file

@ -9,15 +9,15 @@ import sys
from typing import Any, Protocol from typing import Any, Protocol
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from llama_stack.apis.inference import Inference, SamplingParams from pydantic import BaseModel, Field
from llama_stack.apis.inference import Inference
from llama_stack.core.datatypes import Api, Provider, StackRunConfig from llama_stack.core.datatypes import Api, Provider, StackRunConfig
from llama_stack.core.resolver import resolve_impls from llama_stack.core.resolver import resolve_impls
from llama_stack.core.routers.inference import InferenceRouter from llama_stack.core.routers.inference import InferenceRouter
from llama_stack.core.routing_tables.models import ModelsRoutingTable from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
from pydantic import BaseModel, Field
def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None: def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None:
"""Dynamically add protocol methods to a class by inspecting the protocol.""" """Dynamically add protocol methods to a class by inspecting the protocol."""