mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-14 16:52:37 +00:00
Ran precommit
This commit is contained in:
parent
9886520b40
commit
9fc0d966f6
7 changed files with 153 additions and 310 deletions
|
|
@ -8,6 +8,9 @@ from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||||
from llama_stack.apis.common.responses import Order
|
from llama_stack.apis.common.responses import Order
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
|
@ -23,9 +26,6 @@ from llama_stack.models.llama.datatypes import (
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
register_schema(ToolCall)
|
register_schema(ToolCall)
|
||||||
register_schema(ToolDefinition)
|
register_schema(ToolDefinition)
|
||||||
|
|
||||||
|
|
@ -381,9 +381,7 @@ class ToolConfig(BaseModel):
|
||||||
|
|
||||||
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
||||||
system_message_behavior: SystemMessageBehavior | None = Field(
|
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
|
||||||
default=SystemMessageBehavior.append
|
|
||||||
)
|
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
def model_post_init(self, __context: Any) -> None:
|
||||||
if isinstance(self.tool_choice, str):
|
if isinstance(self.tool_choice, str):
|
||||||
|
|
@ -512,21 +510,15 @@ class OpenAIFile(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIChatCompletionContentPartParam = Annotated[
|
OpenAIChatCompletionContentPartParam = Annotated[
|
||||||
OpenAIChatCompletionContentPartTextParam
|
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile,
|
||||||
| OpenAIChatCompletionContentPartImageParam
|
|
||||||
| OpenAIFile,
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(
|
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||||
OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||||
|
|
||||||
OpenAIChatCompletionTextOnlyMessageContent = (
|
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
|
||||||
str | list[OpenAIChatCompletionContentPartTextParam]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
@ -694,9 +686,7 @@ class OpenAIResponseFormatJSONObject(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseFormatParam = Annotated[
|
OpenAIResponseFormatParam = Annotated[
|
||||||
OpenAIResponseFormatText
|
OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
|
||||||
| OpenAIResponseFormatJSONSchema
|
|
||||||
| OpenAIResponseFormatJSONObject,
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||||
|
|
@ -986,16 +976,8 @@ class InferenceProvider(Protocol):
|
||||||
async def rerank(
|
async def rerank(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
query: (
|
query: (str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam),
|
||||||
str
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
| OpenAIChatCompletionContentPartTextParam
|
|
||||||
| OpenAIChatCompletionContentPartImageParam
|
|
||||||
),
|
|
||||||
items: list[
|
|
||||||
str
|
|
||||||
| OpenAIChatCompletionContentPartTextParam
|
|
||||||
| OpenAIChatCompletionContentPartImageParam
|
|
||||||
],
|
|
||||||
max_num_results: int | None = None,
|
max_num_results: int | None = None,
|
||||||
) -> RerankResponse:
|
) -> RerankResponse:
|
||||||
"""Rerank a list of documents based on their relevance to a query.
|
"""Rerank a list of documents based on their relevance to a query.
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,17 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from datetime import datetime, UTC
|
from datetime import UTC, datetime
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionToolParam as OpenAIChatCompletionToolParam,
|
||||||
|
)
|
||||||
|
from pydantic import Field, TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
|
@ -48,12 +56,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
get_current_span,
|
get_current_span,
|
||||||
)
|
)
|
||||||
|
|
||||||
from openai.types.chat import (
|
|
||||||
ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam,
|
|
||||||
ChatCompletionToolParam as OpenAIChatCompletionToolParam,
|
|
||||||
)
|
|
||||||
from pydantic import Field, TypeAdapter
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core::routers")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -96,9 +98,7 @@ class InferenceRouter(Inference):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||||
)
|
)
|
||||||
await self.routing_table.register_model(
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
model_id, provider_model_id, provider_id, metadata, model_type
|
|
||||||
)
|
|
||||||
|
|
||||||
def _construct_metrics(
|
def _construct_metrics(
|
||||||
self,
|
self,
|
||||||
|
|
@ -153,16 +153,11 @@ class InferenceRouter(Inference):
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
model: Model,
|
model: Model,
|
||||||
) -> list[MetricInResponse]:
|
) -> list[MetricInResponse]:
|
||||||
metrics = self._construct_metrics(
|
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||||
prompt_tokens, completion_tokens, total_tokens, model
|
|
||||||
)
|
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
enqueue_event(metric)
|
enqueue_event(metric)
|
||||||
return [
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
|
||||||
for metric in metrics
|
|
||||||
]
|
|
||||||
|
|
||||||
async def _count_tokens(
|
async def _count_tokens(
|
||||||
self,
|
self,
|
||||||
|
|
@ -256,9 +251,7 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
# these metrics will show up in the client response.
|
# these metrics will show up in the client response.
|
||||||
response.metrics = (
|
response.metrics = (
|
||||||
metrics
|
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||||
if not hasattr(response, "metrics") or response.metrics is None
|
|
||||||
else response.metrics + metrics
|
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
@ -296,13 +289,9 @@ class InferenceRouter(Inference):
|
||||||
# Use the OpenAI client for a bit of extra input validation without
|
# Use the OpenAI client for a bit of extra input validation without
|
||||||
# exposing the OpenAI client itself as part of our API surface
|
# exposing the OpenAI client itself as part of our API surface
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(
|
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
|
||||||
tool_choice
|
|
||||||
)
|
|
||||||
if tools is None:
|
if tools is None:
|
||||||
raise ValueError(
|
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
||||||
"'tool_choice' is only allowed when 'tools' is also provided"
|
|
||||||
)
|
|
||||||
if tools:
|
if tools:
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||||
|
|
@ -367,9 +356,7 @@ class InferenceRouter(Inference):
|
||||||
enqueue_event(metric)
|
enqueue_event(metric)
|
||||||
# these metrics will show up in the client response.
|
# these metrics will show up in the client response.
|
||||||
response.metrics = (
|
response.metrics = (
|
||||||
metrics
|
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||||
if not hasattr(response, "metrics") or response.metrics is None
|
|
||||||
else response.metrics + metrics
|
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
@ -405,31 +392,19 @@ class InferenceRouter(Inference):
|
||||||
) -> ListOpenAIChatCompletionResponse:
|
) -> ListOpenAIChatCompletionResponse:
|
||||||
if self.store:
|
if self.store:
|
||||||
return await self.store.list_chat_completions(after, limit, model, order)
|
return await self.store.list_chat_completions(after, limit, model, order)
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
|
||||||
"List chat completions is not supported: inference store is not configured."
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_chat_completion(
|
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||||
self, completion_id: str
|
|
||||||
) -> OpenAICompletionWithInputMessages:
|
|
||||||
if self.store:
|
if self.store:
|
||||||
return await self.store.get_chat_completion(completion_id)
|
return await self.store.get_chat_completion(completion_id)
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
||||||
"Get chat completion is not supported: inference store is not configured."
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _nonstream_openai_chat_completion(
|
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
||||||
self, provider: Inference, params: dict
|
|
||||||
) -> OpenAIChatCompletion:
|
|
||||||
response = await provider.openai_chat_completion(**params)
|
response = await provider.openai_chat_completion(**params)
|
||||||
for choice in response.choices:
|
for choice in response.choices:
|
||||||
# some providers return an empty list for no tool calls in non-streaming responses
|
# some providers return an empty list for no tool calls in non-streaming responses
|
||||||
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
||||||
if (
|
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
|
||||||
choice.message
|
|
||||||
and choice.message.tool_calls is not None
|
|
||||||
and len(choice.message.tool_calls) == 0
|
|
||||||
):
|
|
||||||
choice.message.tool_calls = None
|
choice.message.tool_calls = None
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
@ -449,9 +424,7 @@ class InferenceRouter(Inference):
|
||||||
message=f"Health check timed out after {timeout} seconds",
|
message=f"Health check timed out after {timeout} seconds",
|
||||||
)
|
)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
health_statuses[provider_id] = HealthResponse(
|
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
||||||
status=HealthStatus.NOT_IMPLEMENTED
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
health_statuses[provider_id] = HealthResponse(
|
health_statuses[provider_id] = HealthResponse(
|
||||||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||||
|
|
@ -486,11 +459,7 @@ class InferenceRouter(Inference):
|
||||||
else:
|
else:
|
||||||
if hasattr(chunk, "delta"):
|
if hasattr(chunk, "delta"):
|
||||||
completion_text += chunk.delta
|
completion_text += chunk.delta
|
||||||
if (
|
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||||
hasattr(chunk, "stop_reason")
|
|
||||||
and chunk.stop_reason
|
|
||||||
and self.telemetry
|
|
||||||
):
|
|
||||||
complete = True
|
complete = True
|
||||||
completion_tokens = await self._count_tokens(completion_text)
|
completion_tokens = await self._count_tokens(completion_text)
|
||||||
# if we are done receiving tokens
|
# if we are done receiving tokens
|
||||||
|
|
@ -515,14 +484,9 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
# Return metrics in response
|
# Return metrics in response
|
||||||
async_metrics = [
|
async_metrics = [
|
||||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||||
for metric in completion_metrics
|
|
||||||
]
|
]
|
||||||
chunk.metrics = (
|
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||||
async_metrics
|
|
||||||
if chunk.metrics is None
|
|
||||||
else chunk.metrics + async_metrics
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Fallback if no telemetry
|
# Fallback if no telemetry
|
||||||
completion_metrics = self._construct_metrics(
|
completion_metrics = self._construct_metrics(
|
||||||
|
|
@ -532,14 +496,9 @@ class InferenceRouter(Inference):
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
async_metrics = [
|
async_metrics = [
|
||||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||||
for metric in completion_metrics
|
|
||||||
]
|
]
|
||||||
chunk.metrics = (
|
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||||
async_metrics
|
|
||||||
if chunk.metrics is None
|
|
||||||
else chunk.metrics + async_metrics
|
|
||||||
)
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def count_tokens_and_compute_metrics(
|
async def count_tokens_and_compute_metrics(
|
||||||
|
|
@ -553,9 +512,7 @@ class InferenceRouter(Inference):
|
||||||
content = [response.completion_message]
|
content = [response.completion_message]
|
||||||
else:
|
else:
|
||||||
content = response.content
|
content = response.content
|
||||||
completion_tokens = await self._count_tokens(
|
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
|
||||||
messages=content, tool_prompt_format=tool_prompt_format
|
|
||||||
)
|
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
|
||||||
# Create a separate span for completion metrics
|
# Create a separate span for completion metrics
|
||||||
|
|
@ -575,10 +532,7 @@ class InferenceRouter(Inference):
|
||||||
enqueue_event(metric)
|
enqueue_event(metric)
|
||||||
|
|
||||||
# Return metrics in response
|
# Return metrics in response
|
||||||
return [
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
||||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
|
||||||
for metric in completion_metrics
|
|
||||||
]
|
|
||||||
|
|
||||||
# Fallback if no telemetry
|
# Fallback if no telemetry
|
||||||
metrics = self._construct_metrics(
|
metrics = self._construct_metrics(
|
||||||
|
|
@ -587,10 +541,7 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
return [
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
|
||||||
for metric in metrics
|
|
||||||
]
|
|
||||||
|
|
||||||
async def stream_tokens_and_compute_metrics_openai_chat(
|
async def stream_tokens_and_compute_metrics_openai_chat(
|
||||||
self,
|
self,
|
||||||
|
|
@ -631,48 +582,33 @@ class InferenceRouter(Inference):
|
||||||
if choice_delta.delta:
|
if choice_delta.delta:
|
||||||
delta = choice_delta.delta
|
delta = choice_delta.delta
|
||||||
if delta.content:
|
if delta.content:
|
||||||
current_choice_data["content_parts"].append(
|
current_choice_data["content_parts"].append(delta.content)
|
||||||
delta.content
|
|
||||||
)
|
|
||||||
if delta.tool_calls:
|
if delta.tool_calls:
|
||||||
for tool_call_delta in delta.tool_calls:
|
for tool_call_delta in delta.tool_calls:
|
||||||
tc_idx = tool_call_delta.index
|
tc_idx = tool_call_delta.index
|
||||||
if (
|
if tc_idx not in current_choice_data["tool_calls_builder"]:
|
||||||
tc_idx
|
current_choice_data["tool_calls_builder"][tc_idx] = {
|
||||||
not in current_choice_data["tool_calls_builder"]
|
|
||||||
):
|
|
||||||
current_choice_data["tool_calls_builder"][
|
|
||||||
tc_idx
|
|
||||||
] = {
|
|
||||||
"id": None,
|
"id": None,
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function_name_parts": [],
|
"function_name_parts": [],
|
||||||
"function_arguments_parts": [],
|
"function_arguments_parts": [],
|
||||||
}
|
}
|
||||||
builder = current_choice_data["tool_calls_builder"][
|
builder = current_choice_data["tool_calls_builder"][tc_idx]
|
||||||
tc_idx
|
|
||||||
]
|
|
||||||
if tool_call_delta.id:
|
if tool_call_delta.id:
|
||||||
builder["id"] = tool_call_delta.id
|
builder["id"] = tool_call_delta.id
|
||||||
if tool_call_delta.type:
|
if tool_call_delta.type:
|
||||||
builder["type"] = tool_call_delta.type
|
builder["type"] = tool_call_delta.type
|
||||||
if tool_call_delta.function:
|
if tool_call_delta.function:
|
||||||
if tool_call_delta.function.name:
|
if tool_call_delta.function.name:
|
||||||
builder["function_name_parts"].append(
|
builder["function_name_parts"].append(tool_call_delta.function.name)
|
||||||
tool_call_delta.function.name
|
|
||||||
)
|
|
||||||
if tool_call_delta.function.arguments:
|
if tool_call_delta.function.arguments:
|
||||||
builder["function_arguments_parts"].append(
|
builder["function_arguments_parts"].append(
|
||||||
tool_call_delta.function.arguments
|
tool_call_delta.function.arguments
|
||||||
)
|
)
|
||||||
if choice_delta.finish_reason:
|
if choice_delta.finish_reason:
|
||||||
current_choice_data["finish_reason"] = (
|
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
||||||
choice_delta.finish_reason
|
|
||||||
)
|
|
||||||
if choice_delta.logprobs and choice_delta.logprobs.content:
|
if choice_delta.logprobs and choice_delta.logprobs.content:
|
||||||
current_choice_data["logprobs_content_parts"].extend(
|
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
||||||
choice_delta.logprobs.content
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute metrics on final chunk
|
# Compute metrics on final chunk
|
||||||
if chunk.choices and chunk.choices[0].finish_reason:
|
if chunk.choices and chunk.choices[0].finish_reason:
|
||||||
|
|
@ -702,12 +638,8 @@ class InferenceRouter(Inference):
|
||||||
if choice_data["tool_calls_builder"]:
|
if choice_data["tool_calls_builder"]:
|
||||||
for tc_build_data in choice_data["tool_calls_builder"].values():
|
for tc_build_data in choice_data["tool_calls_builder"].values():
|
||||||
if tc_build_data["id"]:
|
if tc_build_data["id"]:
|
||||||
func_name = "".join(
|
func_name = "".join(tc_build_data["function_name_parts"])
|
||||||
tc_build_data["function_name_parts"]
|
func_args = "".join(tc_build_data["function_arguments_parts"])
|
||||||
)
|
|
||||||
func_args = "".join(
|
|
||||||
tc_build_data["function_arguments_parts"]
|
|
||||||
)
|
|
||||||
assembled_tool_calls.append(
|
assembled_tool_calls.append(
|
||||||
OpenAIChatCompletionToolCall(
|
OpenAIChatCompletionToolCall(
|
||||||
id=tc_build_data["id"],
|
id=tc_build_data["id"],
|
||||||
|
|
@ -720,16 +652,10 @@ class InferenceRouter(Inference):
|
||||||
message = OpenAIAssistantMessageParam(
|
message = OpenAIAssistantMessageParam(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=content_str if content_str else None,
|
content=content_str if content_str else None,
|
||||||
tool_calls=(
|
tool_calls=(assembled_tool_calls if assembled_tool_calls else None),
|
||||||
assembled_tool_calls if assembled_tool_calls else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
logprobs_content = choice_data["logprobs_content_parts"]
|
logprobs_content = choice_data["logprobs_content_parts"]
|
||||||
final_logprobs = (
|
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
||||||
OpenAIChoiceLogprobs(content=logprobs_content)
|
|
||||||
if logprobs_content
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
assembled_choices.append(
|
assembled_choices.append(
|
||||||
OpenAIChoice(
|
OpenAIChoice(
|
||||||
|
|
@ -748,6 +674,4 @@ class InferenceRouter(Inference):
|
||||||
object="chat.completion",
|
object="chat.completion",
|
||||||
)
|
)
|
||||||
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
||||||
asyncio.create_task(
|
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
|
||||||
self.store.store_chat_completion(final_response, messages)
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,9 @@
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_hf_repo_model_entry,
|
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
build_hf_repo_model_entry,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
|
|
@ -51,9 +50,7 @@ class RunpodInferenceAdapter(
|
||||||
Inference,
|
Inference,
|
||||||
):
|
):
|
||||||
def __init__(self, config: RunpodImplConfig) -> None:
|
def __init__(self, config: RunpodImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
|
||||||
self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS
|
|
||||||
)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from typing import Any
|
||||||
|
|
||||||
from ibm_watsonx_ai.foundation_models import Model
|
from ibm_watsonx_ai.foundation_models import Model
|
||||||
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
|
@ -33,7 +34,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
from . import WatsonXConfig
|
from . import WatsonXConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
@ -65,9 +65,7 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
self._project_id = self._config.project_id
|
self._project_id = self._config.project_id
|
||||||
|
|
||||||
def _get_client(self, model_id) -> Model:
|
def _get_client(self, model_id) -> Model:
|
||||||
config_api_key = (
|
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
|
||||||
self._config.api_key.get_secret_value() if self._config.api_key else None
|
|
||||||
)
|
|
||||||
config_url = self._config.url
|
config_url = self._config.url
|
||||||
project_id = self._config.project_id
|
project_id = self._config.project_id
|
||||||
credentials = {"url": config_url, "apikey": config_api_key}
|
credentials = {"url": config_url, "apikey": config_api_key}
|
||||||
|
|
@ -82,46 +80,28 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
return self._openai_client
|
return self._openai_client
|
||||||
|
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
|
|
||||||
input_dict = {"params": {}}
|
input_dict = {"params": {}}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
llama_model = self.get_llama_model(request.model)
|
llama_model = self.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||||
request, llama_model
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert (
|
assert not media_present, "Together does not support media for Completion requests"
|
||||||
not media_present
|
|
||||||
), "Together does not support media for Completion requests"
|
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
if request.sampling_params:
|
if request.sampling_params:
|
||||||
if request.sampling_params.strategy:
|
if request.sampling_params.strategy:
|
||||||
input_dict["params"][
|
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
|
||||||
GenParams.DECODING_METHOD
|
|
||||||
] = request.sampling_params.strategy.type
|
|
||||||
if request.sampling_params.max_tokens:
|
if request.sampling_params.max_tokens:
|
||||||
input_dict["params"][
|
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
|
||||||
GenParams.MAX_NEW_TOKENS
|
|
||||||
] = request.sampling_params.max_tokens
|
|
||||||
if request.sampling_params.repetition_penalty:
|
if request.sampling_params.repetition_penalty:
|
||||||
input_dict["params"][
|
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
|
||||||
GenParams.REPETITION_PENALTY
|
|
||||||
] = request.sampling_params.repetition_penalty
|
|
||||||
|
|
||||||
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
|
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
|
||||||
input_dict["params"][
|
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
|
||||||
GenParams.TOP_P
|
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
|
||||||
] = request.sampling_params.strategy.top_p
|
|
||||||
input_dict["params"][
|
|
||||||
GenParams.TEMPERATURE
|
|
||||||
] = request.sampling_params.strategy.temperature
|
|
||||||
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
||||||
input_dict["params"][
|
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
|
||||||
GenParams.TOP_K
|
|
||||||
] = request.sampling_params.strategy.top_k
|
|
||||||
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
|
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
|
||||||
input_dict["params"][GenParams.TEMPERATURE] = 0.0
|
input_dict["params"][GenParams.TEMPERATURE] = 0.0
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,17 @@ from typing import Any
|
||||||
from openai import AsyncStream
|
from openai import AsyncStream
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
|
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
|
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
|
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -29,15 +37,56 @@ except ImportError:
|
||||||
from openai.types.chat.chat_completion_message_tool_call import (
|
from openai.types.chat.chat_completion_message_tool_call import (
|
||||||
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall,
|
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall,
|
||||||
)
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion import (
|
||||||
|
Choice as OpenAIChoice,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion import (
|
||||||
|
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
|
Choice as OpenAIChatCompletionChunkChoice,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
|
ChoiceDelta as OpenAIChoiceDelta,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
|
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
|
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion_content_part_image_param import (
|
||||||
|
ImageURL as OpenAIImageURL,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion_message_tool_call import (
|
||||||
|
Function as OpenAIFunction,
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
_URLOrData,
|
URL,
|
||||||
ImageContentItem,
|
ImageContentItem,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
TextDelta,
|
TextDelta,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
URL,
|
_URLOrData,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
|
@ -74,30 +123,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
convert_image_content_to_url,
|
convert_image_content_to_url,
|
||||||
decode_assistant_message,
|
decode_assistant_message,
|
||||||
)
|
)
|
||||||
from openai.types.chat import (
|
|
||||||
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
|
||||||
ChatCompletionMessageToolCall,
|
|
||||||
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
|
||||||
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
|
|
||||||
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
|
|
||||||
)
|
|
||||||
from openai.types.chat.chat_completion import (
|
|
||||||
Choice as OpenAIChoice,
|
|
||||||
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
|
||||||
)
|
|
||||||
from openai.types.chat.chat_completion_chunk import (
|
|
||||||
Choice as OpenAIChatCompletionChunkChoice,
|
|
||||||
ChoiceDelta as OpenAIChoiceDelta,
|
|
||||||
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
|
|
||||||
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
|
|
||||||
)
|
|
||||||
from openai.types.chat.chat_completion_content_part_image_param import (
|
|
||||||
ImageURL as OpenAIImageURL,
|
|
||||||
)
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import (
|
|
||||||
Function as OpenAIFunction,
|
|
||||||
)
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="providers::utils")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
@ -196,16 +221,12 @@ def convert_openai_completion_logprobs(
|
||||||
if logprobs.tokens and logprobs.token_logprobs:
|
if logprobs.tokens and logprobs.token_logprobs:
|
||||||
return [
|
return [
|
||||||
TokenLogProbs(logprobs_by_token={token: token_lp})
|
TokenLogProbs(logprobs_by_token={token: token_lp})
|
||||||
for token, token_lp in zip(
|
for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False)
|
||||||
logprobs.tokens, logprobs.token_logprobs, strict=False
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def convert_openai_completion_logprobs_stream(
|
def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None):
|
||||||
text: str, logprobs: float | OpenAICompatLogprobs | None
|
|
||||||
):
|
|
||||||
if logprobs is None:
|
if logprobs is None:
|
||||||
return None
|
return None
|
||||||
if isinstance(logprobs, float):
|
if isinstance(logprobs, float):
|
||||||
|
|
@ -250,9 +271,7 @@ def process_chat_completion_response(
|
||||||
if not choice.message or not choice.message.tool_calls:
|
if not choice.message or not choice.message.tool_calls:
|
||||||
raise ValueError("Tool calls are not present in the response")
|
raise ValueError("Tool calls are not present in the response")
|
||||||
|
|
||||||
tool_calls = [
|
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
|
||||||
convert_tool_call(tool_call) for tool_call in choice.message.tool_calls
|
|
||||||
]
|
|
||||||
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
||||||
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
|
|
@ -276,9 +295,7 @@ def process_chat_completion_response(
|
||||||
|
|
||||||
# TODO: This does not work well with tool calls for vLLM remote provider
|
# TODO: This does not work well with tool calls for vLLM remote provider
|
||||||
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
||||||
raw_message = decode_assistant_message(
|
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
|
||||||
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE: If we do not set tools in chat-completion request, we should not
|
# NOTE: If we do not set tools in chat-completion request, we should not
|
||||||
# expect the ToolCall in the response. Instead, we should return the raw
|
# expect the ToolCall in the response. Instead, we should return the raw
|
||||||
|
|
@ -479,17 +496,13 @@ async def process_chat_completion_stream_response(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict(
|
async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict:
|
||||||
message: Message, download: bool = False
|
|
||||||
) -> dict:
|
|
||||||
async def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
if isinstance(content, ImageContentItem):
|
if isinstance(content, ImageContentItem):
|
||||||
return {
|
return {
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": await convert_image_content_to_url(
|
"url": await convert_image_content_to_url(content, download=download),
|
||||||
content, download=download
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|
@ -574,11 +587,7 @@ async def convert_message_to_openai_dict_new(
|
||||||
) -> str | Iterable[OpenAIChatCompletionContentPartParam]:
|
) -> str | Iterable[OpenAIChatCompletionContentPartParam]:
|
||||||
async def impl(
|
async def impl(
|
||||||
content_: InterleavedContent,
|
content_: InterleavedContent,
|
||||||
) -> (
|
) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]:
|
||||||
str
|
|
||||||
| OpenAIChatCompletionContentPartParam
|
|
||||||
| list[OpenAIChatCompletionContentPartParam]
|
|
||||||
):
|
|
||||||
# Llama Stack and OpenAI spec match for str and text input
|
# Llama Stack and OpenAI spec match for str and text input
|
||||||
if isinstance(content_, str):
|
if isinstance(content_, str):
|
||||||
return content_
|
return content_
|
||||||
|
|
@ -591,9 +600,7 @@ async def convert_message_to_openai_dict_new(
|
||||||
return OpenAIChatCompletionContentPartImageParam(
|
return OpenAIChatCompletionContentPartImageParam(
|
||||||
type="image_url",
|
type="image_url",
|
||||||
image_url=OpenAIImageURL(
|
image_url=OpenAIImageURL(
|
||||||
url=await convert_image_content_to_url(
|
url=await convert_image_content_to_url(content_, download=download_images)
|
||||||
content_, download=download_images
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif isinstance(content_, list):
|
elif isinstance(content_, list):
|
||||||
|
|
@ -620,11 +627,7 @@ async def convert_message_to_openai_dict_new(
|
||||||
OpenAIChatCompletionMessageFunctionToolCall(
|
OpenAIChatCompletionMessageFunctionToolCall(
|
||||||
id=tool.call_id,
|
id=tool.call_id,
|
||||||
function=OpenAIFunction(
|
function=OpenAIFunction(
|
||||||
name=(
|
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
||||||
tool.tool_name
|
|
||||||
if not isinstance(tool.tool_name, BuiltinTool)
|
|
||||||
else tool.tool_name.value
|
|
||||||
),
|
|
||||||
arguments=tool.arguments, # Already a JSON string, don't double-encode
|
arguments=tool.arguments, # Already a JSON string, don't double-encode
|
||||||
),
|
),
|
||||||
type="function",
|
type="function",
|
||||||
|
|
@ -804,9 +807,7 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
||||||
}.get(finish_reason, StopReason.end_of_turn)
|
}.get(finish_reason, StopReason.end_of_turn)
|
||||||
|
|
||||||
|
|
||||||
def _convert_openai_request_tool_config(
|
def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig:
|
||||||
tool_choice: str | dict[str, Any] | None = None
|
|
||||||
) -> ToolConfig:
|
|
||||||
tool_config = ToolConfig()
|
tool_config = ToolConfig()
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
try:
|
try:
|
||||||
|
|
@ -817,9 +818,7 @@ def _convert_openai_request_tool_config(
|
||||||
return tool_config
|
return tool_config
|
||||||
|
|
||||||
|
|
||||||
def _convert_openai_request_tools(
|
def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]:
|
||||||
tools: list[dict[str, Any]] | None = None
|
|
||||||
) -> list[ToolDefinition]:
|
|
||||||
lls_tools = []
|
lls_tools = []
|
||||||
if not tools:
|
if not tools:
|
||||||
return lls_tools
|
return lls_tools
|
||||||
|
|
@ -918,11 +917,7 @@ def _convert_openai_logprobs(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return [
|
return [
|
||||||
TokenLogProbs(
|
TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs})
|
||||||
logprobs_by_token={
|
|
||||||
logprobs.token: logprobs.logprob for logprobs in content.top_logprobs
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for content in logprobs.content
|
for content in logprobs.content
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -961,13 +956,9 @@ def openai_messages_to_messages(
|
||||||
converted_messages = []
|
converted_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.role == "system":
|
if message.role == "system":
|
||||||
converted_message = SystemMessage(
|
converted_message = SystemMessage(content=openai_content_to_content(message.content))
|
||||||
content=openai_content_to_content(message.content)
|
|
||||||
)
|
|
||||||
elif message.role == "user":
|
elif message.role == "user":
|
||||||
converted_message = UserMessage(
|
converted_message = UserMessage(content=openai_content_to_content(message.content))
|
||||||
content=openai_content_to_content(message.content)
|
|
||||||
)
|
|
||||||
elif message.role == "assistant":
|
elif message.role == "assistant":
|
||||||
converted_message = CompletionMessage(
|
converted_message = CompletionMessage(
|
||||||
content=openai_content_to_content(message.content),
|
content=openai_content_to_content(message.content),
|
||||||
|
|
@ -999,9 +990,7 @@ def openai_content_to_content(
|
||||||
if content.type == "text":
|
if content.type == "text":
|
||||||
return TextContentItem(type="text", text=content.text)
|
return TextContentItem(type="text", text=content.text)
|
||||||
elif content.type == "image_url":
|
elif content.type == "image_url":
|
||||||
return ImageContentItem(
|
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url)))
|
||||||
type="image", image=_URLOrData(url=URL(uri=content.image_url.url))
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown content type: {content.type}")
|
raise ValueError(f"Unknown content type: {content.type}")
|
||||||
else:
|
else:
|
||||||
|
|
@ -1041,17 +1030,14 @@ def convert_openai_chat_completion_choice(
|
||||||
end_of_message = "end_of_message"
|
end_of_message = "end_of_message"
|
||||||
out_of_tokens = "out_of_tokens"
|
out_of_tokens = "out_of_tokens"
|
||||||
"""
|
"""
|
||||||
assert (
|
assert hasattr(choice, "message") and choice.message, "error in server response: message not found"
|
||||||
hasattr(choice, "message") and choice.message
|
assert hasattr(choice, "finish_reason") and choice.finish_reason, (
|
||||||
), "error in server response: message not found"
|
"error in server response: finish_reason not found"
|
||||||
assert (
|
)
|
||||||
hasattr(choice, "finish_reason") and choice.finish_reason
|
|
||||||
), "error in server response: finish_reason not found"
|
|
||||||
|
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
completion_message=CompletionMessage(
|
completion_message=CompletionMessage(
|
||||||
content=choice.message.content
|
content=choice.message.content or "", # CompletionMessage content is not optional
|
||||||
or "", # CompletionMessage content is not optional
|
|
||||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||||
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
|
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
|
||||||
),
|
),
|
||||||
|
|
@ -1291,9 +1277,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
outstanding_responses.append(response)
|
outstanding_responses.append(response)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(
|
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
|
||||||
self, model, outstanding_responses
|
|
||||||
)
|
|
||||||
|
|
||||||
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
|
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
|
||||||
self, model, outstanding_responses
|
self, model, outstanding_responses
|
||||||
|
|
@ -1302,29 +1286,21 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
async def _process_stream_response(
|
async def _process_stream_response(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
outstanding_responses: list[
|
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
|
||||||
Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]
|
|
||||||
],
|
|
||||||
):
|
):
|
||||||
id = f"chatcmpl-{uuid.uuid4()}"
|
id = f"chatcmpl-{uuid.uuid4()}"
|
||||||
for i, outstanding_response in enumerate(outstanding_responses):
|
for i, outstanding_response in enumerate(outstanding_responses):
|
||||||
response = await outstanding_response
|
response = await outstanding_response
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(
|
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
|
||||||
event.stop_reason
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(event.delta, TextDelta):
|
if isinstance(event.delta, TextDelta):
|
||||||
text_delta = event.delta.text
|
text_delta = event.delta.text
|
||||||
delta = OpenAIChoiceDelta(content=text_delta)
|
delta = OpenAIChoiceDelta(content=text_delta)
|
||||||
yield OpenAIChatCompletionChunk(
|
yield OpenAIChatCompletionChunk(
|
||||||
id=id,
|
id=id,
|
||||||
choices=[
|
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
|
||||||
OpenAIChatCompletionChunkChoice(
|
|
||||||
index=i, finish_reason=finish_reason, delta=delta
|
|
||||||
)
|
|
||||||
],
|
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
model=model,
|
model=model,
|
||||||
object="chat.completion.chunk",
|
object="chat.completion.chunk",
|
||||||
|
|
@ -1346,9 +1322,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
yield OpenAIChatCompletionChunk(
|
yield OpenAIChatCompletionChunk(
|
||||||
id=id,
|
id=id,
|
||||||
choices=[
|
choices=[
|
||||||
OpenAIChatCompletionChunkChoice(
|
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
||||||
index=i, finish_reason=finish_reason, delta=delta
|
|
||||||
)
|
|
||||||
],
|
],
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -1365,9 +1339,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
yield OpenAIChatCompletionChunk(
|
yield OpenAIChatCompletionChunk(
|
||||||
id=id,
|
id=id,
|
||||||
choices=[
|
choices=[
|
||||||
OpenAIChatCompletionChunkChoice(
|
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
||||||
index=i, finish_reason=finish_reason, delta=delta
|
|
||||||
)
|
|
||||||
],
|
],
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -1382,9 +1354,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
response = await outstanding_response
|
response = await outstanding_response
|
||||||
completion_message = response.completion_message
|
completion_message = response.completion_message
|
||||||
message = await convert_message_to_openai_dict_new(completion_message)
|
message = await convert_message_to_openai_dict_new(completion_message)
|
||||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(
|
finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason)
|
||||||
completion_message.stop_reason
|
|
||||||
)
|
|
||||||
|
|
||||||
choice = OpenAIChatCompletionChoice(
|
choice = OpenAIChatCompletionChoice(
|
||||||
index=len(choices),
|
index=len(choices),
|
||||||
|
|
|
||||||
|
|
@ -87,9 +87,7 @@ def pytest_configure(config):
|
||||||
suite = config.getoption("--suite")
|
suite = config.getoption("--suite")
|
||||||
if suite:
|
if suite:
|
||||||
if suite not in SUITE_DEFINITIONS:
|
if suite not in SUITE_DEFINITIONS:
|
||||||
raise pytest.UsageError(
|
raise pytest.UsageError(f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}")
|
||||||
f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply setups (global parameterizations): env + defaults
|
# Apply setups (global parameterizations): env + defaults
|
||||||
setup = config.getoption("--setup")
|
setup = config.getoption("--setup")
|
||||||
|
|
@ -127,9 +125,7 @@ def pytest_addoption(parser):
|
||||||
"""
|
"""
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.addoption(
|
parser.addoption("--env", action="append", help="Set environment variables, e.g. --env KEY=value")
|
||||||
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
|
|
||||||
)
|
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--text-model",
|
"--text-model",
|
||||||
help="comma-separated list of text models. Fixture name: text_model_id",
|
help="comma-separated list of text models. Fixture name: text_model_id",
|
||||||
|
|
@ -169,7 +165,9 @@ def pytest_addoption(parser):
|
||||||
)
|
)
|
||||||
|
|
||||||
available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys()))
|
available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys()))
|
||||||
suite_help = f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses"
|
suite_help = (
|
||||||
|
f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses"
|
||||||
|
)
|
||||||
parser.addoption("--suite", help=suite_help)
|
parser.addoption("--suite", help=suite_help)
|
||||||
|
|
||||||
# Global setups for any suite
|
# Global setups for any suite
|
||||||
|
|
@ -241,11 +239,7 @@ def pytest_generate_tests(metafunc):
|
||||||
|
|
||||||
# Generate test IDs
|
# Generate test IDs
|
||||||
test_ids = []
|
test_ids = []
|
||||||
non_empty_params = [
|
non_empty_params = [(i, values) for i, values in enumerate(param_values.values()) if values[0] is not None]
|
||||||
(i, values)
|
|
||||||
for i, values in enumerate(param_values.values())
|
|
||||||
if values[0] is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
# Get actual function parameters using inspect
|
# Get actual function parameters using inspect
|
||||||
test_func_params = set(inspect.signature(metafunc.function).parameters.keys())
|
test_func_params = set(inspect.signature(metafunc.function).parameters.keys())
|
||||||
|
|
@ -262,9 +256,7 @@ def pytest_generate_tests(metafunc):
|
||||||
if parts:
|
if parts:
|
||||||
test_ids.append(":".join(parts))
|
test_ids.append(":".join(parts))
|
||||||
|
|
||||||
metafunc.parametrize(
|
metafunc.parametrize(params, value_combinations, scope="session", ids=test_ids if test_ids else None)
|
||||||
params, value_combinations, scope="session", ids=test_ids if test_ids else None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
|
def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
|
||||||
|
|
@ -274,9 +266,7 @@ def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
sobj = SUITE_DEFINITIONS.get(suite)
|
sobj = SUITE_DEFINITIONS.get(suite)
|
||||||
roots: list[str] = (
|
roots: list[str] = sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", [])
|
||||||
sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", [])
|
|
||||||
)
|
|
||||||
if not roots:
|
if not roots:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,15 +9,15 @@ import sys
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
from llama_stack.apis.inference import Inference, SamplingParams
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||||
from llama_stack.core.resolver import resolve_impls
|
from llama_stack.core.resolver import resolve_impls
|
||||||
from llama_stack.core.routers.inference import InferenceRouter
|
from llama_stack.core.routers.inference import InferenceRouter
|
||||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||||
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
|
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None:
|
def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None:
|
||||||
"""Dynamically add protocol methods to a class by inspecting the protocol."""
|
"""Dynamically add protocol methods to a class by inspecting the protocol."""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue