mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-15 08:42:36 +00:00
Merge branch 'main' into dead_code_removal
This commit is contained in:
commit
9886520b40
927 changed files with 171924 additions and 102933 deletions
|
|
@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
|
||||
|
||||
from .openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
|
|
@ -42,6 +42,20 @@ from .openai_responses import (
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ResponseShieldSpec(BaseModel):
|
||||
"""Specification for a shield to apply during response generation.
|
||||
|
||||
:param type: The type/identifier of the shield.
|
||||
"""
|
||||
|
||||
type: str
|
||||
# TODO: more fields to be added for shield configuration
|
||||
|
||||
|
||||
ResponseShield = str | ResponseShieldSpec
|
||||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
"""An attachment to an agent turn.
|
||||
|
||||
|
|
@ -783,7 +797,7 @@ class Agents(Protocol):
|
|||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
"""Retrieve an OpenAI response by its ID.
|
||||
"""Get a model response.
|
||||
|
||||
:param response_id: The ID of the OpenAI response to retrieve.
|
||||
:returns: An OpenAIResponseObject.
|
||||
|
|
@ -805,13 +819,20 @@ class Agents(Protocol):
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||
shields: Annotated[
|
||||
list[ResponseShield] | None,
|
||||
ExtraBodyField(
|
||||
"List of shields to apply during response generation. Shields provide safety and content moderation."
|
||||
),
|
||||
] = None,
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a new OpenAI response.
|
||||
"""Create a model response.
|
||||
|
||||
:param input: Input message(s) to create the response.
|
||||
:param model: The underlying LLM used for completions.
|
||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||
:param include: (Optional) Additional fields to include in the response.
|
||||
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
|
@ -825,7 +846,7 @@ class Agents(Protocol):
|
|||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
"""List all OpenAI responses.
|
||||
"""List all responses.
|
||||
|
||||
:param after: The ID of the last response to return.
|
||||
:param limit: The number of responses to return.
|
||||
|
|
@ -848,7 +869,7 @@ class Agents(Protocol):
|
|||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
"""List input items for a given OpenAI response.
|
||||
"""List input items.
|
||||
|
||||
:param response_id: The ID of the response to retrieve input items for.
|
||||
:param after: An item ID to list items after, used for pagination.
|
||||
|
|
@ -863,7 +884,7 @@ class Agents(Protocol):
|
|||
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
"""Delete an OpenAI response by its ID.
|
||||
"""Delete a response.
|
||||
|
||||
:param response_id: The ID of the OpenAI response to delete.
|
||||
:returns: An OpenAIDeleteResponseObject
|
||||
|
|
|
|||
|
|
@ -888,6 +888,10 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
|||
|
||||
input: list[OpenAIResponseInput]
|
||||
|
||||
def to_response_object(self) -> OpenAIResponseObject:
|
||||
"""Convert to OpenAIResponseObject by excluding input field."""
|
||||
return OpenAIResponseObject(**{k: v for k, v in self.model_dump().items() if k != "input"})
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListOpenAIResponseObject(BaseModel):
|
||||
|
|
|
|||
31
llama_stack/apis/conversations/__init__.py
Normal file
31
llama_stack/apis/conversations/__init__.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .conversations import (
|
||||
Conversation,
|
||||
ConversationCreateRequest,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
ConversationItemCreateRequest,
|
||||
ConversationItemDeletedResource,
|
||||
ConversationItemList,
|
||||
Conversations,
|
||||
ConversationUpdateRequest,
|
||||
Metadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Conversation",
|
||||
"ConversationCreateRequest",
|
||||
"ConversationDeletedResource",
|
||||
"ConversationItem",
|
||||
"ConversationItemCreateRequest",
|
||||
"ConversationItemDeletedResource",
|
||||
"ConversationItemList",
|
||||
"Conversations",
|
||||
"ConversationUpdateRequest",
|
||||
"Metadata",
|
||||
]
|
||||
260
llama_stack/apis/conversations/conversations.py
Normal file
260
llama_stack/apis/conversations/conversations.py
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Literal, Protocol, runtime_checkable
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
from openai._types import NotGiven
|
||||
from openai.types.responses.response_includable import ResponseIncludable
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
Metadata = dict[str, str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Conversation(BaseModel):
|
||||
"""OpenAI-compatible conversation object."""
|
||||
|
||||
id: str = Field(..., description="The unique ID of the conversation.")
|
||||
object: Literal["conversation"] = Field(
|
||||
default="conversation", description="The object type, which is always conversation."
|
||||
)
|
||||
created_at: int = Field(
|
||||
..., description="The time at which the conversation was created, measured in seconds since the Unix epoch."
|
||||
)
|
||||
metadata: Metadata | None = Field(
|
||||
default=None,
|
||||
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard.",
|
||||
)
|
||||
items: list[dict] | None = Field(
|
||||
default=None,
|
||||
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationMessage(BaseModel):
|
||||
"""OpenAI-compatible message item for conversations."""
|
||||
|
||||
id: str = Field(..., description="unique identifier for this message")
|
||||
content: list[dict] = Field(..., description="message content")
|
||||
role: str = Field(..., description="message role")
|
||||
status: str = Field(..., description="message status")
|
||||
type: Literal["message"] = "message"
|
||||
object: Literal["message"] = "message"
|
||||
|
||||
|
||||
ConversationItem = Annotated[
|
||||
OpenAIResponseMessage
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageMCPCall
|
||||
| OpenAIResponseOutputMessageMCPListTools,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ConversationItem, name="ConversationItem")
|
||||
|
||||
# Using OpenAI types directly caused issues but some notes for reference:
|
||||
# Note that ConversationItem is a Annotated Union of the types below:
|
||||
# from openai.types.responses import *
|
||||
# from openai.types.responses.response_item import *
|
||||
# from openai.types.conversations import ConversationItem
|
||||
# f = [
|
||||
# ResponseFunctionToolCallItem,
|
||||
# ResponseFunctionToolCallOutputItem,
|
||||
# ResponseFileSearchToolCall,
|
||||
# ResponseFunctionWebSearch,
|
||||
# ImageGenerationCall,
|
||||
# ResponseComputerToolCall,
|
||||
# ResponseComputerToolCallOutputItem,
|
||||
# ResponseReasoningItem,
|
||||
# ResponseCodeInterpreterToolCall,
|
||||
# LocalShellCall,
|
||||
# LocalShellCallOutput,
|
||||
# McpListTools,
|
||||
# McpApprovalRequest,
|
||||
# McpApprovalResponse,
|
||||
# McpCall,
|
||||
# ResponseCustomToolCall,
|
||||
# ResponseCustomToolCallOutput
|
||||
# ]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationCreateRequest(BaseModel):
|
||||
"""Request body for creating a conversation."""
|
||||
|
||||
items: list[ConversationItem] | None = Field(
|
||||
default=[],
|
||||
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
max_length=20,
|
||||
)
|
||||
metadata: Metadata | None = Field(
|
||||
default={},
|
||||
description="Set of 16 key-value pairs that can be attached to an object. Useful for storing additional information",
|
||||
max_length=16,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationUpdateRequest(BaseModel):
|
||||
"""Request body for updating a conversation."""
|
||||
|
||||
metadata: Metadata = Field(
|
||||
...,
|
||||
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters.",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationDeletedResource(BaseModel):
|
||||
"""Response for deleted conversation."""
|
||||
|
||||
id: str = Field(..., description="The deleted conversation identifier")
|
||||
object: str = Field(default="conversation.deleted", description="Object type")
|
||||
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationItemCreateRequest(BaseModel):
|
||||
"""Request body for creating conversation items."""
|
||||
|
||||
items: list[ConversationItem] = Field(
|
||||
...,
|
||||
description="Items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
max_length=20,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationItemList(BaseModel):
|
||||
"""List of conversation items with pagination."""
|
||||
|
||||
object: str = Field(default="list", description="Object type")
|
||||
data: list[ConversationItem] = Field(..., description="List of conversation items")
|
||||
first_id: str | None = Field(default=None, description="The ID of the first item in the list")
|
||||
last_id: str | None = Field(default=None, description="The ID of the last item in the list")
|
||||
has_more: bool = Field(default=False, description="Whether there are more items available")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationItemDeletedResource(BaseModel):
|
||||
"""Response for deleted conversation item."""
|
||||
|
||||
id: str = Field(..., description="The deleted item identifier")
|
||||
object: str = Field(default="conversation.item.deleted", description="Object type")
|
||||
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Conversations(Protocol):
|
||||
"""Protocol for conversation management operations."""
|
||||
|
||||
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_conversation(
|
||||
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||
) -> Conversation:
|
||||
"""Create a conversation.
|
||||
|
||||
:param items: Initial items to include in the conversation context.
|
||||
:param metadata: Set of key-value pairs that can be attached to an object.
|
||||
:returns: The created conversation object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Get a conversation with the given ID.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:returns: The conversation object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||
"""Update a conversation's metadata with the given ID.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param metadata: Set of key-value pairs that can be attached to an object.
|
||||
:returns: The updated conversation object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||
"""Delete a conversation with the given ID.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:returns: The deleted conversation resource.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||
"""Create items in the conversation.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param items: Items to include in the conversation context.
|
||||
:returns: List of created items.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||
"""Retrieve a conversation item.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param item_id: The item identifier.
|
||||
:returns: The conversation item.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list(
|
||||
self,
|
||||
conversation_id: str,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
include: list[ResponseIncludable] | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
|
||||
) -> ConversationItemList:
|
||||
"""List items in the conversation.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param after: An item ID to list items after, used in pagination.
|
||||
:param include: Specify additional output data to include in the response.
|
||||
:param limit: A limit on the number of objects to be returned (1-100, default 20).
|
||||
:param order: The order to return items in (asc or desc, default desc).
|
||||
:returns: List of conversation items.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def openai_delete_conversation_item(
|
||||
self, conversation_id: str, item_id: str
|
||||
) -> ConversationItemDeletedResource:
|
||||
"""Delete a conversation item.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param item_id: The item identifier.
|
||||
:returns: The deleted item resource.
|
||||
"""
|
||||
...
|
||||
|
|
@ -129,6 +129,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
tool_groups = "tool_groups"
|
||||
files = "files"
|
||||
prompts = "prompts"
|
||||
conversations = "conversations"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
|
|
|||
|
|
@ -104,6 +104,11 @@ class OpenAIFileDeleteResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Files(Protocol):
|
||||
"""Files
|
||||
|
||||
This API is used to upload documents that can be used with other Llama Stack APIs.
|
||||
"""
|
||||
|
||||
# OpenAI Files API Endpoints
|
||||
@webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
|
||||
|
|
@ -113,7 +118,8 @@ class Files(Protocol):
|
|||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
"""Upload file.
|
||||
|
||||
Upload a file that can be used across various endpoints.
|
||||
|
||||
The file upload should be a multipart form request with:
|
||||
|
|
@ -137,7 +143,8 @@ class Files(Protocol):
|
|||
order: Order | None = Order.desc,
|
||||
purpose: OpenAIFilePurpose | None = None,
|
||||
) -> ListOpenAIFileResponse:
|
||||
"""
|
||||
"""List files.
|
||||
|
||||
Returns a list of files that belong to the user's organization.
|
||||
|
||||
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
|
||||
|
|
@ -154,7 +161,8 @@ class Files(Protocol):
|
|||
self,
|
||||
file_id: str,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
"""Retrieve file.
|
||||
|
||||
Returns information about a specific file.
|
||||
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
|
|
@ -168,8 +176,7 @@ class Files(Protocol):
|
|||
self,
|
||||
file_id: str,
|
||||
) -> OpenAIFileDeleteResponse:
|
||||
"""
|
||||
Delete a file.
|
||||
"""Delete file.
|
||||
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
:returns: An OpenAIFileDeleteResponse indicating successful deletion.
|
||||
|
|
@ -182,7 +189,8 @@ class Files(Protocol):
|
|||
self,
|
||||
file_id: str,
|
||||
) -> Response:
|
||||
"""
|
||||
"""Retrieve file content.
|
||||
|
||||
Returns the contents of the specified file.
|
||||
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
|
|
|
|||
|
|
@ -982,45 +982,6 @@ class InferenceProvider(Protocol):
|
|||
|
||||
model_store: ModelStore | None = None
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
"""Generate a chat completion for the given messages using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation.
|
||||
:param sampling_params: Parameters to control the sampling strategy.
|
||||
:param tools: (Optional) List of tool definitions available to the model.
|
||||
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||
.. deprecated::
|
||||
Use tool_config instead.
|
||||
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
||||
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
||||
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
||||
.. deprecated::
|
||||
Use tool_config instead.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
|
||||
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
|
||||
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
|
||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:param tool_config: (Optional) Configuration for tool use.
|
||||
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
|
||||
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/rerank", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def rerank(
|
||||
self,
|
||||
|
|
@ -1081,7 +1042,9 @@ class InferenceProvider(Protocol):
|
|||
# for fill-in-the-middle type completion
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
"""Create completion.
|
||||
|
||||
Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param prompt: The prompt to generate a completion for.
|
||||
|
|
@ -1138,7 +1101,9 @@ class InferenceProvider(Protocol):
|
|||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
"""Create chat completions.
|
||||
|
||||
Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation.
|
||||
|
|
@ -1182,7 +1147,9 @@ class InferenceProvider(Protocol):
|
|||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""Generate OpenAI-compatible embeddings for the given input using the specified model.
|
||||
"""Create embeddings.
|
||||
|
||||
Generate OpenAI-compatible embeddings for the given input using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
|
||||
|
|
@ -1195,7 +1162,9 @@ class InferenceProvider(Protocol):
|
|||
|
||||
|
||||
class Inference(InferenceProvider):
|
||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
"""Inference
|
||||
|
||||
Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
|
|
@ -1216,7 +1185,7 @@ class Inference(InferenceProvider):
|
|||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIChatCompletionResponse:
|
||||
"""List all chat completions.
|
||||
"""List chat completions.
|
||||
|
||||
:param after: The ID of the last chat completion to return.
|
||||
:param limit: The maximum number of chat completions to return.
|
||||
|
|
@ -1237,10 +1206,11 @@ class Inference(InferenceProvider):
|
|||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def get_chat_completion(
|
||||
self, completion_id: str
|
||||
) -> OpenAICompletionWithInputMessages:
|
||||
"""Describe a chat completion by its ID.
|
||||
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
"""Get chat completion.
|
||||
|
||||
Describe a chat completion by its ID.
|
||||
|
||||
:param completion_id: ID of the chat completion.
|
||||
:returns: A OpenAICompletionWithInputMessages.
|
||||
|
|
|
|||
|
|
@ -58,9 +58,16 @@ class ListRoutesResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
"""Inspect
|
||||
|
||||
APIs for inspecting the Llama Stack service, including health status, available API routes with methods and implementing providers.
|
||||
"""
|
||||
|
||||
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
"""List all available API routes with their methods and implementing providers.
|
||||
"""List routes.
|
||||
|
||||
List all available API routes with their methods and implementing providers.
|
||||
|
||||
:returns: Response containing information about all available routes.
|
||||
"""
|
||||
|
|
@ -68,7 +75,9 @@ class Inspect(Protocol):
|
|||
|
||||
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def health(self) -> HealthInfo:
|
||||
"""Get the current health status of the service.
|
||||
"""Get health status.
|
||||
|
||||
Get the current health status of the service.
|
||||
|
||||
:returns: Health information indicating if the service is operational.
|
||||
"""
|
||||
|
|
@ -76,7 +85,9 @@ class Inspect(Protocol):
|
|||
|
||||
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def version(self) -> VersionInfo:
|
||||
"""Get the version of the service.
|
||||
"""Get version.
|
||||
|
||||
Get the version of the service.
|
||||
|
||||
:returns: Version information containing the service version number.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -124,7 +124,9 @@ class Models(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
) -> Model:
|
||||
"""Get a model by its identifier.
|
||||
"""Get model.
|
||||
|
||||
Get a model by its identifier.
|
||||
|
||||
:param model_id: The identifier of the model to get.
|
||||
:returns: A Model.
|
||||
|
|
@ -140,7 +142,9 @@ class Models(Protocol):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model:
|
||||
"""Register a model.
|
||||
"""Register model.
|
||||
|
||||
Register a model.
|
||||
|
||||
:param model_id: The identifier of the model to register.
|
||||
:param provider_model_id: The identifier of the model in the provider.
|
||||
|
|
@ -156,7 +160,9 @@ class Models(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
) -> None:
|
||||
"""Unregister a model.
|
||||
"""Unregister model.
|
||||
|
||||
Unregister a model.
|
||||
|
||||
:param model_id: The identifier of the model to unregister.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -94,7 +94,9 @@ class ListPromptsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Prompts(Protocol):
|
||||
"""Protocol for prompt management operations."""
|
||||
"""Prompts
|
||||
|
||||
Protocol for prompt management operations."""
|
||||
|
||||
@webmethod(route="/prompts", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_prompts(self) -> ListPromptsResponse:
|
||||
|
|
@ -109,7 +111,9 @@ class Prompts(Protocol):
|
|||
self,
|
||||
prompt_id: str,
|
||||
) -> ListPromptsResponse:
|
||||
"""List all versions of a specific prompt.
|
||||
"""List prompt versions.
|
||||
|
||||
List all versions of a specific prompt.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to list versions for.
|
||||
:returns: A ListPromptsResponse containing all versions of the prompt.
|
||||
|
|
@ -122,7 +126,9 @@ class Prompts(Protocol):
|
|||
prompt_id: str,
|
||||
version: int | None = None,
|
||||
) -> Prompt:
|
||||
"""Get a prompt by its identifier and optional version.
|
||||
"""Get prompt.
|
||||
|
||||
Get a prompt by its identifier and optional version.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to get.
|
||||
:param version: The version of the prompt to get (defaults to latest).
|
||||
|
|
@ -136,7 +142,9 @@ class Prompts(Protocol):
|
|||
prompt: str,
|
||||
variables: list[str] | None = None,
|
||||
) -> Prompt:
|
||||
"""Create a new prompt.
|
||||
"""Create prompt.
|
||||
|
||||
Create a new prompt.
|
||||
|
||||
:param prompt: The prompt text content with variable placeholders.
|
||||
:param variables: List of variable names that can be used in the prompt template.
|
||||
|
|
@ -153,7 +161,9 @@ class Prompts(Protocol):
|
|||
variables: list[str] | None = None,
|
||||
set_as_default: bool = True,
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt (increments version).
|
||||
"""Update prompt.
|
||||
|
||||
Update an existing prompt (increments version).
|
||||
|
||||
:param prompt_id: The identifier of the prompt to update.
|
||||
:param prompt: The updated prompt text content.
|
||||
|
|
@ -169,7 +179,9 @@ class Prompts(Protocol):
|
|||
self,
|
||||
prompt_id: str,
|
||||
) -> None:
|
||||
"""Delete a prompt.
|
||||
"""Delete prompt.
|
||||
|
||||
Delete a prompt.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to delete.
|
||||
"""
|
||||
|
|
@ -181,7 +193,9 @@ class Prompts(Protocol):
|
|||
prompt_id: str,
|
||||
version: int,
|
||||
) -> Prompt:
|
||||
"""Set which version of a prompt should be the default in get_prompt (latest).
|
||||
"""Set prompt version.
|
||||
|
||||
Set which version of a prompt should be the default in get_prompt (latest).
|
||||
|
||||
:param prompt_id: The identifier of the prompt.
|
||||
:param version: The version to set as default.
|
||||
|
|
|
|||
|
|
@ -42,13 +42,16 @@ class ListProvidersResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Providers(Protocol):
|
||||
"""
|
||||
"""Providers
|
||||
|
||||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
"""List all available providers.
|
||||
"""List providers.
|
||||
|
||||
List all available providers.
|
||||
|
||||
:returns: A ListProvidersResponse containing information about all providers.
|
||||
"""
|
||||
|
|
@ -56,7 +59,9 @@ class Providers(Protocol):
|
|||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
"""Get detailed information about a specific provider.
|
||||
"""Get provider.
|
||||
|
||||
Get detailed information about a specific provider.
|
||||
|
||||
:param provider_id: The ID of the provider to inspect.
|
||||
:returns: A ProviderInfo object containing the provider's details.
|
||||
|
|
|
|||
|
|
@ -96,6 +96,11 @@ class ShieldStore(Protocol):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Safety(Protocol):
|
||||
"""Safety
|
||||
|
||||
OpenAI-compatible Moderations API.
|
||||
"""
|
||||
|
||||
shield_store: ShieldStore
|
||||
|
||||
@webmethod(route="/safety/run-shield", method="POST", level=LLAMA_STACK_API_V1)
|
||||
|
|
@ -105,7 +110,9 @@ class Safety(Protocol):
|
|||
messages: list[Message],
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse:
|
||||
"""Run a shield.
|
||||
"""Run shield.
|
||||
|
||||
Run a shield.
|
||||
|
||||
:param shield_id: The identifier of the shield to run.
|
||||
:param messages: The messages to run the shield on.
|
||||
|
|
@ -117,7 +124,9 @@ class Safety(Protocol):
|
|||
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
"""Classifies if text and/or image inputs are potentially harmful.
|
||||
"""Create moderation.
|
||||
|
||||
Classifies if text and/or image inputs are potentially harmful.
|
||||
:param input: Input (or inputs) to classify.
|
||||
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
|
||||
:param model: The content moderation model you would like to use.
|
||||
|
|
|
|||
|
|
@ -6,11 +6,18 @@
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import ssl
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import uvicorn
|
||||
import yaml
|
||||
|
||||
from llama_stack.cli.stack.utils import ImageType
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.core.datatypes import LoggingConfig, StackRunConfig
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars, validate_env_pair
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
|
@ -146,23 +153,7 @@ class StackRun(Subcommand):
|
|||
# using the current environment packages.
|
||||
if not image_type and not image_name:
|
||||
logger.info("No image type or image name provided. Assuming environment packages.")
|
||||
from llama_stack.core.server.server import main as server_main
|
||||
|
||||
# Build the server args from the current args passed to the CLI
|
||||
server_args = argparse.Namespace()
|
||||
for arg in vars(args):
|
||||
# If this is a function, avoid passing it
|
||||
# "args" contains:
|
||||
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
|
||||
if callable(getattr(args, arg)):
|
||||
continue
|
||||
if arg == "config":
|
||||
server_args.config = str(config_file)
|
||||
else:
|
||||
setattr(server_args, arg, getattr(args, arg))
|
||||
|
||||
# Run the server
|
||||
server_main(server_args)
|
||||
self._uvicorn_run(config_file, args)
|
||||
else:
|
||||
run_args = formulate_run_args(image_type, image_name)
|
||||
|
||||
|
|
@ -184,6 +175,76 @@ class StackRun(Subcommand):
|
|||
|
||||
run_command(run_args)
|
||||
|
||||
def _uvicorn_run(self, config_file: Path | None, args: argparse.Namespace) -> None:
|
||||
if not config_file:
|
||||
self.parser.error("Config file is required")
|
||||
|
||||
# Set environment variables if provided
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
self.parser.error(f"Invalid environment variable format: {env_pair}")
|
||||
|
||||
config_file = resolve_config_or_distro(str(config_file), Mode.RUN)
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
port = args.port or config.server.port
|
||||
host = config.server.host or ["::", "0.0.0.0"]
|
||||
|
||||
# Set the config file in environment so create_app can find it
|
||||
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
|
||||
|
||||
uvicorn_config = {
|
||||
"factory": True,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
}
|
||||
|
||||
keyfile = config.server.tls_keyfile
|
||||
certfile = config.server.tls_certfile
|
||||
if keyfile and certfile:
|
||||
uvicorn_config["ssl_keyfile"] = config.server.tls_keyfile
|
||||
uvicorn_config["ssl_certfile"] = config.server.tls_certfile
|
||||
if config.server.tls_cafile:
|
||||
uvicorn_config["ssl_ca_certs"] = config.server.tls_cafile
|
||||
uvicorn_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
||||
|
||||
logger.info(
|
||||
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
logger.info(f"Listening on {host}:{port}")
|
||||
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
# stack trace when using Ctrl+C or kill -2 (SIGINT).
|
||||
# SIGTERM (kill -15) works fine without this because Python doesn't
|
||||
# have a default handler for it.
|
||||
#
|
||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
uvicorn.run("llama_stack.core.server.server:create_app", **uvicorn_config)
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
|
||||
def _start_ui_development_server(self, stack_server_port: int):
|
||||
logger.info("Attempting to start UI development server...")
|
||||
# Check if npm is available
|
||||
|
|
|
|||
|
|
@ -324,14 +324,14 @@ fi
|
|||
RUN pip uninstall -y uv
|
||||
EOF
|
||||
|
||||
# If a run config is provided, we use the --config flag
|
||||
# If a run config is provided, we use the llama stack CLI
|
||||
if [[ -n "$run_config" ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$RUN_CONFIG_PATH"]
|
||||
ENTRYPOINT ["llama", "stack", "run", "$RUN_CONFIG_PATH"]
|
||||
EOF
|
||||
elif [[ "$distro_or_config" != *.yaml ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$distro_or_config"]
|
||||
ENTRYPOINT ["llama", "stack", "run", "$distro_or_config"]
|
||||
EOF
|
||||
fi
|
||||
|
||||
|
|
|
|||
5
llama_stack/core/conversations/__init__.py
Normal file
5
llama_stack/core/conversations/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
306
llama_stack/core/conversations/conversations.py
Normal file
306
llama_stack/core/conversations/conversations.py
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
Conversation,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
ConversationItemDeletedResource,
|
||||
ConversationItemList,
|
||||
Conversations,
|
||||
Metadata,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import (
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreConfig,
|
||||
sqlstore_impl,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai::conversations")
|
||||
|
||||
|
||||
class ConversationServiceConfig(BaseModel):
|
||||
"""Configuration for the built-in conversation service.
|
||||
|
||||
:param conversations_store: SQL store configuration for conversations (defaults to SQLite)
|
||||
:param policy: Access control rules
|
||||
"""
|
||||
|
||||
conversations_store: SqlStoreConfig = SqliteSqlStoreConfig(
|
||||
db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix()
|
||||
)
|
||||
policy: list[AccessRule] = []
|
||||
|
||||
|
||||
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
"""Get the conversation service implementation."""
|
||||
impl = ConversationServiceImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class ConversationServiceImpl(Conversations):
|
||||
"""Built-in conversation service implementation using AuthorizedSqlStore."""
|
||||
|
||||
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
self.policy = config.policy
|
||||
|
||||
base_sql_store = sqlstore_impl(config.conversations_store)
|
||||
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the store and create tables."""
|
||||
if isinstance(self.config.conversations_store, SqliteSqlStoreConfig):
|
||||
os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True)
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"openai_conversations",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"items": ColumnType.JSON,
|
||||
"metadata": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"conversation_items",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"conversation_id": ColumnType.STRING,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"item_data": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
async def create_conversation(
|
||||
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||
) -> Conversation:
|
||||
"""Create a conversation."""
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
conversation_id = f"conv_{random_bytes.hex()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
record_data = {
|
||||
"id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"items": [],
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
await self.sql_store.insert(
|
||||
table="openai_conversations",
|
||||
data=record_data,
|
||||
)
|
||||
|
||||
if items:
|
||||
item_records = []
|
||||
for item in items:
|
||||
item_dict = item.model_dump()
|
||||
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||
|
||||
item_record = {
|
||||
"id": item_id,
|
||||
"conversation_id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
item_records.append(item_record)
|
||||
|
||||
await self.sql_store.insert(table="conversation_items", data=item_records)
|
||||
|
||||
conversation = Conversation(
|
||||
id=conversation_id,
|
||||
created_at=created_at,
|
||||
metadata=metadata,
|
||||
object="conversation",
|
||||
)
|
||||
|
||||
logger.info(f"Created conversation {conversation_id}")
|
||||
return conversation
|
||||
|
||||
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Get a conversation with the given ID."""
|
||||
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Conversation {conversation_id} not found")
|
||||
|
||||
return Conversation(
|
||||
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
|
||||
)
|
||||
|
||||
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||
"""Update a conversation's metadata with the given ID"""
|
||||
await self.sql_store.update(
|
||||
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
|
||||
)
|
||||
|
||||
return await self.get_conversation(conversation_id)
|
||||
|
||||
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||
"""Delete a conversation with the given ID."""
|
||||
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
|
||||
|
||||
logger.info(f"Deleted conversation {conversation_id}")
|
||||
return ConversationDeletedResource(id=conversation_id)
|
||||
|
||||
def _validate_conversation_id(self, conversation_id: str) -> None:
|
||||
"""Validate conversation ID format."""
|
||||
if not conversation_id.startswith("conv_"):
|
||||
raise ValueError(
|
||||
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
|
||||
)
|
||||
|
||||
def _get_or_generate_item_id(self, item: ConversationItem, item_dict: dict) -> str:
|
||||
"""Get existing item ID or generate one if missing."""
|
||||
if item.id is None:
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
if item.type == "message":
|
||||
item_id = f"msg_{random_bytes.hex()}"
|
||||
else:
|
||||
item_id = f"item_{random_bytes.hex()}"
|
||||
item_dict["id"] = item_id
|
||||
return item_id
|
||||
return item.id
|
||||
|
||||
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Validate conversation ID and return the conversation if it exists."""
|
||||
self._validate_conversation_id(conversation_id)
|
||||
return await self.get_conversation(conversation_id)
|
||||
|
||||
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||
"""Create (add) items to a conversation."""
|
||||
await self._get_validated_conversation(conversation_id)
|
||||
|
||||
created_items = []
|
||||
created_at = int(time.time())
|
||||
|
||||
for item in items:
|
||||
item_dict = item.model_dump()
|
||||
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||
|
||||
item_record = {
|
||||
"id": item_id,
|
||||
"conversation_id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
|
||||
try:
|
||||
await self.sql_store.insert(table="conversation_items", data=item_record)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
table="conversation_items",
|
||||
data={"created_at": created_at, "item_data": item_dict},
|
||||
where={"id": item_id},
|
||||
)
|
||||
|
||||
created_items.append(item_dict)
|
||||
|
||||
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
|
||||
|
||||
# Convert created items (dicts) to proper ConversationItem types
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
|
||||
|
||||
return ConversationItemList(
|
||||
data=response_items,
|
||||
first_id=created_items[0]["id"] if created_items else None,
|
||||
last_id=created_items[-1]["id"] if created_items else None,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||
"""Retrieve a conversation item."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
if not item_id:
|
||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||
|
||||
# Get item from conversation_items table
|
||||
record = await self.sql_store.fetch_one(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
return adapter.validate_python(record["item_data"])
|
||||
|
||||
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
|
||||
"""List items in the conversation."""
|
||||
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
|
||||
records = result.data
|
||||
|
||||
if order != NOT_GIVEN and order == "asc":
|
||||
records.sort(key=lambda x: x["created_at"])
|
||||
else:
|
||||
records.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
actual_limit = 20
|
||||
if limit != NOT_GIVEN and isinstance(limit, int):
|
||||
actual_limit = limit
|
||||
|
||||
records = records[:actual_limit]
|
||||
items = [record["item_data"] for record in records]
|
||||
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items]
|
||||
|
||||
first_id = response_items[0].id if response_items else None
|
||||
last_id = response_items[-1].id if response_items else None
|
||||
|
||||
return ConversationItemList(
|
||||
data=response_items,
|
||||
first_id=first_id,
|
||||
last_id=last_id,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
async def openai_delete_conversation_item(
|
||||
self, conversation_id: str, item_id: str
|
||||
) -> ConversationItemDeletedResource:
|
||||
"""Delete a conversation item."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
if not item_id:
|
||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||
|
||||
_ = await self._get_validated_conversation(conversation_id)
|
||||
|
||||
record = await self.sql_store.fetch_one(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||
|
||||
await self.sql_store.delete(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
|
||||
return ConversationItemDeletedResource(id=item_id)
|
||||
|
|
@ -475,6 +475,13 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
|
|||
If not specified, a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
conversations_store: SqlStoreConfig | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Configuration for the persistence store used by the conversations API.
|
||||
If not specified, a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
models: list[ModelInput] = Field(default_factory=list)
|
||||
shields: list[ShieldInput] = Field(default_factory=list)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from llama_stack.providers.datatypes import (
|
|||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts}
|
||||
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
|
||||
|
||||
|
||||
def stack_apis() -> list[Api]:
|
||||
|
|
@ -243,6 +243,7 @@ def get_external_providers_from_module(
|
|||
spec = module.get_provider_spec()
|
||||
else:
|
||||
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
spec = ProviderSpec(
|
||||
api=Api(provider_api),
|
||||
provider_type=provider.provider_type,
|
||||
|
|
@ -251,9 +252,20 @@ def get_external_providers_from_module(
|
|||
config_class="",
|
||||
)
|
||||
provider_type = provider.provider_type
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
# return a partially filled out spec that the build script will populate.
|
||||
registry[Api(provider_api)][provider_type] = spec
|
||||
if isinstance(spec, list):
|
||||
# optionally allow people to pass inline and remote provider specs as a returned list.
|
||||
# with the old method, users could pass in directories of specs using overlapping code
|
||||
# we want to ensure we preserve that flexibility in this method.
|
||||
logger.info(
|
||||
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
|
||||
)
|
||||
for provider_spec in spec:
|
||||
if provider_spec.provider_type != provider.provider_type:
|
||||
continue
|
||||
logger.info(f"Adding {provider.provider_type} to registry")
|
||||
registry[Api(provider_api)][provider.provider_type] = provider_spec
|
||||
else:
|
||||
registry[Api(provider_api)][provider_type] = spec
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
||||
|
|
|
|||
|
|
@ -374,6 +374,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
||||
# Merge extra_json parameters (extra_body from SDK is converted to extra_json)
|
||||
if hasattr(options, "extra_json") and options.extra_json:
|
||||
body |= options.extra_json
|
||||
|
||||
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Any
|
|||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batches import Batches
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.datatypes import ExternalApiSpec
|
||||
|
|
@ -96,6 +97,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
Api.tool_runtime: ToolRuntime,
|
||||
Api.files: Files,
|
||||
Api.prompts: Prompts,
|
||||
Api.conversations: Conversations,
|
||||
}
|
||||
|
||||
if external_apis:
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ from llama_stack.apis.inference import (
|
|||
CompletionMessage,
|
||||
Inference,
|
||||
ListOpenAIChatCompletionResponse,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
|
|
@ -34,12 +33,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
Order,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
|
|
@ -193,102 +187,6 @@ class InferenceRouter(Inference):
|
|||
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
||||
return model
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = None,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id, ModelType.llm)
|
||||
if tool_config:
|
||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||
if (
|
||||
tool_prompt_format
|
||||
and tool_prompt_format != tool_config.tool_prompt_format
|
||||
):
|
||||
raise ValueError(
|
||||
"tool_prompt_format and tool_config.tool_prompt_format must match"
|
||||
)
|
||||
else:
|
||||
params = {}
|
||||
if tool_choice:
|
||||
params["tool_choice"] = tool_choice
|
||||
if tool_prompt_format:
|
||||
params["tool_prompt_format"] = tool_prompt_format
|
||||
tool_config = ToolConfig(**params)
|
||||
|
||||
tools = tools or []
|
||||
if tool_config.tool_choice == ToolChoice.none:
|
||||
tools = []
|
||||
elif tool_config.tool_choice == ToolChoice.auto:
|
||||
pass
|
||||
elif tool_config.tool_choice == ToolChoice.required:
|
||||
pass
|
||||
else:
|
||||
# verify tool_choice is one of the tools
|
||||
tool_names = [
|
||||
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
|
||||
for t in tools
|
||||
]
|
||||
if tool_config.tool_choice not in tool_names:
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
|
||||
)
|
||||
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(
|
||||
messages, tool_config.tool_prompt_format
|
||||
)
|
||||
|
||||
if stream:
|
||||
response_stream = await provider.chat_completion(**params)
|
||||
return self.stream_tokens_and_compute_metrics(
|
||||
response=response_stream,
|
||||
prompt_tokens=prompt_tokens,
|
||||
model=model,
|
||||
tool_prompt_format=tool_config.tool_prompt_format,
|
||||
)
|
||||
|
||||
response = await provider.chat_completion(**params)
|
||||
metrics = await self.count_tokens_and_compute_metrics(
|
||||
response=response,
|
||||
prompt_tokens=prompt_tokens,
|
||||
model=model,
|
||||
tool_prompt_format=tool_config.tool_prompt_format,
|
||||
)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics
|
||||
if not hasattr(response, "metrics") or response.metrics is None
|
||||
else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import functools
|
||||
|
|
@ -12,7 +11,6 @@ import inspect
|
|||
import json
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
|
@ -35,7 +33,6 @@ from pydantic import BaseModel, ValidationError
|
|||
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
|
|
@ -55,7 +52,6 @@ from llama_stack.core.stack import (
|
|||
Stack,
|
||||
cast_image_name_to_string,
|
||||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
|
|
@ -333,23 +329,18 @@ class ClientVersionMiddleware:
|
|||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def create_app(
|
||||
config_file: str | None = None,
|
||||
env_vars: list[str] | None = None,
|
||||
) -> StackApp:
|
||||
def create_app() -> StackApp:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
Args:
|
||||
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
|
||||
env_vars: List of environment variables in KEY=value format.
|
||||
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
|
||||
This factory function reads configuration from environment variables:
|
||||
- LLAMA_STACK_CONFIG: Path to config file (required)
|
||||
|
||||
Returns:
|
||||
Configured StackApp instance.
|
||||
"""
|
||||
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
|
||||
config_file = os.getenv("LLAMA_STACK_CONFIG")
|
||||
if config_file is None:
|
||||
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
|
||||
raise ValueError("LLAMA_STACK_CONFIG environment variable is required")
|
||||
|
||||
config_file = resolve_config_or_distro(config_file, Mode.RUN)
|
||||
|
||||
|
|
@ -361,16 +352,6 @@ def create_app(
|
|||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||
|
||||
if env_vars:
|
||||
for env_pair in env_vars:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
|
||||
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
|
|
@ -451,6 +432,7 @@ def create_app(
|
|||
apis_to_serve.add("inspect")
|
||||
apis_to_serve.add("providers")
|
||||
apis_to_serve.add("prompts")
|
||||
apis_to_serve.add("conversations")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
|
|
@ -493,101 +475,6 @@ def create_app(
|
|||
return app
|
||||
|
||||
|
||||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
|
||||
add_config_distro_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||
)
|
||||
|
||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||
# parsed from the command line
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
config_or_distro = get_config_from_args(args)
|
||||
|
||||
try:
|
||||
app = create_app(
|
||||
config_file=config_or_distro,
|
||||
env_vars=args.env,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating app: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
import uvicorn
|
||||
|
||||
# Configure SSL if certificates are provided
|
||||
port = args.port or config.server.port
|
||||
|
||||
ssl_config = None
|
||||
keyfile = config.server.tls_keyfile
|
||||
certfile = config.server.tls_certfile
|
||||
|
||||
if keyfile and certfile:
|
||||
ssl_config = {
|
||||
"ssl_keyfile": keyfile,
|
||||
"ssl_certfile": certfile,
|
||||
}
|
||||
if config.server.tls_cafile:
|
||||
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
|
||||
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
||||
logger.info(
|
||||
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
listen_host = config.server.host or ["::", "0.0.0.0"]
|
||||
logger.info(f"Listening on {listen_host}:{port}")
|
||||
|
||||
uvicorn_config = {
|
||||
"app": app,
|
||||
"host": listen_host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
}
|
||||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
# stack trace when using Ctrl+C or kill -2 (SIGINT).
|
||||
# SIGTERM (kill -15) works fine without this because Python doesn't
|
||||
# have a default handler for it.
|
||||
#
|
||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
"""Logs the run config with redacted fields and disabled providers removed."""
|
||||
logger.info("Run configuration:")
|
||||
|
|
@ -614,7 +501,3 @@ def remove_disabled_providers(obj):
|
|||
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import yaml
|
|||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
|
|
@ -34,6 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
|
|||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
|
|
@ -73,6 +75,7 @@ class LlamaStack(
|
|||
RAGToolRuntime,
|
||||
Files,
|
||||
Prompts,
|
||||
Conversations,
|
||||
):
|
||||
pass
|
||||
|
||||
|
|
@ -312,6 +315,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
|||
)
|
||||
impls[Api.prompts] = prompts_impl
|
||||
|
||||
conversations_impl = ConversationServiceImpl(
|
||||
ConversationServiceConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.conversations] = conversations_impl
|
||||
|
||||
|
||||
class Stack:
|
||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||
|
|
@ -342,6 +351,8 @@ class Stack:
|
|||
|
||||
if Api.prompts in impls:
|
||||
await impls[Api.prompts].initialize()
|
||||
if Api.conversations in impls:
|
||||
await impls[Api.conversations].initialize()
|
||||
|
||||
await register_resources(self.run_config, impls)
|
||||
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ if [[ "$env_type" == "venv" ]]; then
|
|||
yaml_config_arg=""
|
||||
fi
|
||||
|
||||
$PYTHON_BINARY -m llama_stack.core.server.server \
|
||||
llama stack run \
|
||||
$yaml_config_arg \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ def strip_rich_markup(text):
|
|||
|
||||
class CustomRichHandler(RichHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["console"] = Console(width=150)
|
||||
kwargs["console"] = Console()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def emit(self, record):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from pathlib import Path
|
|||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, "tokenizer_utils")
|
||||
logger = get_logger(__name__, "models")
|
||||
|
||||
|
||||
def load_bpe_file(model_path: Path) -> dict[bytes, int]:
|
||||
|
|
|
|||
|
|
@ -329,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shields: list | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input,
|
||||
|
|
@ -342,6 +343,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools,
|
||||
include,
|
||||
max_infer_iters,
|
||||
shields,
|
||||
)
|
||||
|
||||
async def list_openai_responses(
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import time
|
|||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
|
|
@ -26,12 +26,16 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.responses.responses_store import (
|
||||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
from .tool_executor import ToolExecutor
|
||||
|
|
@ -72,26 +76,48 @@ class OpenAIResponsesImpl:
|
|||
async def _prepend_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_response_id: str | None = None,
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages,
|
||||
):
|
||||
new_input_items = previous_response.input.copy()
|
||||
new_input_items.extend(previous_response.output)
|
||||
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
|
||||
return new_input_items
|
||||
|
||||
async def _process_input_with_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_response_id: str | None,
|
||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
|
||||
"""Process input with optional previous response context.
|
||||
|
||||
Returns:
|
||||
tuple: (all_input for storage, messages for chat completion)
|
||||
"""
|
||||
if previous_response_id:
|
||||
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
|
||||
await self.responses_store.get_response_object(previous_response_id)
|
||||
)
|
||||
all_input = await self._prepend_previous_response(input, previous_response)
|
||||
|
||||
# previous response input items
|
||||
new_input_items = previous_response_with_input.input
|
||||
|
||||
# previous response output items
|
||||
new_input_items.extend(previous_response_with_input.output)
|
||||
|
||||
# new input items from the current request
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
if previous_response.messages:
|
||||
# Use stored messages directly and convert only new input
|
||||
message_adapter = TypeAdapter(list[OpenAIMessageParam])
|
||||
messages = message_adapter.validate_python(previous_response.messages)
|
||||
new_messages = await convert_response_input_to_chat_messages(input)
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
# Backward compatibility: reconstruct from inputs
|
||||
messages = await convert_response_input_to_chat_messages(all_input)
|
||||
else:
|
||||
all_input = input
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
|
||||
input = new_input_items
|
||||
|
||||
return input
|
||||
return all_input, messages
|
||||
|
||||
async def _prepend_instructions(self, messages, instructions):
|
||||
if instructions:
|
||||
|
|
@ -102,7 +128,7 @@ class OpenAIResponsesImpl:
|
|||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
response_with_input = await self.responses_store.get_response_object(response_id)
|
||||
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
|
||||
return response_with_input.to_response_object()
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
|
|
@ -138,6 +164,7 @@ class OpenAIResponsesImpl:
|
|||
self,
|
||||
response: OpenAIResponseObject,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
|
|
@ -165,6 +192,7 @@ class OpenAIResponsesImpl:
|
|||
await self.responses_store.store_response_object(
|
||||
response_object=response,
|
||||
input=input_items_data,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
async def create_openai_response(
|
||||
|
|
@ -180,10 +208,15 @@ class OpenAIResponsesImpl:
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shields: list | None = None,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
# Shields parameter received via extra_body - not yet implemented
|
||||
if shields is not None:
|
||||
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
model=model,
|
||||
|
|
@ -224,8 +257,7 @@ class OpenAIResponsesImpl:
|
|||
max_infer_iters: int | None = 10,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Structured outputs
|
||||
|
|
@ -265,7 +297,8 @@ class OpenAIResponsesImpl:
|
|||
if store and final_response:
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=input,
|
||||
input=all_input,
|
||||
messages=orchestrator.final_messages,
|
||||
)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
OpenAIMessageParam,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
|
@ -94,6 +95,8 @@ class StreamingResponseOrchestrator:
|
|||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
# Track final messages after all tool executions
|
||||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Initialize output messages
|
||||
|
|
@ -183,6 +186,8 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
messages = next_turn_messages
|
||||
|
||||
self.final_messages = messages.copy() + [current_response.choices[0].message]
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
created_at=self.created_at,
|
||||
|
|
|
|||
|
|
@ -5,37 +5,17 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
InferenceProvider,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TokenLogProbs,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -53,13 +33,6 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
convert_request_to_raw,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generators import LlamaGenerator
|
||||
|
|
@ -76,7 +49,6 @@ def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_
|
|||
|
||||
|
||||
class MetaReferenceInferenceImpl(
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
|
|
@ -161,10 +133,10 @@ class MetaReferenceInferenceImpl(
|
|||
self.llama_model = llama_model
|
||||
|
||||
log.info("Warming up...")
|
||||
await self.chat_completion(
|
||||
model_id=model_id,
|
||||
messages=[UserMessage(content="Hi how are you?")],
|
||||
sampling_params=SamplingParams(max_tokens=20),
|
||||
await self.openai_chat_completion(
|
||||
model=model_id,
|
||||
messages=[{"role": "user", "content": "Hi how are you?"}],
|
||||
max_tokens=20,
|
||||
)
|
||||
log.info("Warmed up!")
|
||||
|
||||
|
|
@ -176,242 +148,30 @@ class MetaReferenceInferenceImpl(
|
|||
elif request.model != self.model_id:
|
||||
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
||||
|
||||
async def chat_completion(
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config or ToolConfig(),
|
||||
)
|
||||
self.check_model(request)
|
||||
|
||||
# augment and rewrite messages depending on the model
|
||||
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
||||
# download media and convert to raw content so we can send it to the model
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
|
||||
if request.stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
results = await self._nonstream_chat_completion([request])
|
||||
return results[0]
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request_batch: list[ChatCompletionRequest]
|
||||
) -> list[ChatCompletionResponse]:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
first_request = request_batch[0]
|
||||
|
||||
class ItemState(BaseModel):
|
||||
tokens: list[int] = []
|
||||
logprobs: list[TokenLogProbs] = []
|
||||
stop_reason: StopReason | None = None
|
||||
finished: bool = False
|
||||
|
||||
def impl():
|
||||
states = [ItemState() for _ in request_batch]
|
||||
|
||||
for token_results in self.generator.chat_completion(request_batch):
|
||||
first = token_results[0]
|
||||
if not first.finished and not first.ignore_token:
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
|
||||
cprint(first.text, color="cyan", end="", file=sys.stderr)
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
|
||||
|
||||
for result in token_results:
|
||||
idx = result.batch_idx
|
||||
state = states[idx]
|
||||
if state.finished or result.ignore_token:
|
||||
continue
|
||||
|
||||
state.finished = result.finished
|
||||
if first_request.logprobs:
|
||||
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||
|
||||
state.tokens.append(result.token)
|
||||
if result.token == tokenizer.eot_id:
|
||||
state.stop_reason = StopReason.end_of_turn
|
||||
elif result.token == tokenizer.eom_id:
|
||||
state.stop_reason = StopReason.end_of_message
|
||||
|
||||
results = []
|
||||
for state in states:
|
||||
if state.stop_reason is None:
|
||||
state.stop_reason = StopReason.out_of_tokens
|
||||
|
||||
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
|
||||
results.append(
|
||||
ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=state.logprobs if first_request.logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
return impl()
|
||||
else:
|
||||
return impl()
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
ipython = False
|
||||
|
||||
for token_results in self.generator.chat_completion([request]):
|
||||
token_result = token_results[0]
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||
cprint(token_result.text, color="cyan", end="", file=sys.stderr)
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
|
||||
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
if ipython:
|
||||
delta = ToolCallDelta(
|
||||
tool_call=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
else:
|
||||
delta = TextDelta(text=text)
|
||||
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call="",
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
for x in impl():
|
||||
yield x
|
||||
else:
|
||||
for x in impl():
|
||||
yield x
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
|
||||
|
|
|
|||
|
|
@ -4,21 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
InferenceProvider,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAICompletion
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
|
|
@ -69,21 +67,6 @@ class SentenceTransformersInferenceImpl(
|
|||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
raise ValueError("Sentence transformers don't support chat completion")
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
# Standard OpenAI completion parameters
|
||||
|
|
@ -110,6 +93,32 @@ class SentenceTransformersInferenceImpl(
|
|||
# for fill-in-the-middle type completion
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError(
|
||||
"OpenAI completion not supported by sentence transformers provider"
|
||||
)
|
||||
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")
|
||||
|
|
|
|||
|
|
@ -52,9 +52,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter_type="cerebras",
|
||||
provider_type="remote::cerebras",
|
||||
pip_packages=[
|
||||
"cerebras_cloud_sdk",
|
||||
],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.cerebras",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||
|
|
@ -169,7 +167,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter_type="openai",
|
||||
provider_type="remote::openai",
|
||||
pip_packages=["litellm"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.openai",
|
||||
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||
|
|
@ -179,7 +177,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter_type="anthropic",
|
||||
provider_type="remote::anthropic",
|
||||
pip_packages=["litellm"],
|
||||
pip_packages=["anthropic"],
|
||||
module="llama_stack.providers.remote.inference.anthropic",
|
||||
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||
|
|
@ -189,9 +187,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter_type="gemini",
|
||||
provider_type="remote::gemini",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.gemini",
|
||||
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||
|
|
@ -202,7 +198,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
adapter_type="vertexai",
|
||||
provider_type="remote::vertexai",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
"google-cloud-aiplatform",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.vertexai",
|
||||
|
|
@ -233,9 +228,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
adapter_type="groq",
|
||||
provider_type="remote::groq",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.groq",
|
||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
|
|
@ -245,7 +238,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
adapter_type="llama-openai-compat",
|
||||
provider_type="remote::llama-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||
|
|
@ -255,9 +248,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
adapter_type="sambanova",
|
||||
provider_type="remote::sambanova",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.sambanova",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||
|
|
@ -287,7 +278,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
provider_type="remote::azure",
|
||||
adapter_type="azure",
|
||||
pip_packages=["litellm"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.azure",
|
||||
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",
|
||||
|
|
|
|||
|
|
@ -500,7 +500,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
|||
api=Api.vector_io,
|
||||
adapter_type="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
pip_packages=["weaviate-client"],
|
||||
pip_packages=["weaviate-client>=4.16.5"],
|
||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,6 @@ from .config import AnthropicConfig
|
|||
async def get_adapter_impl(config: AnthropicConfig, _deps):
|
||||
from .anthropic import AnthropicInferenceAdapter
|
||||
|
||||
impl = AnthropicInferenceAdapter(config)
|
||||
impl = AnthropicInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,13 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from collections.abc import Iterable
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AnthropicConfig
|
||||
|
||||
|
||||
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
class AnthropicInferenceAdapter(OpenAIMixin):
|
||||
config: AnthropicConfig
|
||||
|
||||
provider_data_api_key_field: str = "anthropic_api_key"
|
||||
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
|
||||
# TODO: add support for voyageai, which is where these models are hosted
|
||||
# embedding_model_metadata = {
|
||||
|
|
@ -23,22 +29,11 @@ class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# }
|
||||
|
||||
def __init__(self, config: AnthropicConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="anthropic",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="anthropic_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://api.anthropic.com/v1"
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ class AnthropicProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class AnthropicConfig(BaseModel):
|
||||
class AnthropicConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Anthropic models",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,6 @@ from .config import AzureConfig
|
|||
async def get_adapter_impl(config: AzureConfig, _deps):
|
||||
from .azure import AzureInferenceAdapter
|
||||
|
||||
impl = AzureInferenceAdapter(config)
|
||||
impl = AzureInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,31 +4,20 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||
LiteLLMOpenAIMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AzureConfig
|
||||
|
||||
|
||||
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: AzureConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="azure",
|
||||
api_key_from_config=config.api_key.get_secret_value(),
|
||||
provider_data_api_key_field="azure_api_key",
|
||||
openai_compat_api_base=str(config.api_base),
|
||||
)
|
||||
self.config = config
|
||||
class AzureInferenceAdapter(OpenAIMixin):
|
||||
config: AzureConfig
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
provider_data_api_key_field: str = "azure_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
|
|
@ -37,26 +26,3 @@ class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
Returns the Azure API base URL from the configuration.
|
||||
"""
|
||||
return urljoin(str(self.config.api_base), "/openai/v1")
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||
# Get base parameters from parent
|
||||
params = await super()._get_params(request)
|
||||
|
||||
# Add Azure specific parameters
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data:
|
||||
if getattr(provider_data, "azure_api_key", None):
|
||||
params["api_key"] = provider_data.azure_api_key
|
||||
if getattr(provider_data, "azure_api_base", None):
|
||||
params["api_base"] = provider_data.azure_api_base
|
||||
if getattr(provider_data, "azure_api_version", None):
|
||||
params["api_version"] = provider_data.azure_api_version
|
||||
if getattr(provider_data, "azure_api_type", None):
|
||||
params["api_type"] = provider_data.azure_api_type
|
||||
else:
|
||||
params["api_key"] = self.config.api_key.get_secret_value()
|
||||
params["api_base"] = str(self.config.api_base)
|
||||
params["api_version"] = self.config.api_version
|
||||
params["api_type"] = self.config.api_type
|
||||
|
||||
return params
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -30,7 +31,7 @@ class AzureProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class AzureConfig(BaseModel):
|
||||
class AzureConfig(RemoteInferenceProviderConfig):
|
||||
api_key: SecretStr = Field(
|
||||
description="Azure API key for Azure",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,39 +5,30 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from botocore.client import BaseClient
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAICompletion
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
get_sampling_strategy_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
|
|
@ -86,7 +77,6 @@ def _to_inference_profile_id(model_id: str, region: str = None) -> str:
|
|||
class BedrockInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
|
@ -106,71 +96,6 @@ class BedrockInferenceAdapter(
|
|||
if self._client is not None:
|
||||
self._client.close()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params_for_chat_completion(request)
|
||||
res = self.client.invoke_model(**params)
|
||||
chunk = next(res["body"])
|
||||
result = json.loads(chunk.decode("utf-8"))
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=result["stop_reason"],
|
||||
text=result["generation"],
|
||||
)
|
||||
|
||||
response = OpenAICompatCompletionResponse(choices=[choice])
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params_for_chat_completion(request)
|
||||
res = self.client.invoke_model_with_response_stream(**params)
|
||||
event_stream = res["body"]
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
for chunk in event_stream:
|
||||
chunk = chunk["chunk"]["bytes"]
|
||||
result = json.loads(chunk.decode("utf-8"))
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=result["stop_reason"],
|
||||
text=result["generation"],
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(choices=[choice])
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
|
||||
bedrock_model = request.model
|
||||
|
||||
|
|
@ -235,3 +160,31 @@ class BedrockInferenceAdapter(
|
|||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ async def get_adapter_impl(config: CerebrasImplConfig, _deps):
|
|||
|
||||
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = CerebrasInferenceAdapter(config)
|
||||
impl = CerebrasInferenceAdapter(config=config)
|
||||
|
||||
await impl.initialize()
|
||||
|
||||
|
|
|
|||
|
|
@ -4,53 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from cerebras.cloud.sdk import AsyncCerebras
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
TopKSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import CerebrasImplConfig
|
||||
|
||||
|
||||
class CerebrasInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
):
|
||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
# TODO: make this use provider data, etc. like other providers
|
||||
self._cerebras_client = AsyncCerebras(
|
||||
base_url=self.config.base_url,
|
||||
api_key=self.config.api_key.get_secret_value(),
|
||||
)
|
||||
class CerebrasInferenceAdapter(OpenAIMixin):
|
||||
config: CerebrasImplConfig
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
|
@ -58,86 +21,6 @@ class CerebrasInferenceAdapter(
|
|||
def get_base_url(self) -> str:
|
||||
return urljoin(self.config.base_url, "v1")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
|
||||
r = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
if request.sampling_params and isinstance(
|
||||
request.sampling_params.strategy, TopKSamplingStrategy
|
||||
):
|
||||
raise ValueError("`top_k` not supported by Cerebras")
|
||||
|
||||
prompt = ""
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
prompt = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown request type {type(request)}")
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
"prompt": prompt,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
|||
|
|
@ -7,21 +7,22 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CerebrasImplConfig(BaseModel):
|
||||
class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
||||
base_url: str = Field(
|
||||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||
description="Base URL for the Cerebras API",
|
||||
)
|
||||
api_key: SecretStr = Field(
|
||||
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
|
||||
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
|
||||
description="Cerebras API Key",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,6 @@ async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
|||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = DatabricksInferenceAdapter(config)
|
||||
impl = DatabricksInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -6,19 +6,20 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatabricksImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
)
|
||||
api_token: SecretStr = Field(
|
||||
default=SecretStr(None),
|
||||
default=SecretStr(None), # type: ignore[arg-type]
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,27 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
Model,
|
||||
OpenAICompletion,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.inference import OpenAICompletion
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
|
@ -33,30 +18,31 @@ from .config import DatabricksImplConfig
|
|||
logger = get_logger(name=__name__, category="inference::databricks")
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
Inference,
|
||||
):
|
||||
class DatabricksInferenceAdapter(OpenAIMixin):
|
||||
config: DatabricksImplConfig
|
||||
|
||||
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
|
||||
embedding_model_metadata = {
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
|
||||
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_token.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/serving-endpoints"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [
|
||||
endpoint.name
|
||||
for endpoint in WorkspaceClient(
|
||||
host=self.config.url, token=self.get_api_key()
|
||||
).serving_endpoints.list() # TODO: this is not async
|
||||
]
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
|
|
@ -82,47 +68,3 @@ class DatabricksInferenceAdapter(
|
|||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._model_cache = {} # from OpenAIMixin
|
||||
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
|
||||
endpoints = ws_client.serving_endpoints.list()
|
||||
for endpoint in endpoints:
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__,
|
||||
provider_resource_id=endpoint.name,
|
||||
identifier=endpoint.name,
|
||||
)
|
||||
if endpoint.task == "llm/v1/chat":
|
||||
model.model_type = ModelType.llm # this is redundant, but informative
|
||||
elif endpoint.task == "llm/v1/embeddings":
|
||||
if endpoint.name not in self.embedding_model_metadata:
|
||||
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
|
||||
continue
|
||||
model.model_type = ModelType.embedding
|
||||
model.metadata = self.embedding_model_metadata[endpoint.name]
|
||||
else:
|
||||
logger.warning(f"Unknown model type, skipping: {endpoint}")
|
||||
continue
|
||||
|
||||
self._model_cache[endpoint.name] = model
|
||||
|
||||
return list(self._model_cache.values())
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -17,6 +17,6 @@ async def get_adapter_impl(config: FireworksImplConfig, _deps):
|
|||
from .fireworks import FireworksInferenceAdapter
|
||||
|
||||
assert isinstance(config, FireworksImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = FireworksInferenceAdapter(config)
|
||||
impl = FireworksInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,195 +4,27 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
embedding_model_metadata = {
|
||||
class FireworksInferenceAdapter(OpenAIMixin):
|
||||
config: FireworksImplConfig
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
|
||||
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
|
||||
}
|
||||
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self)
|
||||
self.config = config
|
||||
self.allowed_models = config.allowed_models
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
provider_data_api_key_field: str = "fireworks_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
return config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.fireworks_api_key:
|
||||
raise ValueError(
|
||||
'Pass Fireworks API Key in the header X-LlamaStack-Provider-Data as { "fireworks_api_key": <your api key>}'
|
||||
)
|
||||
return provider_data.fireworks_api_key
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else None # type: ignore[return-value]
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return "https://api.fireworks.ai/inference/v1"
|
||||
|
||||
def _get_client(self) -> Fireworks:
|
||||
fireworks_api_key = self.get_api_key()
|
||||
return Fireworks(api_key=fireworks_api_key)
|
||||
|
||||
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
|
||||
"""Remove BOS token as Fireworks automatically prepends it"""
|
||||
if prompt.startswith("<|begin_of_text|>"):
|
||||
return prompt[len("<|begin_of_text|>") :]
|
||||
return prompt
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
if "messages" in params:
|
||||
r = await self._get_client().chat.completions.acreate(**params)
|
||||
else:
|
||||
r = await self._get_client().completion.acreate(**params)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
async def _to_async_generator():
|
||||
if "messages" in params:
|
||||
stream = self._get_client().chat.completions.acreate(**params)
|
||||
else:
|
||||
stream = self._get_client().completion.acreate(**params)
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
def _build_options(
|
||||
self,
|
||||
sampling_params: SamplingParams | None,
|
||||
fmt: ResponseFormat | None,
|
||||
logprobs: LogProbConfig | None,
|
||||
) -> dict:
|
||||
options = get_sampling_options(sampling_params)
|
||||
options.setdefault("max_tokens", 512)
|
||||
|
||||
if fmt:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": fmt.json_schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
options["response_format"] = {
|
||||
"type": "grammar",
|
||||
"grammar": fmt.bnf,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
if logprobs and logprobs.top_k:
|
||||
options["logprobs"] = logprobs.top_k
|
||||
if options["logprobs"] <= 0 or options["logprobs"] >= 5:
|
||||
raise ValueError("Required range: 0 < top_k < 5")
|
||||
|
||||
return options
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
# TODO: tools are never added to the request, so we need to add them here
|
||||
if media_present or not llama_model:
|
||||
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if "prompt" in input_dict:
|
||||
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||
|
||||
params = {
|
||||
"model": request.model,
|
||||
**input_dict,
|
||||
"stream": bool(request.stream),
|
||||
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
||||
}
|
||||
logger.debug(f"params to fireworks: {params}")
|
||||
|
||||
return params
|
||||
|
|
|
|||
|
|
@ -10,6 +10,6 @@ from .config import GeminiConfig
|
|||
async def get_adapter_impl(config: GeminiConfig, _deps):
|
||||
from .gemini import GeminiInferenceAdapter
|
||||
|
||||
impl = GeminiInferenceAdapter(config)
|
||||
impl = GeminiInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ class GeminiProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class GeminiConfig(BaseModel):
|
||||
class GeminiConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Gemini models",
|
||||
|
|
|
|||
|
|
@ -4,33 +4,21 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import GeminiConfig
|
||||
|
||||
|
||||
class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
embedding_model_metadata = {
|
||||
class GeminiInferenceAdapter(OpenAIMixin):
|
||||
config: GeminiConfig
|
||||
|
||||
provider_data_api_key_field: str = "gemini_api_key"
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
||||
}
|
||||
|
||||
def __init__(self, config: GeminiConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="gemini",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="gemini_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
|
|
|
|||
|
|
@ -11,5 +11,5 @@ async def get_adapter_impl(config: GroqConfig, _deps):
|
|||
# import dynamically so the import is used only when it is needed
|
||||
from .groq import GroqInferenceAdapter
|
||||
|
||||
adapter = GroqInferenceAdapter(config)
|
||||
adapter = GroqInferenceAdapter(config=config)
|
||||
return adapter
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ class GroqProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class GroqConfig(BaseModel):
|
||||
class GroqConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
# The Groq client library loads the GROQ_API_KEY environment variable by default
|
||||
default=None,
|
||||
|
|
|
|||
|
|
@ -6,30 +6,16 @@
|
|||
|
||||
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
_config: GroqConfig
|
||||
class GroqInferenceAdapter(OpenAIMixin):
|
||||
config: GroqConfig
|
||||
|
||||
def __init__(self, config: GroqConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="groq",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
)
|
||||
self.config = config
|
||||
provider_data_api_key_field: str = "groq_api_key"
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/openai/v1"
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
|
|
|||
|
|
@ -4,14 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import LlamaCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> InferenceProvider:
|
||||
async def get_adapter_impl(config: LlamaCompatConfig, _deps):
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .llama import LlamaCompatInferenceAdapter
|
||||
|
||||
adapter = LlamaCompatInferenceAdapter(config)
|
||||
adapter = LlamaCompatInferenceAdapter(config=config)
|
||||
return adapter
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ class LlamaProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class LlamaCompatConfig(BaseModel):
|
||||
class LlamaCompatConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Llama API key",
|
||||
|
|
|
|||
|
|
@ -3,40 +3,26 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference.inference import OpenAICompletion, OpenAIEmbeddingsResponse
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
class LlamaCompatInferenceAdapter(OpenAIMixin):
|
||||
config: LlamaCompatConfig
|
||||
|
||||
provider_data_api_key_field: str = "llama_api_key"
|
||||
"""
|
||||
Llama API Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of ModelRegistryHelper.check_model_availability().
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists
|
||||
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
|
||||
"""
|
||||
|
||||
_config: LlamaCompatConfig
|
||||
|
||||
def __init__(self, config: LlamaCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="meta_llama",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="llama_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
|
|
@ -46,8 +32,37 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
"""
|
||||
return self.config.openai_compat_api_base
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
|
|||
|
||||
if not isinstance(config, NVIDIAConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
adapter = NVIDIAInferenceAdapter(config)
|
||||
adapter = NVIDIAInferenceAdapter(config=config)
|
||||
await adapter.initialize()
|
||||
return adapter
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,13 +7,14 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NVIDIAConfig(BaseModel):
|
||||
class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||
"""
|
||||
Configuration for the NVIDIA NIM inference endpoint.
|
||||
|
||||
|
|
|
|||
|
|
@ -4,44 +4,26 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from openai import NOT_GIVEN, APIConnectionError
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from . import NVIDIAConfig
|
||||
from .openai_utils import (
|
||||
convert_chat_completion_request,
|
||||
)
|
||||
from .utils import _is_nvidia_hosted
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||
config: NVIDIAConfig
|
||||
|
||||
"""
|
||||
NVIDIA Inference Adapter for Llama Stack.
|
||||
|
||||
|
|
@ -56,32 +38,21 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
"""
|
||||
|
||||
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||
embedding_model_metadata = {
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192},
|
||||
"nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024},
|
||||
"nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096},
|
||||
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
|
||||
}
|
||||
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
|
||||
|
||||
if _is_nvidia_hosted(config):
|
||||
if not config.api_key:
|
||||
if _is_nvidia_hosted(self.config):
|
||||
if not self.config.api_key:
|
||||
raise RuntimeError(
|
||||
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
|
||||
)
|
||||
# elif self._config.api_key:
|
||||
#
|
||||
# we don't raise this warning because a user may have deployed their
|
||||
# self-hosted NIM with an API key requirement.
|
||||
#
|
||||
# warnings.warn(
|
||||
# "API key is not required for self-hosted NVIDIA NIM. "
|
||||
# "Consider removing the api_key from the configuration."
|
||||
# )
|
||||
|
||||
self._config = config
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
|
|
@ -89,7 +60,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
|
||||
:return: The NVIDIA API key
|
||||
"""
|
||||
return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
|
|
@ -97,7 +68,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
|
||||
:return: The NVIDIA API base URL
|
||||
"""
|
||||
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
|
@ -149,49 +120,3 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
model=response.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if tool_prompt_format:
|
||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
|
||||
|
||||
# await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = await self._get_provider_model_id(model_id)
|
||||
request = await convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=provider_model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
),
|
||||
n=1,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
if stream:
|
||||
return convert_openai_chat_completion_stream(response, enable_incremental_tool_calls=False)
|
||||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return convert_openai_chat_completion_choice(response.choices[0])
|
||||
|
|
|
|||
|
|
@ -10,6 +10,6 @@ from .config import OllamaImplConfig
|
|||
async def get_adapter_impl(config: OllamaImplConfig, _deps):
|
||||
from .ollama import OllamaInferenceAdapter
|
||||
|
||||
impl = OllamaInferenceAdapter(config)
|
||||
impl = OllamaInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -6,12 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
class OllamaImplConfig(BaseModel):
|
||||
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
refresh_models: bool = Field(
|
||||
default=False,
|
||||
|
|
|
|||
|
|
@ -6,72 +6,29 @@
|
|||
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from ollama import AsyncClient as AsyncOllamaClient
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
GrammarResponseFormat,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
convert_image_content_to_url,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::ollama")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
ModelRegistryHelper,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
class OllamaInferenceAdapter(OpenAIMixin):
|
||||
config: OllamaImplConfig
|
||||
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
|
||||
embedding_model_metadata = {
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"all-minilm:l6-v2": {
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
|
|
@ -90,29 +47,8 @@ class OllamaInferenceAdapter(
|
|||
},
|
||||
}
|
||||
|
||||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
# TODO: remove ModelRegistryHelper.__init__ when completion and
|
||||
# chat_completion are. this exists to satisfy the input /
|
||||
# output processing for llama models. specifically,
|
||||
# tool_calling is handled by raw template processing,
|
||||
# instead of using the /api/chat endpoint w/ tools=...
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_entries=[
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:3b-instruct-fp16",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
],
|
||||
)
|
||||
self.config = config
|
||||
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
||||
self.download_images = True
|
||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||
download_images: bool = True
|
||||
_clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||
|
||||
@property
|
||||
def ollama_client(self) -> AsyncOllamaClient:
|
||||
|
|
@ -156,134 +92,6 @@ class OllamaInferenceAdapter(
|
|||
async def shutdown(self) -> None:
|
||||
self._clients.clear()
|
||||
|
||||
async def _get_model(self, model_id: str) -> Model:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set")
|
||||
return await self.model_store.get_model(model_id)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
if model.provider_resource_id is None:
|
||||
raise ValueError(f"Model {model_id} has no provider_resource_id set")
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
response_format=response_format,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
sampling_options = get_sampling_options(request.sampling_params)
|
||||
# This is needed since the Ollama API expects num_predict to be set
|
||||
# for early truncation instead of max_tokens.
|
||||
if sampling_options.get("max_tokens") is not None:
|
||||
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
||||
|
||||
input_dict: dict[str, Any] = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if media_present or not llama_model:
|
||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||
# flatten the list of lists
|
||||
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
||||
else:
|
||||
input_dict["raw"] = True
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request,
|
||||
llama_model,
|
||||
)
|
||||
|
||||
if fmt := request.response_format:
|
||||
if isinstance(fmt, JsonSchemaResponseFormat):
|
||||
input_dict["format"] = fmt.json_schema
|
||||
elif isinstance(fmt, GrammarResponseFormat):
|
||||
raise NotImplementedError("Grammar response format is not supported")
|
||||
else:
|
||||
raise ValueError(f"Unknown response format type: {fmt.type}")
|
||||
|
||||
params = {
|
||||
"model": request.model,
|
||||
**input_dict,
|
||||
"options": sampling_options,
|
||||
"stream": request.stream,
|
||||
}
|
||||
logger.debug(f"params to ollama: {params}")
|
||||
|
||||
return params
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
if "messages" in params:
|
||||
r = await self.ollama_client.chat(**params)
|
||||
else:
|
||||
r = await self.ollama_client.generate(**params)
|
||||
|
||||
if "message" in r:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r["done_reason"] if r["done"] else None,
|
||||
text=r["message"]["content"],
|
||||
)
|
||||
else:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r["done_reason"] if r["done"] else None,
|
||||
text=r["response"],
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
if "messages" in params:
|
||||
s = await self.ollama_client.chat(**params)
|
||||
else:
|
||||
s = await self.ollama_client.generate(**params)
|
||||
async for chunk in s:
|
||||
if "message" in chunk:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||
text=chunk["message"]["content"],
|
||||
)
|
||||
else:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||
text=chunk["response"],
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if await self.check_model_availability(model.provider_model_id):
|
||||
return model
|
||||
|
|
@ -295,24 +103,3 @@ class OllamaInferenceAdapter(
|
|||
return model
|
||||
|
||||
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||
async def _convert_content(content) -> dict:
|
||||
if isinstance(content, ImageContentItem):
|
||||
return {
|
||||
"role": message.role,
|
||||
"images": [await convert_image_content_to_url(content, download=True, include_format=False)],
|
||||
}
|
||||
else:
|
||||
text = content.text if isinstance(content, TextContentItem) else content
|
||||
assert isinstance(text, str)
|
||||
return {
|
||||
"role": message.role,
|
||||
"content": text,
|
||||
}
|
||||
|
||||
if isinstance(message.content, list):
|
||||
return [await _convert_content(c) for c in message.content]
|
||||
else:
|
||||
return [await _convert_content(message.content)]
|
||||
|
|
|
|||
|
|
@ -10,6 +10,6 @@ from .config import OpenAIConfig
|
|||
async def get_adapter_impl(config: OpenAIConfig, _deps):
|
||||
from .openai import OpenAIInferenceAdapter
|
||||
|
||||
impl = OpenAIInferenceAdapter(config)
|
||||
impl = OpenAIInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ class OpenAIProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIConfig(BaseModel):
|
||||
class OpenAIConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import OpenAIConfig
|
||||
|
|
@ -14,52 +13,24 @@ logger = get_logger(name=__name__, category="inference::openai")
|
|||
|
||||
|
||||
#
|
||||
# This OpenAI adapter implements Inference methods using two mixins -
|
||||
# This OpenAI adapter implements Inference methods using OpenAIMixin
|
||||
#
|
||||
# | Inference Method | Implementation Source |
|
||||
# |----------------------------|--------------------------|
|
||||
# | completion | LiteLLMOpenAIMixin |
|
||||
# | chat_completion | LiteLLMOpenAIMixin |
|
||||
# | embedding | LiteLLMOpenAIMixin |
|
||||
# | openai_completion | OpenAIMixin |
|
||||
# | openai_chat_completion | OpenAIMixin |
|
||||
# | openai_embeddings | OpenAIMixin |
|
||||
#
|
||||
class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
class OpenAIInferenceAdapter(OpenAIMixin):
|
||||
"""
|
||||
OpenAI Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of ModelRegistryHelper.check_model_availability().
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists
|
||||
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
|
||||
"""
|
||||
|
||||
embedding_model_metadata = {
|
||||
config: OpenAIConfig
|
||||
|
||||
provider_data_api_key_field: str = "openai_api_key"
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
|
||||
}
|
||||
|
||||
def __init__(self, config: OpenAIConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="openai",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="openai_api_key",
|
||||
)
|
||||
self.config = config
|
||||
# we set is_openai_compat so users can use the canonical
|
||||
# openai model names like "gpt-4" or "gpt-3.5-turbo"
|
||||
# and the model name will be translated to litellm's
|
||||
# "openai/gpt-4" or "openai/gpt-3.5-turbo" transparently.
|
||||
# if we do not set this, users will be exposed to the
|
||||
# litellm specific model names, an abstraction leak.
|
||||
self.is_openai_compat = True
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
|
|
@ -68,9 +39,3 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
Returns the OpenAI API base URL from the configuration.
|
||||
"""
|
||||
return self.config.base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
|
|
|
|||
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PassthroughImplConfig(BaseModel):
|
||||
class PassthroughImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default=None,
|
||||
description="The URL for the passthrough endpoint",
|
||||
|
|
|
|||
|
|
@ -4,33 +4,22 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack_client import AsyncLlamaStackClient
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.core.library_client import convert_pydantic_to_json_value, convert_to_pydantic
|
||||
from llama_stack.core.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
||||
|
|
@ -42,12 +31,6 @@ class PassthroughInferenceAdapter(Inference):
|
|||
ModelRegistryHelper.__init__(self)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
|
@ -85,76 +68,6 @@ class PassthroughInferenceAdapter(Inference):
|
|||
provider_data=provider_data,
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
# TODO: revisit this remove tool_calls from messages logic
|
||||
for message in messages:
|
||||
if hasattr(message, "tool_calls"):
|
||||
message.tool_calls = None
|
||||
|
||||
request_params = {
|
||||
"model_id": model.provider_resource_id,
|
||||
"messages": messages,
|
||||
"sampling_params": sampling_params,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
"tool_prompt_format": tool_prompt_format,
|
||||
"response_format": response_format,
|
||||
"stream": stream,
|
||||
"logprobs": logprobs,
|
||||
}
|
||||
|
||||
# only pass through the not None params
|
||||
request_params = {key: value for key, value in request_params.items() if value is not None}
|
||||
|
||||
# cast everything to json dict
|
||||
json_params = self.cast_value_to_json_dict(request_params)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(json_params)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(json_params)
|
||||
|
||||
async def _nonstream_chat_completion(self, json_params: dict[str, Any]) -> ChatCompletionResponse:
|
||||
client = self._get_client()
|
||||
response = await client.inference.chat_completion(**json_params)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=response.completion_message.content.text,
|
||||
stop_reason=response.completion_message.stop_reason,
|
||||
tool_calls=response.completion_message.tool_calls,
|
||||
),
|
||||
logprobs=response.logprobs,
|
||||
)
|
||||
|
||||
async def _stream_chat_completion(self, json_params: dict[str, Any]) -> AsyncGenerator:
|
||||
client = self._get_client()
|
||||
stream_response = await client.inference.chat_completion(**json_params)
|
||||
|
||||
async for chunk in stream_response:
|
||||
chunk = chunk.to_dict()
|
||||
|
||||
# temporary hack to remove the metrics from the response
|
||||
chunk["metrics"] = []
|
||||
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
|
||||
yield chunk
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
|||
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RunpodImplConfig(BaseModel):
|
||||
class RunpodImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
default=None,
|
||||
description="The URL for the Runpod model serving endpoint",
|
||||
|
|
|
|||
|
|
@ -3,9 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
||||
|
|
@ -16,9 +14,6 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
|
|
@ -54,7 +49,6 @@ MODEL_ENTRIES = [
|
|||
class RunpodInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: RunpodImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
|
|
@ -62,64 +56,6 @@ class RunpodInferenceAdapter(
|
|||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _to_async_generator():
|
||||
s = client.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
|
|
|
|||
|
|
@ -11,6 +11,6 @@ async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
|
|||
from .sambanova import SambaNovaInferenceAdapter
|
||||
|
||||
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = SambaNovaInferenceAdapter(config)
|
||||
impl = SambaNovaInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ class SambaNovaProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class SambaNovaImplConfig(BaseModel):
|
||||
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova AI server",
|
||||
|
|
|
|||
|
|
@ -5,39 +5,22 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import SambaNovaImplConfig
|
||||
|
||||
|
||||
class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
class SambaNovaInferenceAdapter(OpenAIMixin):
|
||||
config: SambaNovaImplConfig
|
||||
|
||||
provider_data_api_key_field: str = "sambanova_api_key"
|
||||
download_images: bool = True # SambaNova does not support image downloads server-size, perform them on the client
|
||||
"""
|
||||
SambaNova Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of LiteLLMOpenAIMixin.check_model_availability().
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the /v1/models to check if a model exists
|
||||
- LiteLLMOpenAIMixin.check_model_availability() checks the static registry within LiteLLM
|
||||
"""
|
||||
|
||||
def __init__(self, config: SambaNovaImplConfig):
|
||||
self.config = config
|
||||
self.environment_available_models: list[str] = []
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="sambanova",
|
||||
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
|
||||
provider_data_api_key_field="sambanova_api_key",
|
||||
openai_compat_api_base=self.config.url,
|
||||
download_images=True, # SambaNova requires base64 image encoding
|
||||
json_schema_strict=False, # SambaNova doesn't support strict=True yet
|
||||
)
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -7,11 +7,12 @@
|
|||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TGIImplConfig(BaseModel):
|
||||
class TGIImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
description="The URL for the TGI serving endpoint",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,68 +5,21 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import Iterable
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_model_input_info,
|
||||
)
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
log = get_logger(name=__name__, category="inference::tgi")
|
||||
|
||||
|
||||
def build_hf_repo_model_entries():
|
||||
return [
|
||||
build_hf_repo_model_entry(
|
||||
model.huggingface_repo,
|
||||
model.descriptor(),
|
||||
)
|
||||
for model in all_registered_models()
|
||||
if model.huggingface_repo
|
||||
]
|
||||
|
||||
|
||||
class _HfAdapter(
|
||||
OpenAIMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
class _HfAdapter(OpenAIMixin):
|
||||
url: str
|
||||
api_key: SecretStr
|
||||
|
||||
|
|
@ -76,152 +29,14 @@ class _HfAdapter(
|
|||
|
||||
overwrite_completion_id = True # TGI always returns id=""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.huggingface_repo_to_llama_model_id = {
|
||||
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
|
||||
}
|
||||
|
||||
def get_api_key(self):
|
||||
return self.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self):
|
||||
return self.url
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
models = []
|
||||
async for model in self.client.models.list():
|
||||
models.append(
|
||||
Model(
|
||||
identifier=model.id,
|
||||
provider_resource_id=model.id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if model.provider_resource_id != self.model_id:
|
||||
raise ValueError(
|
||||
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
|
||||
)
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
def _get_max_new_tokens(self, sampling_params, input_tokens):
|
||||
return min(
|
||||
sampling_params.max_tokens or (self.max_tokens - input_tokens),
|
||||
self.max_tokens - input_tokens - 1,
|
||||
)
|
||||
|
||||
def _build_options(
|
||||
self,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
fmt: ResponseFormat = None,
|
||||
):
|
||||
options = get_sampling_options(sampling_params)
|
||||
# TGI does not support temperature=0 when using greedy sampling
|
||||
# We set it to 1e-3 instead, anything lower outputs garbage from TGI
|
||||
# We can use top_p sampling strategy to specify lower temperature
|
||||
if abs(options["temperature"]) < 1e-10:
|
||||
options["temperature"] = 1e-3
|
||||
|
||||
# delete key "max_tokens" from options since its not supported by the API
|
||||
options.pop("max_tokens", None)
|
||||
if fmt:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["grammar"] = {
|
||||
"type": "json",
|
||||
"value": fmt.json_schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
raise ValueError("Grammar response format not supported yet")
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format: {fmt.type}")
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await self.hf_client.text_generation(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r.details.finish_reason,
|
||||
text="".join(t.text for t in r.details.tokens),
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.hf_client.text_generation(**params)
|
||||
async for chunk in s:
|
||||
token_result = chunk.token
|
||||
|
||||
choice = OpenAICompatCompletionChoice(text=token_result.text)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
prompt, input_tokens = await chat_completion_request_to_model_input_info(
|
||||
request, self.register_helper.get_llama_model(request.model)
|
||||
)
|
||||
return dict(
|
||||
prompt=prompt,
|
||||
stream=request.stream,
|
||||
details=True,
|
||||
max_new_tokens=self._get_max_new_tokens(request.sampling_params, input_tokens),
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
)
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [self.model_id]
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,6 @@ async def get_adapter_impl(config: TogetherImplConfig, _deps):
|
|||
from .together import TogetherInferenceAdapter
|
||||
|
||||
assert isinstance(config, TogetherImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = TogetherInferenceAdapter(config)
|
||||
impl = TogetherInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,51 +4,30 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from collections.abc import Iterable
|
||||
|
||||
from together import AsyncTogether
|
||||
from together.constants import BASE_URL
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::together")
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
embedding_model_metadata = {
|
||||
class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
||||
config: TogetherImplConfig
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768},
|
||||
"BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512},
|
||||
"BAAI/bge-base-en-v1.5": {"embedding_dimension": 768, "context_length": 512},
|
||||
|
|
@ -56,24 +35,16 @@ class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Need
|
|||
"intfloat/multilingual-e5-large-instruct": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self)
|
||||
self.config = config
|
||||
self.allowed_models = config.allowed_models
|
||||
self._model_cache: dict[str, Model] = {}
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
provider_data_api_key_field: str = "together_api_key"
|
||||
|
||||
def get_api_key(self):
|
||||
return self.config.api_key.get_secret_value()
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
|
||||
def get_base_url(self):
|
||||
return BASE_URL
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_client(self) -> AsyncTogether:
|
||||
together_api_key = None
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
|
|
@ -88,142 +59,13 @@ class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Need
|
|||
together_api_key = provider_data.together_api_key
|
||||
return AsyncTogether(api_key=together_api_key)
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
together_client = self._get_client().client
|
||||
return AsyncOpenAI(
|
||||
base_url=together_client.base_url,
|
||||
api_key=together_client.api_key,
|
||||
)
|
||||
|
||||
def _build_options(
|
||||
self,
|
||||
sampling_params: SamplingParams | None,
|
||||
logprobs: LogProbConfig | None,
|
||||
fmt: ResponseFormat,
|
||||
) -> dict:
|
||||
options = get_sampling_options(sampling_params)
|
||||
if fmt:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": fmt.json_schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
raise NotImplementedError("Grammar response format not supported yet")
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
if logprobs and logprobs.top_k:
|
||||
if logprobs.top_k != 1:
|
||||
raise ValueError(
|
||||
f"Unsupported value: Together only supports logprobs top_k=1. {logprobs.top_k} was provided",
|
||||
)
|
||||
options["logprobs"] = 1
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
client = self._get_client()
|
||||
if "messages" in params:
|
||||
r = await client.chat.completions.create(**params)
|
||||
else:
|
||||
r = await client.completions.create(**params)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
client = self._get_client()
|
||||
if "messages" in params:
|
||||
stream = await client.chat.completions.create(**params)
|
||||
else:
|
||||
stream = await client.completions.create(**params)
|
||||
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if media_present or not llama_model:
|
||||
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
|
||||
params = {
|
||||
"model": request.model,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
|
||||
}
|
||||
logger.debug(f"params to together: {params}")
|
||||
return params
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._model_cache = {}
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
|
||||
for m in await self._get_client().models.list():
|
||||
if m.type == "embedding":
|
||||
if m.id not in self.embedding_model_metadata:
|
||||
logger.warning(f"Unknown embedding dimension for model {m.id}, skipping.")
|
||||
continue
|
||||
metadata = self.embedding_model_metadata[m.id]
|
||||
self._model_cache[m.id] = Model(
|
||||
provider_id=self.__provider_id__,
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
else:
|
||||
self._model_cache[m.id] = Model(
|
||||
provider_id=self.__provider_id__,
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
return self._model_cache.values()
|
||||
return [m.id for m in await self._get_client().models.list()]
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return True
|
||||
|
||||
async def check_model_availability(self, model):
|
||||
return model in self._model_cache
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -264,4 +106,4 @@ class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Need
|
|||
)
|
||||
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
|
||||
|
||||
return response
|
||||
return response # type: ignore[no-any-return]
|
||||
|
|
|
|||
|
|
@ -10,6 +10,6 @@ from .config import VertexAIConfig
|
|||
async def get_adapter_impl(config: VertexAIConfig, _deps):
|
||||
from .vertexai import VertexAIInferenceAdapter
|
||||
|
||||
impl = VertexAIInferenceAdapter(config)
|
||||
impl = VertexAIInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -23,7 +24,7 @@ class VertexAIProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class VertexAIConfig(BaseModel):
|
||||
class VertexAIConfig(RemoteInferenceProviderConfig):
|
||||
project: str = Field(
|
||||
description="Google Cloud project ID for Vertex AI",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,29 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import google.auth.transport.requests
|
||||
from google.auth import default
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||
LiteLLMOpenAIMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import VertexAIConfig
|
||||
|
||||
|
||||
class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: VertexAIConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="vertex_ai",
|
||||
api_key_from_config=None, # Vertex AI uses ADC, not API keys
|
||||
provider_data_api_key_field="vertex_project", # Use project for validation
|
||||
)
|
||||
self.config = config
|
||||
class VertexAIInferenceAdapter(OpenAIMixin):
|
||||
config: VertexAIConfig
|
||||
|
||||
provider_data_api_key_field: str = "vertex_project"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
|
|
@ -41,8 +31,7 @@ class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
credentials.refresh(google.auth.transport.requests.Request())
|
||||
return str(credentials.token)
|
||||
except Exception:
|
||||
# If we can't get credentials, return empty string to let LiteLLM handle it
|
||||
# This allows the LiteLLM mixin to work with ADC directly
|
||||
# If we can't get credentials, return empty string to let the env work with ADC directly
|
||||
return ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
|
|
@ -53,23 +42,3 @@ class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
|
||||
"""
|
||||
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||
# Get base parameters from parent
|
||||
params = await super()._get_params(request)
|
||||
|
||||
# Add Vertex AI specific parameters
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data:
|
||||
if getattr(provider_data, "vertex_project", None):
|
||||
params["vertex_project"] = provider_data.vertex_project
|
||||
if getattr(provider_data, "vertex_location", None):
|
||||
params["vertex_location"] = provider_data.vertex_location
|
||||
else:
|
||||
params["vertex_project"] = self.config.project
|
||||
params["vertex_location"] = self.config.location
|
||||
|
||||
# Remove api_key since Vertex AI uses ADC
|
||||
params.pop("api_key", None)
|
||||
|
||||
return params
|
||||
|
|
|
|||
|
|
@ -17,6 +17,6 @@ async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
|
|||
from .vllm import VLLMInferenceAdapter
|
||||
|
||||
assert isinstance(config, VLLMInferenceAdapterConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = VLLMInferenceAdapter(config)
|
||||
impl = VLLMInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VLLMInferenceAdapterConfig(BaseModel):
|
||||
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
default=None,
|
||||
description="The URL for the vLLM model serving endpoint",
|
||||
|
|
|
|||
|
|
@ -3,63 +3,26 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
from openai import APIConnectionError, AsyncOpenAI
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
)
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ModelStore,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
UnparseableToolCall,
|
||||
convert_message_to_openai_dict,
|
||||
convert_openai_chat_completion_stream,
|
||||
convert_tool_call,
|
||||
get_sampling_options,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
|
@ -68,210 +31,15 @@ from .config import VLLMInferenceAdapterConfig
|
|||
log = get_logger(name=__name__, category="inference::vllm")
|
||||
|
||||
|
||||
def build_hf_repo_model_entries():
|
||||
return [
|
||||
build_hf_repo_model_entry(
|
||||
model.huggingface_repo,
|
||||
model.descriptor(),
|
||||
)
|
||||
for model in all_registered_models()
|
||||
if model.huggingface_repo
|
||||
]
|
||||
class VLLMInferenceAdapter(OpenAIMixin):
|
||||
config: VLLMInferenceAdapterConfig
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def _convert_to_vllm_tool_calls_in_response(
|
||||
tool_calls,
|
||||
) -> list[ToolCall]:
|
||||
if not tool_calls:
|
||||
return []
|
||||
provider_data_api_key_field: str = "vllm_api_token"
|
||||
|
||||
return [
|
||||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=call.function.arguments,
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
||||
|
||||
def _convert_to_vllm_tools_in_request(tools: list[ToolDefinition]) -> list[dict]:
|
||||
compat_tools = []
|
||||
|
||||
for tool in tools:
|
||||
# The tool.tool_name can be a str or a BuiltinTool enum. If
|
||||
# it's the latter, convert to a string.
|
||||
tool_name = tool.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
|
||||
compat_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.input_schema
|
||||
or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
compat_tools.append(compat_tool)
|
||||
|
||||
return compat_tools
|
||||
|
||||
|
||||
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
|
||||
return {
|
||||
"stop": StopReason.end_of_turn,
|
||||
"length": StopReason.out_of_tokens,
|
||||
"tool_calls": StopReason.end_of_message,
|
||||
}.get(finish_reason, StopReason.end_of_turn)
|
||||
|
||||
|
||||
def _process_vllm_chat_completion_end_of_stream(
|
||||
finish_reason: str | None,
|
||||
last_chunk_content: str | None,
|
||||
current_event_type: ChatCompletionResponseEventType,
|
||||
tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
|
||||
) -> list[OpenAIChatCompletionChunk]:
|
||||
chunks = []
|
||||
|
||||
if finish_reason is not None:
|
||||
stop_reason = _convert_to_vllm_finish_reason(finish_reason)
|
||||
else:
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
tool_call_bufs = tool_call_bufs or {}
|
||||
for _index, tool_call_buf in sorted(tool_call_bufs.items()):
|
||||
args_str = tool_call_buf.arguments or "{}"
|
||||
try:
|
||||
chunks.append(
|
||||
ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=current_event_type,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=ToolCall(
|
||||
call_id=tool_call_buf.call_id,
|
||||
tool_name=tool_call_buf.tool_name,
|
||||
arguments=args_str,
|
||||
),
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
|
||||
|
||||
chunks.append(
|
||||
ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=str(tool_call_buf),
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
chunks.append(
|
||||
ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=last_chunk_content or ""),
|
||||
logprobs=None,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
async def _process_vllm_chat_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
|
||||
) -> AsyncGenerator:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
event_type = ChatCompletionResponseEventType.progress
|
||||
tool_call_bufs: dict[str, UnparseableToolCall] = {}
|
||||
end_of_stream_processed = False
|
||||
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
|
||||
return
|
||||
choice = chunk.choices[0]
|
||||
if choice.delta.tool_calls:
|
||||
for delta_tool_call in choice.delta.tool_calls:
|
||||
tool_call = convert_tool_call(delta_tool_call)
|
||||
if delta_tool_call.index not in tool_call_bufs:
|
||||
tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
|
||||
tool_call_buf = tool_call_bufs[delta_tool_call.index]
|
||||
tool_call_buf.tool_name += str(tool_call.tool_name)
|
||||
tool_call_buf.call_id += tool_call.call_id
|
||||
tool_call_buf.arguments += (
|
||||
tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments)
|
||||
)
|
||||
if choice.finish_reason:
|
||||
chunks = _process_vllm_chat_completion_end_of_stream(
|
||||
finish_reason=choice.finish_reason,
|
||||
last_chunk_content=choice.delta.content,
|
||||
current_event_type=event_type,
|
||||
tool_call_bufs=tool_call_bufs,
|
||||
)
|
||||
for c in chunks:
|
||||
yield c
|
||||
end_of_stream_processed = True
|
||||
elif not choice.delta.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=None,
|
||||
)
|
||||
)
|
||||
event_type = ChatCompletionResponseEventType.progress
|
||||
|
||||
if end_of_stream_processed:
|
||||
return
|
||||
|
||||
# the stream ended without a chunk containing finish_reason - we have to generate the
|
||||
# respective completion chunks manually
|
||||
chunks = _process_vllm_chat_completion_end_of_stream(
|
||||
finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_bufs=tool_call_bufs
|
||||
)
|
||||
for c in chunks:
|
||||
yield c
|
||||
|
||||
|
||||
class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate):
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
model_store: ModelStore | None = None
|
||||
|
||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=build_hf_repo_model_entries(),
|
||||
litellm_provider_name="vllm",
|
||||
api_key_from_config=config.api_token,
|
||||
provider_data_api_key_field="vllm_api_token",
|
||||
openai_compat_api_base=config.url,
|
||||
)
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.config = config
|
||||
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_token or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get the base URL from config."""
|
||||
|
|
@ -289,27 +57,6 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
|
|||
# Strictly respecting the refresh_models directive
|
||||
return self.config.refresh_models
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
models = []
|
||||
async for m in self.client.models.list():
|
||||
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.id,
|
||||
provider_resource_id=m.id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={},
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the remote vLLM server.
|
||||
|
|
@ -331,143 +78,66 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
|
|||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
async def _get_model(self, model_id: str) -> Model:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set")
|
||||
return await self.model_store.get_model(model_id)
|
||||
|
||||
def get_extra_client_params(self):
|
||||
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
|
||||
|
||||
async def chat_completion(
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
if model.provider_resource_id is None:
|
||||
raise ValueError(f"Model {model_id} has no provider_resource_id set")
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
max_tokens = max_tokens or self.config.max_tokens
|
||||
|
||||
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
||||
# References:
|
||||
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
# * https://github.com/vllm-project/vllm/pull/10000
|
||||
if not tools and tool_config is not None:
|
||||
tool_config.tool_choice = ToolChoice.none
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
if not tools and tool_choice is not None:
|
||||
tool_choice = ToolChoice.none.value
|
||||
|
||||
return await super().openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
stream=stream,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
tool_config=tool_config,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion_with_client(request, self.client)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, self.client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
assert self.client is not None
|
||||
params = await self._get_params(request)
|
||||
r = await client.chat.completions.create(**params)
|
||||
choice = r.choices[0]
|
||||
result = ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=choice.message.content or "",
|
||||
stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason),
|
||||
tool_calls=_convert_to_vllm_tool_calls_in_response(choice.message.tool_calls),
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
return result
|
||||
|
||||
async def _stream_chat_completion(self, response: Any) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
# This method is called from LiteLLMOpenAIMixin.chat_completion
|
||||
# The response parameter contains the litellm response
|
||||
# We need to convert it to our format
|
||||
async def _stream_generator():
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
||||
async for chunk in convert_openai_chat_completion_stream(
|
||||
_stream_generator(), enable_incremental_tool_calls=True
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def _stream_chat_completion_with_client(
|
||||
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
"""Helper method for streaming with explicit client parameter."""
|
||||
assert self.client is not None
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await client.chat.completions.create(**params)
|
||||
if request.tools:
|
||||
res = _process_vllm_chat_completion_stream_response(stream)
|
||||
else:
|
||||
res = process_chat_completion_stream_response(stream, request)
|
||||
async for chunk in res:
|
||||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
try:
|
||||
model = await self.register_helper.register_model(model)
|
||||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
try:
|
||||
res = self.client.models.list()
|
||||
except APIConnectionError as e:
|
||||
raise ValueError(
|
||||
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
|
||||
) from e
|
||||
available_models = [m.id async for m in res]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
f"Model {model.provider_resource_id} is not being served by vLLM. "
|
||||
f"Available models: {', '.join(available_models)}"
|
||||
)
|
||||
return model
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
options = get_sampling_options(request.sampling_params)
|
||||
if "max_tokens" not in options:
|
||||
options["max_tokens"] = self.config.max_tokens
|
||||
|
||||
input_dict: dict[str, Any] = {}
|
||||
# Only include the 'tools' param if there is any. It can break things if an empty list is sent to the vLLM.
|
||||
if isinstance(request, ChatCompletionRequest) and request.tools:
|
||||
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
|
||||
|
||||
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
|
||||
|
||||
if fmt := request.response_format:
|
||||
if isinstance(fmt, JsonSchemaResponseFormat):
|
||||
input_dict["extra_body"] = {"guided_json": fmt.json_schema}
|
||||
elif isinstance(fmt, GrammarResponseFormat):
|
||||
raise NotImplementedError("Grammar response format not supported yet")
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
if request.logprobs and request.logprobs.top_k:
|
||||
input_dict["logprobs"] = request.logprobs.top_k
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**options,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ class WatsonXProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class WatsonXConfig(BaseModel):
|
||||
class WatsonXConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||
description="A base url for accessing the watsonx.ai",
|
||||
|
|
|
|||
|
|
@ -12,34 +12,21 @@ from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
|||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
GreedySamplingStrategy,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
prepare_openai_completion_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
|
|
@ -77,12 +64,6 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
|
||||
self._project_id = self._config.project_id
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_client(self, model_id) -> Model:
|
||||
config_api_key = (
|
||||
self._config.api_key.get_secret_value() if self._config.api_key else None
|
||||
|
|
@ -101,81 +82,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
return self._openai_client
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self._get_client(request.model).generate(**params)
|
||||
choices = []
|
||||
if "results" in r:
|
||||
for result in r["results"]:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=(
|
||||
result["stop_reason"] if result["stop_reason"] else None
|
||||
),
|
||||
text=result["generated_text"],
|
||||
)
|
||||
choices.append(choice)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=choices,
|
||||
)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
model_id = request.model
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = self._get_client(model_id).generate_text_stream(**params)
|
||||
for chunk in s:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=None,
|
||||
text=chunk,
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
|
||||
input_dict = {"params": {}}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import weaviate
|
|||
import weaviate.classes as wvc
|
||||
from numpy.typing import NDArray
|
||||
from weaviate.classes.init import Auth
|
||||
from weaviate.classes.query import Filter
|
||||
from weaviate.classes.query import Filter, HybridFusion
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
|
|
@ -26,6 +26,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
|||
OpenAIVectorStoreMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
|
|
@ -47,7 +48,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
|
|||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(
|
||||
self,
|
||||
client: weaviate.Client,
|
||||
client: weaviate.WeaviateClient,
|
||||
collection_name: str,
|
||||
kvstore: KVStore | None = None,
|
||||
):
|
||||
|
|
@ -64,14 +65,14 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
data_objects = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
for chunk, embedding in zip(chunks, embeddings, strict=False):
|
||||
data_objects.append(
|
||||
wvc.data.DataObject(
|
||||
properties={
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"chunk_content": chunk.model_dump_json(),
|
||||
},
|
||||
vector=embeddings[i].tolist(),
|
||||
vector=embedding.tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -88,14 +89,30 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs vector search using Weaviate's built-in vector search.
|
||||
Args:
|
||||
embedding: The query embedding vector
|
||||
k: Limit of number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
Returns:
|
||||
QueryChunksResponse with chunks and scores.
|
||||
"""
|
||||
log.debug(
|
||||
f"WEAVIATE VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
|
||||
)
|
||||
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||
collection = self.client.collections.get(sanitized_collection_name)
|
||||
|
||||
results = collection.query.near_vector(
|
||||
near_vector=embedding.tolist(),
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(distance=True),
|
||||
)
|
||||
try:
|
||||
results = collection.query.near_vector(
|
||||
near_vector=embedding.tolist(),
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(distance=True),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Weaviate client vector search failed: {e}")
|
||||
raise
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
|
|
@ -108,13 +125,17 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
log.exception(f"Failed to parse document: {chunk_json}")
|
||||
continue
|
||||
|
||||
score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")
|
||||
if doc.metadata.distance is None:
|
||||
continue
|
||||
# Convert cosine distance ∈ [0,2] -> normalized cosine similarity ∈ [0,1]
|
||||
score = 1.0 - (float(doc.metadata.distance) / 2.0)
|
||||
if score < score_threshold:
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
log.debug(f"WEAVIATE VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self, chunk_ids: list[str] | None = None) -> None:
|
||||
|
|
@ -136,7 +157,50 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Weaviate")
|
||||
"""
|
||||
Performs BM25-based keyword search using Weaviate's built-in full-text search.
|
||||
Args:
|
||||
query_string: The text query for keyword search
|
||||
k: Limit of number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
Returns:
|
||||
QueryChunksResponse with chunks and scores
|
||||
"""
|
||||
log.debug(f"WEAVIATE KEYWORD SEARCH CALLED: query='{query_string}', k={k}, threshold={score_threshold}")
|
||||
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||
collection = self.client.collections.get(sanitized_collection_name)
|
||||
|
||||
# Perform BM25 keyword search on chunk_content field
|
||||
try:
|
||||
results = collection.query.bm25(
|
||||
query=query_string,
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Weaviate client keyword search failed: {e}")
|
||||
raise
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc in results.objects:
|
||||
chunk_json = doc.properties["chunk_content"]
|
||||
try:
|
||||
chunk_dict = json.loads(chunk_json)
|
||||
chunk = Chunk(**chunk_dict)
|
||||
except Exception:
|
||||
log.exception(f"Failed to parse document: {chunk_json}")
|
||||
continue
|
||||
|
||||
score = doc.metadata.score if doc.metadata.score is not None else 0.0
|
||||
if score < score_threshold:
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
log.debug(f"WEAVIATE KEYWORD SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}.")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
|
|
@ -147,7 +211,65 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Weaviate")
|
||||
"""
|
||||
Hybrid search combining vector similarity and keyword search using Weaviate's native hybrid search.
|
||||
Args:
|
||||
embedding: The query embedding vector
|
||||
query_string: The text query for keyword search
|
||||
k: Limit of number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
reranker_type: Type of reranker to use ("rrf" or "normalized")
|
||||
reranker_params: Parameters for the reranker
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
log.debug(
|
||||
f"WEAVIATE HYBRID SEARCH CALLED: query='{query_string}', embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}, reranker={reranker_type}"
|
||||
)
|
||||
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||
collection = self.client.collections.get(sanitized_collection_name)
|
||||
|
||||
# Ranked (RRF) reranker fusion type
|
||||
if reranker_type == RERANKER_TYPE_RRF:
|
||||
rerank = HybridFusion.RANKED
|
||||
# Relative score (Normalized) reranker fusion type
|
||||
else:
|
||||
rerank = HybridFusion.RELATIVE_SCORE
|
||||
|
||||
# Perform hybrid search using Weaviate's native hybrid search
|
||||
try:
|
||||
results = collection.query.hybrid(
|
||||
query=query_string,
|
||||
alpha=0.5, # Range <0, 1>, where 0.5 will equally favor vector and keyword search
|
||||
vector=embedding.tolist(),
|
||||
limit=k,
|
||||
fusion_type=rerank,
|
||||
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Weaviate client hybrid search failed: {e}")
|
||||
raise
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc in results.objects:
|
||||
chunk_json = doc.properties["chunk_content"]
|
||||
try:
|
||||
chunk_dict = json.loads(chunk_json)
|
||||
chunk = Chunk(**chunk_dict)
|
||||
except Exception:
|
||||
log.exception(f"Failed to parse document: {chunk_json}")
|
||||
continue
|
||||
|
||||
score = doc.metadata.score if doc.metadata.score is not None else 0.0
|
||||
if score < score_threshold:
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
log.debug(f"WEAVIATE HYBRID SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class WeaviateVectorIOAdapter(
|
||||
|
|
@ -172,9 +294,9 @@ class WeaviateVectorIOAdapter(
|
|||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
def _get_client(self) -> weaviate.Client:
|
||||
def _get_client(self) -> weaviate.WeaviateClient:
|
||||
if "localhost" in self.config.weaviate_cluster_url:
|
||||
log.info("using Weaviate locally in container")
|
||||
log.info("Using Weaviate locally in container")
|
||||
host, port = self.config.weaviate_cluster_url.split(":")
|
||||
key = "local_test"
|
||||
client = weaviate.connect_to_local(
|
||||
|
|
@ -247,7 +369,7 @@ class WeaviateVectorIOAdapter(
|
|||
],
|
||||
)
|
||||
|
||||
self.cache[sanitized_collection_name] = VectorDBWithIndex(
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db,
|
||||
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
||||
self.inference_api,
|
||||
|
|
@ -256,32 +378,34 @@ class WeaviateVectorIOAdapter(
|
|||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||
if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
||||
log.warning(f"Vector DB {sanitized_collection_name} not found")
|
||||
if vector_db_id not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
||||
return
|
||||
client.collections.delete(sanitized_collection_name)
|
||||
await self.cache[sanitized_collection_name].index.delete()
|
||||
del self.cache[sanitized_collection_name]
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||
if sanitized_collection_name in self.cache:
|
||||
return self.cache[sanitized_collection_name]
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(sanitized_collection_name)
|
||||
if self.vector_db_store is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
client = self._get_client()
|
||||
if not client.collections.exists(vector_db.identifier):
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
|
||||
if not client.collections.exists(sanitized_collection_name):
|
||||
raise ValueError(f"Collection with name `{sanitized_collection_name}` not found")
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
||||
index=WeaviateIndex(client=client, collection_name=vector_db.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[sanitized_collection_name] = index
|
||||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
async def insert_chunks(
|
||||
|
|
@ -290,8 +414,7 @@ class WeaviateVectorIOAdapter(
|
|||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
|
|
@ -303,17 +426,15 @@ class WeaviateVectorIOAdapter(
|
|||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
sanitized_collection_name = sanitize_collection_name(store_id, weaviate_format=True)
|
||||
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {sanitized_collection_name} not found")
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
|
||||
await index.index.delete_chunks(chunks_for_deletion)
|
||||
|
|
|
|||
|
|
@ -6,10 +6,12 @@
|
|||
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
||||
|
||||
class BedrockBaseConfig(BaseModel):
|
||||
class BedrockBaseConfig(RemoteInferenceProviderConfig):
|
||||
aws_access_key_id: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||
|
|
|
|||
|
|
@ -11,12 +11,8 @@ import litellm
|
|||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
|
|
@ -24,12 +20,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -37,8 +28,6 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
|
|||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
b64_encode_openai_embeddings_response,
|
||||
convert_message_to_openai_dict_new,
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
convert_tooldef_to_openai_tool,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
|
|
@ -105,57 +94,6 @@ class LiteLLMOpenAIMixin(
|
|||
else model_id
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
params = await self._get_params(request)
|
||||
params["model"] = self.get_litellm_model_name(params["model"])
|
||||
|
||||
logger.debug(f"params to litellm (openai compat): {params}")
|
||||
# see https://docs.litellm.ai/docs/completion/stream#async-completion
|
||||
response = await litellm.acompletion(**params)
|
||||
if stream:
|
||||
return self._stream_chat_completion(response)
|
||||
else:
|
||||
return convert_openai_chat_completion_choice(response.choices[0])
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, response: litellm.ModelResponse
|
||||
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
async def _stream_generator():
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
||||
async for chunk in convert_openai_chat_completion_stream(
|
||||
_stream_generator(), enable_incremental_tool_calls=True
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _add_additional_properties_recursive(self, schema):
|
||||
"""
|
||||
Recursively add additionalProperties: False to all object schemas
|
||||
|
|
|
|||
|
|
@ -7,10 +7,11 @@
|
|||
import base64
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import AsyncIterator, Iterable
|
||||
from typing import Any
|
||||
|
||||
from openai import NOT_GIVEN, AsyncOpenAI
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Model,
|
||||
|
|
@ -26,14 +27,14 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
||||
class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||
"""
|
||||
Mixin class that provides OpenAI-specific functionality for inference providers.
|
||||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||
|
|
@ -42,12 +43,25 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
|||
- get_api_key(): Method to retrieve the API key
|
||||
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
|
||||
|
||||
The behavior of this class can be customized by child classes in the following ways:
|
||||
- overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses
|
||||
- download_images: If True, downloads images and converts to base64 for providers that require it
|
||||
- embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata
|
||||
- provider_data_api_key_field: Optional field name in provider data to look for API key
|
||||
- list_provider_model_ids: Method to list available models from the provider
|
||||
- get_extra_client_params: Method to provide extra parameters to the AsyncOpenAI client
|
||||
|
||||
Expected Dependencies:
|
||||
- self.model_store: Injected by the Llama Stack distribution system at runtime.
|
||||
This provides model registry functionality for looking up registered models.
|
||||
The model_store is set in routing_tables/common.py during provider initialization.
|
||||
"""
|
||||
|
||||
# Allow extra fields so the routing infra can inject model_store, __provider_id__, etc.
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
config: RemoteInferenceProviderConfig
|
||||
|
||||
# Allow subclasses to control whether to overwrite the 'id' field in OpenAI responses
|
||||
# is overwritten with a client-side generated id.
|
||||
#
|
||||
|
|
@ -108,6 +122,38 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
|||
"""
|
||||
return {}
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
"""
|
||||
List available models from the provider.
|
||||
|
||||
Child classes can override this method to provide a custom implementation
|
||||
for listing models. The default implementation uses the AsyncOpenAI client
|
||||
to list models from the OpenAI-compatible endpoint.
|
||||
|
||||
:return: An iterable of model IDs or None if not implemented
|
||||
"""
|
||||
return [m.id async for m in self.client.models.list()]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Initialize the OpenAI mixin.
|
||||
|
||||
This method provides a default implementation that does nothing.
|
||||
Subclasses can override this method to perform initialization tasks
|
||||
such as setting up clients, validating configurations, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the OpenAI mixin.
|
||||
|
||||
This method provides a default implementation that does nothing.
|
||||
Subclasses can override this method to perform cleanup tasks
|
||||
such as closing connections, releasing resources, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncOpenAI:
|
||||
"""
|
||||
|
|
@ -356,6 +402,24 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
|||
usage=usage,
|
||||
)
|
||||
|
||||
###
|
||||
# ModelsProtocolPrivate implementation - provide model management functionality
|
||||
#
|
||||
# async def register_model(self, model: Model) -> Model: ...
|
||||
# async def unregister_model(self, model_id: str) -> None: ...
|
||||
#
|
||||
# async def list_models(self) -> list[Model] | None: ...
|
||||
# async def should_refresh_models(self) -> bool: ...
|
||||
##
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if not await self.check_model_availability(model.provider_model_id):
|
||||
raise ValueError(f"Model {model.provider_model_id} is not available from provider {self.__provider_id__}") # type: ignore[attr-defined]
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
return None
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
"""
|
||||
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
|
||||
|
|
@ -366,28 +430,42 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
|||
"""
|
||||
self._model_cache = {}
|
||||
|
||||
async for m in self.client.models.list():
|
||||
if self.allowed_models and m.id not in self.allowed_models:
|
||||
logger.info(f"Skipping model {m.id} as it is not in the allowed models list")
|
||||
try:
|
||||
iterable = await self.list_provider_model_ids()
|
||||
except Exception as e:
|
||||
logger.error(f"{self.__class__.__name__}.list_provider_model_ids() failed with: {e}")
|
||||
raise
|
||||
if not hasattr(iterable, "__iter__"):
|
||||
raise TypeError(
|
||||
f"Failed to list models: {self.__class__.__name__}.list_provider_model_ids() must return an iterable of "
|
||||
f"strings, but returned {type(iterable).__name__}"
|
||||
)
|
||||
|
||||
provider_models_ids = list(iterable)
|
||||
logger.info(f"{self.__class__.__name__}.list_provider_model_ids() returned {len(provider_models_ids)} models")
|
||||
|
||||
for provider_model_id in provider_models_ids:
|
||||
if not isinstance(provider_model_id, str):
|
||||
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
|
||||
if self.allowed_models and provider_model_id not in self.allowed_models:
|
||||
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
|
||||
continue
|
||||
if metadata := self.embedding_model_metadata.get(m.id):
|
||||
# This is an embedding model - augment with metadata
|
||||
if metadata := self.embedding_model_metadata.get(provider_model_id):
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
provider_resource_id=provider_model_id,
|
||||
identifier=provider_model_id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
else:
|
||||
# This is an LLM
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=m.id,
|
||||
identifier=m.id,
|
||||
provider_resource_id=provider_model_id,
|
||||
identifier=provider_model_id,
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
self._model_cache[m.id] = model
|
||||
self._model_cache[provider_model_id] = model
|
||||
|
||||
return list(self._model_cache.values())
|
||||
|
||||
|
|
@ -400,5 +478,33 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
|||
"""
|
||||
if not self._model_cache:
|
||||
await self.list_models()
|
||||
|
||||
return model in self._model_cache
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
#
|
||||
# The model_dump implementations are to avoid serializing the extra fields,
|
||||
# e.g. model_store, which are not pydantic.
|
||||
#
|
||||
|
||||
def _filter_fields(self, **kwargs):
|
||||
"""Helper to exclude extra fields from serialization."""
|
||||
# Exclude any extra fields stored in __pydantic_extra__
|
||||
if hasattr(self, "__pydantic_extra__") and self.__pydantic_extra__:
|
||||
exclude = kwargs.get("exclude", set())
|
||||
if not isinstance(exclude, set):
|
||||
exclude = set(exclude) if exclude else set()
|
||||
exclude.update(self.__pydantic_extra__.keys())
|
||||
kwargs["exclude"] = exclude
|
||||
return kwargs
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
"""Override to exclude extra fields from serialization."""
|
||||
kwargs = self._filter_fields(**kwargs)
|
||||
return super().model_dump(**kwargs)
|
||||
|
||||
def model_dump_json(self, **kwargs):
|
||||
"""Override to exclude extra fields from JSON serialization."""
|
||||
kwargs = self._filter_fields(**kwargs)
|
||||
return super().model_dump_json(**kwargs)
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ class ChunkForDeletion(BaseModel):
|
|||
# Constants for reranker types
|
||||
RERANKER_TYPE_RRF = "rrf"
|
||||
RERANKER_TYPE_WEIGHTED = "weighted"
|
||||
RERANKER_TYPE_NORMALIZED = "normalized"
|
||||
|
||||
|
||||
def parse_pdf(data: bytes) -> str:
|
||||
|
|
@ -325,6 +326,8 @@ class VectorDBWithIndex:
|
|||
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
|
||||
reranker_type = RERANKER_TYPE_WEIGHTED
|
||||
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
|
||||
elif strategy == "normalized":
|
||||
reranker_type = RERANKER_TYPE_NORMALIZED
|
||||
else:
|
||||
reranker_type = RERANKER_TYPE_RRF
|
||||
k_value = ranker.get("params", {}).get("k", 60.0)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectWithInput,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig
|
||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -28,6 +29,19 @@ from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreTy
|
|||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
|
||||
class _OpenAIResponseObjectWithInputAndMessages(OpenAIResponseObjectWithInput):
|
||||
"""Internal class for storing responses with chat completion messages.
|
||||
|
||||
This extends the public OpenAIResponseObjectWithInput with messages field
|
||||
for internal storage. The messages field is not exposed in the public API.
|
||||
|
||||
The messages field is optional for backward compatibility with responses
|
||||
stored before this feature was added.
|
||||
"""
|
||||
|
||||
messages: list[OpenAIMessageParam] | None = None
|
||||
|
||||
|
||||
class ResponsesStore:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -54,7 +68,9 @@ class ResponsesStore:
|
|||
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
||||
|
||||
# Async write queue and worker control
|
||||
self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None
|
||||
self._queue: (
|
||||
asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
|
||||
) = None
|
||||
self._worker_tasks: list[asyncio.Task[Any]] = []
|
||||
self._max_write_queue_size: int = config.max_write_queue_size
|
||||
self._num_writers: int = max(1, config.num_writers)
|
||||
|
|
@ -100,18 +116,21 @@ class ResponsesStore:
|
|||
await self._queue.join()
|
||||
|
||||
async def store_response_object(
|
||||
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
||||
self,
|
||||
response_object: OpenAIResponseObject,
|
||||
input: list[OpenAIResponseInput],
|
||||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
if self.enable_write_queue:
|
||||
if self._queue is None:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
try:
|
||||
self._queue.put_nowait((response_object, input))
|
||||
self._queue.put_nowait((response_object, input, messages))
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
|
||||
await self._queue.put((response_object, input))
|
||||
await self._queue.put((response_object, input, messages))
|
||||
else:
|
||||
await self._write_response_object(response_object, input)
|
||||
await self._write_response_object(response_object, input, messages)
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
assert self._queue is not None
|
||||
|
|
@ -120,22 +139,26 @@ class ResponsesStore:
|
|||
item = await self._queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
response_object, input = item
|
||||
response_object, input, messages = item
|
||||
try:
|
||||
await self._write_response_object(response_object, input)
|
||||
await self._write_response_object(response_object, input, messages)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error writing response object: {e}")
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _write_response_object(
|
||||
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
||||
self,
|
||||
response_object: OpenAIResponseObject,
|
||||
input: list[OpenAIResponseInput],
|
||||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
if self.sql_store is None:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
data = response_object.model_dump()
|
||||
data["input"] = [input_item.model_dump() for input_item in input]
|
||||
data["messages"] = [msg.model_dump() for msg in messages]
|
||||
|
||||
await self.sql_store.insert(
|
||||
"openai_responses",
|
||||
|
|
@ -188,7 +211,7 @@ class ResponsesStore:
|
|||
last_id=data[-1].id if data else "",
|
||||
)
|
||||
|
||||
async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput:
|
||||
async def get_response_object(self, response_id: str) -> _OpenAIResponseObjectWithInputAndMessages:
|
||||
"""
|
||||
Get a response object with automatic access control checking.
|
||||
"""
|
||||
|
|
@ -205,7 +228,7 @@ class ResponsesStore:
|
|||
# This provides security by not revealing whether the record exists
|
||||
raise ValueError(f"Response with id {response_id} not found") from None
|
||||
|
||||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
||||
return _OpenAIResponseObjectWithInputAndMessages(**row["response_object"])
|
||||
|
||||
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
if not self.sql_store:
|
||||
|
|
@ -241,8 +264,8 @@ class ResponsesStore:
|
|||
if before and after:
|
||||
raise ValueError("Cannot specify both 'before' and 'after' parameters")
|
||||
|
||||
response_with_input = await self.get_response_object(response_id)
|
||||
items = response_with_input.input
|
||||
response_with_input_and_messages = await self.get_response_object(response_id)
|
||||
items = response_with_input_and_messages.input
|
||||
|
||||
if order == Order.desc:
|
||||
items = list(reversed(items))
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
|
|
@ -41,9 +41,9 @@ class SqlStore(Protocol):
|
|||
"""
|
||||
pass
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||
"""
|
||||
Insert a row into a table.
|
||||
Insert a row or batch of rows into a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
||||
|
|
@ -38,6 +38,18 @@ SQL_OPTIMIZED_POLICY = [
|
|||
]
|
||||
|
||||
|
||||
def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]:
|
||||
"""Add access control attributes to a data item."""
|
||||
enhanced = dict(item)
|
||||
if current_user:
|
||||
enhanced["owner_principal"] = current_user.principal
|
||||
enhanced["access_attributes"] = current_user.attributes
|
||||
else:
|
||||
enhanced["owner_principal"] = None
|
||||
enhanced["access_attributes"] = None
|
||||
return enhanced
|
||||
|
||||
|
||||
class SqlRecord(ProtectedResource):
|
||||
def __init__(self, record_id: str, table_name: str, owner: User):
|
||||
self.type = f"sql_record::{table_name}"
|
||||
|
|
@ -102,18 +114,14 @@ class AuthorizedSqlStore:
|
|||
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
||||
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
"""Insert a row with automatic access control attribute capture."""
|
||||
enhanced_data = dict(data)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||
"""Insert a row or batch of rows with automatic access control attribute capture."""
|
||||
current_user = get_authenticated_user()
|
||||
if current_user:
|
||||
enhanced_data["owner_principal"] = current_user.principal
|
||||
enhanced_data["access_attributes"] = current_user.attributes
|
||||
enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]]
|
||||
if isinstance(data, Mapping):
|
||||
enhanced_data = _enhance_item_with_access_control(data, current_user)
|
||||
else:
|
||||
enhanced_data["owner_principal"] = None
|
||||
enhanced_data["access_attributes"] = None
|
||||
|
||||
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
|
||||
await self.sql_store.insert(table, enhanced_data)
|
||||
|
||||
async def fetch_all(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy import (
|
||||
|
|
@ -116,7 +116,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
async with engine.begin() as conn:
|
||||
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||
async with self.async_session() as session:
|
||||
await session.execute(self.metadata.tables[table].insert(), data)
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -11,6 +11,43 @@ from typing import Any, TypeVar
|
|||
from .strong_typing.schema import json_schema_type, register_schema # noqa: F401
|
||||
|
||||
|
||||
class ExtraBodyField[T]:
|
||||
"""
|
||||
Marker annotation for parameters that arrive via extra_body in the client SDK.
|
||||
|
||||
These parameters:
|
||||
- Will NOT appear in the generated client SDK method signature
|
||||
- WILL be documented in OpenAPI spec under x-llama-stack-extra-body-params
|
||||
- MUST be passed via the extra_body parameter in client SDK calls
|
||||
- WILL be available in server-side method signature with proper typing
|
||||
|
||||
Example:
|
||||
```python
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str,
|
||||
model: str,
|
||||
shields: Annotated[
|
||||
list[str] | None, ExtraBodyField("List of shields to apply")
|
||||
] = None,
|
||||
) -> ResponseObject:
|
||||
# shields is available here with proper typing
|
||||
if shields:
|
||||
print(f"Using shields: {shields}")
|
||||
```
|
||||
|
||||
Client usage:
|
||||
```python
|
||||
client.responses.create(
|
||||
input="hello", model="llama-3", extra_body={"shields": ["shield-1"]}
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
self.description = description
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebMethod:
|
||||
level: str | None = None
|
||||
|
|
@ -26,7 +63,7 @@ class WebMethod:
|
|||
deprecated: bool | None = False
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable[..., Any])
|
||||
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def webmethod(
|
||||
|
|
@ -40,7 +77,7 @@ def webmethod(
|
|||
descriptive_name: str | None = None,
|
||||
required_scope: str | None = None,
|
||||
deprecated: bool | None = False,
|
||||
) -> Callable[[T], T]:
|
||||
) -> Callable[[CallableT], CallableT]:
|
||||
"""
|
||||
Decorator that supplies additional metadata to an endpoint operation function.
|
||||
|
||||
|
|
@ -51,7 +88,7 @@ def webmethod(
|
|||
:param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer').
|
||||
"""
|
||||
|
||||
def wrap(func: T) -> T:
|
||||
def wrap(func: CallableT) -> CallableT:
|
||||
webmethod_obj = WebMethod(
|
||||
route=route,
|
||||
method=method,
|
||||
|
|
|
|||
|
|
@ -484,12 +484,19 @@ class JsonSchemaGenerator:
|
|||
}
|
||||
return ret
|
||||
elif origin_type is Literal:
|
||||
if len(typing.get_args(typ)) != 1:
|
||||
raise ValueError(f"Literal type {typ} has {len(typing.get_args(typ))} arguments")
|
||||
(literal_value,) = typing.get_args(typ) # unpack value of literal type
|
||||
schema = self.type_to_schema(type(literal_value))
|
||||
schema["const"] = literal_value
|
||||
return schema
|
||||
literal_args = typing.get_args(typ)
|
||||
if len(literal_args) == 1:
|
||||
(literal_value,) = literal_args
|
||||
schema = self.type_to_schema(type(literal_value))
|
||||
schema["const"] = literal_value
|
||||
return schema
|
||||
elif len(literal_args) > 1:
|
||||
first_value = literal_args[0]
|
||||
schema = self.type_to_schema(type(first_value))
|
||||
schema["enum"] = list(literal_args)
|
||||
return schema
|
||||
else:
|
||||
return {"enum": []}
|
||||
elif origin_type is type:
|
||||
(concrete_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return {"const": self.type_to_schema(concrete_type, force_expand=True)}
|
||||
|
|
|
|||
|
|
@ -22,10 +22,18 @@ from llama_stack.log import get_logger
|
|||
logger = get_logger(__name__, category="testing")
|
||||
|
||||
# Global state for the recording system
|
||||
# Note: Using module globals instead of ContextVars because the session-scoped
|
||||
# client initialization happens in one async context, but tests run in different
|
||||
# contexts, and we need the mode/storage to persist across all contexts.
|
||||
_current_mode: str | None = None
|
||||
_current_storage: ResponseStorage | None = None
|
||||
_original_methods: dict[str, Any] = {}
|
||||
|
||||
# Test context uses ContextVar since it changes per-test and needs async isolation
|
||||
from contextvars import ContextVar
|
||||
|
||||
_test_context: ContextVar[str | None] = ContextVar("_test_context", default=None)
|
||||
|
||||
from openai.types.completion_choice import CompletionChoice
|
||||
|
||||
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
|
||||
|
|
@ -33,22 +41,38 @@ CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "len
|
|||
CompletionChoice.model_rebuild()
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/recordings"
|
||||
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/common"
|
||||
|
||||
|
||||
class InferenceMode(StrEnum):
|
||||
LIVE = "live"
|
||||
RECORD = "record"
|
||||
REPLAY = "replay"
|
||||
RECORD_IF_MISSING = "record-if-missing"
|
||||
|
||||
|
||||
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||
"""Create a normalized hash of the request for consistent matching."""
|
||||
"""Create a normalized hash of the request for consistent matching.
|
||||
|
||||
Includes test_id from context to ensure test isolation - identical requests
|
||||
from different tests will have different hashes.
|
||||
|
||||
Exception: Model list endpoints (/v1/models, /api/tags) exclude test_id since
|
||||
they are infrastructure/shared and need to work across session setup and tests.
|
||||
"""
|
||||
# Extract just the endpoint path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(url)
|
||||
normalized = {"method": method.upper(), "endpoint": parsed.path, "body": body}
|
||||
normalized: dict[str, Any] = {
|
||||
"method": method.upper(),
|
||||
"endpoint": parsed.path,
|
||||
"body": body,
|
||||
}
|
||||
|
||||
# Include test_id for isolation, except for shared infrastructure endpoints
|
||||
if parsed.path not in ("/api/tags", "/v1/models"):
|
||||
normalized["test_id"] = _test_context.get()
|
||||
|
||||
# Create hash - sort_keys=True ensures deterministic ordering
|
||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||
|
|
@ -67,7 +91,11 @@ def setup_inference_recording():
|
|||
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
|
||||
|
||||
Two environment variables are supported:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'. Default is 'replay'.
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', 'replay', or 'record-if-missing'. Default is 'replay'.
|
||||
- 'live': Make all requests live without recording
|
||||
- 'record': Record all requests (overwrites existing recordings)
|
||||
- 'replay': Use only recorded responses (fails if recording not found)
|
||||
- 'record-if-missing': Use recorded responses when available, record new ones when not found
|
||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'.
|
||||
|
||||
The recordings are stored as JSON files.
|
||||
|
|
@ -80,9 +108,43 @@ def setup_inference_recording():
|
|||
return inference_recording(mode=mode, storage_dir=storage_dir)
|
||||
|
||||
|
||||
def _serialize_response(response: Any) -> Any:
|
||||
def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[str, Any]:
|
||||
"""Normalize fields that change between recordings but don't affect functionality.
|
||||
|
||||
This reduces noise in git diffs by making IDs deterministic and timestamps constant.
|
||||
"""
|
||||
# Only normalize ID for completion/chat responses, not for model objects
|
||||
# Model objects have "object": "model" and the ID is the actual model identifier
|
||||
if "id" in data and data.get("object") != "model":
|
||||
data["id"] = f"rec-{request_hash[:12]}"
|
||||
|
||||
# Normalize timestamp to epoch (0) (for OpenAI-style responses)
|
||||
# But not for model objects where created timestamp might be meaningful
|
||||
if "created" in data and data.get("object") != "model":
|
||||
data["created"] = 0
|
||||
|
||||
# Normalize Ollama-specific timestamp fields
|
||||
if "created_at" in data:
|
||||
data["created_at"] = "1970-01-01T00:00:00.000000Z"
|
||||
|
||||
# Normalize Ollama-specific duration fields (these vary based on system load)
|
||||
if "total_duration" in data and data["total_duration"] is not None:
|
||||
data["total_duration"] = 0
|
||||
if "load_duration" in data and data["load_duration"] is not None:
|
||||
data["load_duration"] = 0
|
||||
if "prompt_eval_duration" in data and data["prompt_eval_duration"] is not None:
|
||||
data["prompt_eval_duration"] = 0
|
||||
if "eval_duration" in data and data["eval_duration"] is not None:
|
||||
data["eval_duration"] = 0
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _serialize_response(response: Any, request_hash: str = "") -> Any:
|
||||
if hasattr(response, "model_dump"):
|
||||
data = response.model_dump(mode="json")
|
||||
# Normalize fields to reduce noise
|
||||
data = _normalize_response_data(data, request_hash)
|
||||
return {
|
||||
"__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}",
|
||||
"__data__": data,
|
||||
|
|
@ -120,61 +182,121 @@ def _deserialize_response(data: dict[str, Any]) -> Any:
|
|||
class ResponseStorage:
|
||||
"""Handles SQLite index + JSON file storage/retrieval for inference recordings."""
|
||||
|
||||
def __init__(self, test_dir: Path):
|
||||
self.test_dir = test_dir
|
||||
self.responses_dir = self.test_dir / "responses"
|
||||
def __init__(self, base_dir: Path):
|
||||
self.base_dir = base_dir
|
||||
# Don't create responses_dir here - determine it per-test at runtime
|
||||
|
||||
self._ensure_directories()
|
||||
def _get_test_dir(self) -> Path:
|
||||
"""Get the recordings directory in the test file's parent directory.
|
||||
|
||||
For test at "tests/integration/inference/test_foo.py::test_bar",
|
||||
returns "tests/integration/inference/recordings/".
|
||||
"""
|
||||
test_id = _test_context.get()
|
||||
if test_id:
|
||||
# Extract the directory path from the test nodeid
|
||||
# e.g., "tests/integration/inference/test_basic.py::test_foo[params]"
|
||||
# -> get "tests/integration/inference"
|
||||
test_file = test_id.split("::")[0] # Remove test function part
|
||||
test_dir = Path(test_file).parent # Get parent directory
|
||||
|
||||
# Put recordings in a "recordings" subdirectory of the test's parent dir
|
||||
# e.g., "tests/integration/inference" -> "tests/integration/inference/recordings"
|
||||
return test_dir / "recordings"
|
||||
else:
|
||||
# Fallback for non-test contexts
|
||||
return self.base_dir / "recordings"
|
||||
|
||||
def _ensure_directories(self):
|
||||
self.test_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.responses_dir.mkdir(exist_ok=True)
|
||||
"""Ensure test-specific directories exist."""
|
||||
test_dir = self._get_test_dir()
|
||||
test_dir.mkdir(parents=True, exist_ok=True)
|
||||
return test_dir
|
||||
|
||||
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
|
||||
"""Store a request/response pair."""
|
||||
# Generate unique response filename
|
||||
short_hash = request_hash[:12]
|
||||
response_file = f"{short_hash}.json"
|
||||
responses_dir = self._ensure_directories()
|
||||
|
||||
# Use FULL hash (not truncated)
|
||||
response_file = f"{request_hash}.json"
|
||||
|
||||
# Serialize response body if needed
|
||||
serialized_response = dict(response)
|
||||
if "body" in serialized_response:
|
||||
if isinstance(serialized_response["body"], list):
|
||||
# Handle streaming responses (list of chunks)
|
||||
serialized_response["body"] = [_serialize_response(chunk) for chunk in serialized_response["body"]]
|
||||
serialized_response["body"] = [
|
||||
_serialize_response(chunk, request_hash) for chunk in serialized_response["body"]
|
||||
]
|
||||
else:
|
||||
# Handle single response
|
||||
serialized_response["body"] = _serialize_response(serialized_response["body"])
|
||||
serialized_response["body"] = _serialize_response(serialized_response["body"], request_hash)
|
||||
|
||||
# If this is an Ollama /api/tags recording, include models digest in filename to distinguish variants
|
||||
# For model-list endpoints, include digest in filename to distinguish different model sets
|
||||
endpoint = request.get("endpoint")
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
digest = _model_identifiers_digest(endpoint, response)
|
||||
response_file = f"models-{short_hash}-{digest}.json"
|
||||
response_file = f"models-{request_hash}-{digest}.json"
|
||||
|
||||
response_path = self.responses_dir / response_file
|
||||
response_path = responses_dir / response_file
|
||||
|
||||
# Save response to JSON file
|
||||
# Save response to JSON file with metadata
|
||||
with open(response_path, "w") as f:
|
||||
json.dump({"request": request, "response": serialized_response}, f, indent=2)
|
||||
json.dump(
|
||||
{
|
||||
"test_id": _test_context.get(), # Include for debugging
|
||||
"request": request,
|
||||
"response": serialized_response,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
f.write("\n")
|
||||
f.flush()
|
||||
|
||||
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
|
||||
"""Find a recorded response by request hash."""
|
||||
response_file = f"{request_hash[:12]}.json"
|
||||
response_path = self.responses_dir / response_file
|
||||
"""Find a recorded response by request hash.
|
||||
|
||||
if not response_path.exists():
|
||||
return None
|
||||
Uses fallback: first checks test-specific dir, then falls back to base recordings dir.
|
||||
This handles cases where recordings happen during session setup (no test context) but
|
||||
are requested during tests (with test context).
|
||||
"""
|
||||
response_file = f"{request_hash}.json"
|
||||
|
||||
return _recording_from_file(response_path)
|
||||
# Try test-specific directory first
|
||||
test_dir = self._get_test_dir()
|
||||
response_path = test_dir / response_file
|
||||
|
||||
def _model_list_responses(self, short_hash: str) -> list[dict[str, Any]]:
|
||||
if response_path.exists():
|
||||
return _recording_from_file(response_path)
|
||||
|
||||
# Fallback to base recordings directory (for session-level recordings)
|
||||
fallback_dir = self.base_dir / "recordings"
|
||||
fallback_path = fallback_dir / response_file
|
||||
|
||||
if fallback_path.exists():
|
||||
return _recording_from_file(fallback_path)
|
||||
|
||||
return None
|
||||
|
||||
def _model_list_responses(self, request_hash: str) -> list[dict[str, Any]]:
|
||||
"""Find all model-list recordings with the given hash (different digests)."""
|
||||
results: list[dict[str, Any]] = []
|
||||
for path in self.responses_dir.glob(f"models-{short_hash}-*.json"):
|
||||
data = _recording_from_file(path)
|
||||
results.append(data)
|
||||
|
||||
# Check test-specific directory first
|
||||
test_dir = self._get_test_dir()
|
||||
if test_dir.exists():
|
||||
for path in test_dir.glob(f"models-{request_hash}-*.json"):
|
||||
data = _recording_from_file(path)
|
||||
results.append(data)
|
||||
|
||||
# Also check fallback directory
|
||||
fallback_dir = self.base_dir / "recordings"
|
||||
if fallback_dir.exists():
|
||||
for path in fallback_dir.glob(f"models-{request_hash}-*.json"):
|
||||
data = _recording_from_file(path)
|
||||
results.append(data)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
|
@ -195,6 +317,8 @@ def _recording_from_file(response_path) -> dict[str, Any]:
|
|||
|
||||
|
||||
def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
|
||||
"""Generate a digest from model identifiers for distinguishing different model sets."""
|
||||
|
||||
def _extract_model_identifiers():
|
||||
"""Extract a stable set of identifiers for model-list endpoints.
|
||||
|
||||
|
|
@ -217,7 +341,14 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
|
|||
|
||||
|
||||
def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]]) -> dict[str, Any] | None:
|
||||
"""Return a single, unioned recording for supported model-list endpoints."""
|
||||
"""Return a single, unioned recording for supported model-list endpoints.
|
||||
|
||||
Merges multiple recordings with different model sets (from different servers) into
|
||||
a single response containing all models.
|
||||
"""
|
||||
if not records:
|
||||
return None
|
||||
|
||||
seen: dict[str, dict[str, Any]] = {}
|
||||
for rec in records:
|
||||
body = rec["response"]["body"]
|
||||
|
|
@ -246,7 +377,10 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
|
|||
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
||||
global _current_mode, _current_storage
|
||||
|
||||
if _current_mode == InferenceMode.LIVE or _current_storage is None:
|
||||
mode = _current_mode
|
||||
storage = _current_storage
|
||||
|
||||
if mode == InferenceMode.LIVE or storage is None:
|
||||
if endpoint == "/v1/models":
|
||||
return original_method(self, *args, **kwargs)
|
||||
else:
|
||||
|
|
@ -277,13 +411,16 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
|
||||
request_hash = normalize_request(method, url, headers, body)
|
||||
|
||||
if _current_mode == InferenceMode.REPLAY:
|
||||
# Special handling for model-list endpoints: return union of all responses
|
||||
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
|
||||
recording = None
|
||||
if mode == InferenceMode.REPLAY or mode == InferenceMode.RECORD_IF_MISSING:
|
||||
# Special handling for model-list endpoints: merge all recordings with this hash
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
records = _current_storage._model_list_responses(request_hash[:12])
|
||||
records = storage._model_list_responses(request_hash)
|
||||
recording = _combine_model_list_responses(endpoint, records)
|
||||
else:
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
recording = storage.find_recording(request_hash)
|
||||
|
||||
if recording:
|
||||
response_body = recording["response"]["body"]
|
||||
|
||||
|
|
@ -296,7 +433,8 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
return replay_stream()
|
||||
else:
|
||||
return response_body
|
||||
else:
|
||||
elif mode == InferenceMode.REPLAY:
|
||||
# REPLAY mode requires recording to exist
|
||||
raise RuntimeError(
|
||||
f"No recorded response found for request hash: {request_hash}\n"
|
||||
f"Request: {method} {url} {body}\n"
|
||||
|
|
@ -304,7 +442,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
||||
)
|
||||
|
||||
elif _current_mode == InferenceMode.RECORD:
|
||||
if mode == InferenceMode.RECORD or (mode == InferenceMode.RECORD_IF_MISSING and not recording):
|
||||
if endpoint == "/v1/models":
|
||||
response = original_method(self, *args, **kwargs)
|
||||
else:
|
||||
|
|
@ -335,7 +473,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
|
||||
# Store the recording immediately
|
||||
response_data = {"body": chunks, "is_streaming": True}
|
||||
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Return a generator that replays the stored chunks
|
||||
async def replay_recorded_stream():
|
||||
|
|
@ -345,11 +483,11 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
return replay_recorded_stream()
|
||||
else:
|
||||
response_data = {"body": response, "is_streaming": False}
|
||||
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
return response
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Invalid mode: {_current_mode}")
|
||||
raise AssertionError(f"Invalid mode: {mode}")
|
||||
|
||||
|
||||
def patch_inference_clients():
|
||||
|
|
@ -490,9 +628,9 @@ def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Gen
|
|||
try:
|
||||
_current_mode = mode
|
||||
|
||||
if mode in ["record", "replay"]:
|
||||
if mode in ["record", "replay", "record-if-missing"]:
|
||||
if storage_dir is None:
|
||||
raise ValueError("storage_dir is required for record and replay modes")
|
||||
raise ValueError("storage_dir is required for record, replay, and record-if-missing modes")
|
||||
_current_storage = ResponseStorage(Path(storage_dir))
|
||||
patch_inference_clients()
|
||||
|
||||
|
|
@ -500,7 +638,7 @@ def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Gen
|
|||
|
||||
finally:
|
||||
# Restore previous state
|
||||
if mode in ["record", "replay"]:
|
||||
if mode in ["record", "replay", "record-if-missing"]:
|
||||
unpatch_inference_clients()
|
||||
|
||||
_current_mode = prev_mode
|
||||
|
|
|
|||
118
llama_stack/ui/package-lock.json
generated
118
llama_stack/ui/package-lock.json
generated
|
|
@ -20,11 +20,11 @@
|
|||
"framer-motion": "^12.23.12",
|
||||
"llama-stack-client": "^0.2.23",
|
||||
"lucide-react": "^0.542.0",
|
||||
"next": "15.5.3",
|
||||
"next": "15.5.4",
|
||||
"next-auth": "^4.24.11",
|
||||
"next-themes": "^0.4.6",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.1.1",
|
||||
"react-dom": "^19.2.0",
|
||||
"react-markdown": "^10.1.0",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"remeda": "^2.32.0",
|
||||
|
|
@ -2279,9 +2279,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@next/env": {
|
||||
"version": "15.5.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-15.5.3.tgz",
|
||||
"integrity": "sha512-RSEDTRqyihYXygx/OJXwvVupfr9m04+0vH8vyy0HfZ7keRto6VX9BbEk0J2PUk0VGy6YhklJUSrgForov5F9pw==",
|
||||
"version": "15.5.4",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-15.5.4.tgz",
|
||||
"integrity": "sha512-27SQhYp5QryzIT5uO8hq99C69eLQ7qkzkDPsk3N+GuS2XgOgoYEeOav7Pf8Tn4drECOVDsDg8oj+/DVy8qQL2A==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@next/eslint-plugin-next": {
|
||||
|
|
@ -2295,9 +2295,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-arm64": {
|
||||
"version": "15.5.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-15.5.3.tgz",
|
||||
"integrity": "sha512-nzbHQo69+au9wJkGKTU9lP7PXv0d1J5ljFpvb+LnEomLtSbJkbZyEs6sbF3plQmiOB2l9OBtN2tNSvCH1nQ9Jg==",
|
||||
"version": "15.5.4",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-15.5.4.tgz",
|
||||
"integrity": "sha512-nopqz+Ov6uvorej8ndRX6HlxCYWCO3AHLfKK2TYvxoSB2scETOcfm/HSS3piPqc3A+MUgyHoqE6je4wnkjfrOA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
|
|
@ -2311,9 +2311,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-x64": {
|
||||
"version": "15.5.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-15.5.3.tgz",
|
||||
"integrity": "sha512-w83w4SkOOhekJOcA5HBvHyGzgV1W/XvOfpkrxIse4uPWhYTTRwtGEM4v/jiXwNSJvfRvah0H8/uTLBKRXlef8g==",
|
||||
"version": "15.5.4",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-15.5.4.tgz",
|
||||
"integrity": "sha512-QOTCFq8b09ghfjRJKfb68kU9k2K+2wsC4A67psOiMn849K9ZXgCSRQr0oVHfmKnoqCbEmQWG1f2h1T2vtJJ9mA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
|
|
@ -2327,9 +2327,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-gnu": {
|
||||
"version": "15.5.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-15.5.3.tgz",
|
||||
"integrity": "sha512-+m7pfIs0/yvgVu26ieaKrifV8C8yiLe7jVp9SpcIzg7XmyyNE7toC1fy5IOQozmr6kWl/JONC51osih2RyoXRw==",
|
||||
"version": "15.5.4",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-15.5.4.tgz",
|
||||
"integrity": "sha512-eRD5zkts6jS3VfE/J0Kt1VxdFqTnMc3QgO5lFE5GKN3KDI/uUpSyK3CjQHmfEkYR4wCOl0R0XrsjpxfWEA++XA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
|
|
@ -2343,9 +2343,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-musl": {
|
||||
"version": "15.5.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-15.5.3.tgz",
|
||||
"integrity": "sha512-u3PEIzuguSenoZviZJahNLgCexGFhso5mxWCrrIMdvpZn6lkME5vc/ADZG8UUk5K1uWRy4hqSFECrON6UKQBbQ==",
|
||||
"version": "15.5.4",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-15.5.4.tgz",
|
||||
"integrity": "sha512-TOK7iTxmXFc45UrtKqWdZ1shfxuL4tnVAOuuJK4S88rX3oyVV4ZkLjtMT85wQkfBrOOvU55aLty+MV8xmcJR8A==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
|
|
@ -2359,9 +2359,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-gnu": {
|
||||
"version": "15.5.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-15.5.3.tgz",
|
||||
"integrity": "sha512-lDtOOScYDZxI2BENN9m0pfVPJDSuUkAD1YXSvlJF0DKwZt0WlA7T7o3wrcEr4Q+iHYGzEaVuZcsIbCps4K27sA==",
|
||||
"version": "15.5.4",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-15.5.4.tgz",
|
||||
"integrity": "sha512-7HKolaj+481FSW/5lL0BcTkA4Ueam9SPYWyN/ib/WGAFZf0DGAN8frNpNZYFHtM4ZstrHZS3LY3vrwlIQfsiMA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
|
|
@ -2375,9 +2375,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-musl": {
|
||||
"version": "15.5.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-15.5.3.tgz",
|
||||
"integrity": "sha512-9vWVUnsx9PrY2NwdVRJ4dUURAQ8Su0sLRPqcCCxtX5zIQUBES12eRVHq6b70bbfaVaxIDGJN2afHui0eDm+cLg==",
|
||||
"version": "15.5.4",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-15.5.4.tgz",
|
||||
"integrity": "sha512-nlQQ6nfgN0nCO/KuyEUwwOdwQIGjOs4WNMjEUtpIQJPR2NUfmGpW2wkJln1d4nJ7oUzd1g4GivH5GoEPBgfsdw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
|
|
@ -2391,9 +2391,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-arm64-msvc": {
|
||||
"version": "15.5.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-15.5.3.tgz",
|
||||
"integrity": "sha512-1CU20FZzY9LFQigRi6jM45oJMU3KziA5/sSG+dXeVaTm661snQP6xu3ykGxxwU5sLG3sh14teO/IOEPVsQMRfA==",
|
||||
"version": "15.5.4",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-15.5.4.tgz",
|
||||
"integrity": "sha512-PcR2bN7FlM32XM6eumklmyWLLbu2vs+D7nJX8OAIoWy69Kef8mfiN4e8TUv2KohprwifdpFKPzIP1njuCjD0YA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
|
|
@ -2407,9 +2407,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-x64-msvc": {
|
||||
"version": "15.5.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-15.5.3.tgz",
|
||||
"integrity": "sha512-JMoLAq3n3y5tKXPQwCK5c+6tmwkuFDa2XAxz8Wm4+IVthdBZdZGh+lmiLUHg9f9IDwIQpUjp+ysd6OkYTyZRZw==",
|
||||
"version": "15.5.4",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-15.5.4.tgz",
|
||||
"integrity": "sha512-1ur2tSHZj8Px/KMAthmuI9FMp/YFusMMGoRNJaRZMOlSkgvLjzosSdQI0cJAKogdHl3qXUQKL9MGaYvKwA7DXg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
|
|
@ -3995,22 +3995,22 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@types/react": {
|
||||
"version": "19.1.4",
|
||||
"resolved": "https://registry.npmjs.org/@types/react/-/react-19.1.4.tgz",
|
||||
"integrity": "sha512-EB1yiiYdvySuIITtD5lhW4yPyJ31RkJkkDw794LaQYrxCSaQV/47y5o1FMC4zF9ZyjUjzJMZwbovEnT5yHTW6g==",
|
||||
"version": "19.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.0.tgz",
|
||||
"integrity": "sha512-1LOH8xovvsKsCBq1wnT4ntDUdCJKmnEakhsuoUSy6ExlHCkGP2hqnatagYTgFk6oeL0VU31u7SNjunPN+GchtA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"csstype": "^3.0.2"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/react-dom": {
|
||||
"version": "19.1.9",
|
||||
"resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-19.1.9.tgz",
|
||||
"integrity": "sha512-qXRuZaOsAdXKFyOhRBg6Lqqc0yay13vN7KrIg4L7N4aaHN68ma9OK3NE1BoDFgFOTfM7zg+3/8+2n8rLUH3OKQ==",
|
||||
"version": "19.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-19.2.0.tgz",
|
||||
"integrity": "sha512-brtBs0MnE9SMx7px208g39lRmC5uHZs96caOJfTjFcYSLHNamvaSMfJNagChVNkup2SdtOxKX1FDBkRSJe1ZAg==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"@types/react": "^19.0.0"
|
||||
"@types/react": "^19.2.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/stack-utils": {
|
||||
|
|
@ -11414,12 +11414,12 @@
|
|||
}
|
||||
},
|
||||
"node_modules/next": {
|
||||
"version": "15.5.3",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-15.5.3.tgz",
|
||||
"integrity": "sha512-r/liNAx16SQj4D+XH/oI1dlpv9tdKJ6cONYPwwcCC46f2NjpaRWY+EKCzULfgQYV6YKXjHBchff2IZBSlZmJNw==",
|
||||
"version": "15.5.4",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-15.5.4.tgz",
|
||||
"integrity": "sha512-xH4Yjhb82sFYQfY3vbkJfgSDgXvBB6a8xPs9i35k6oZJRoQRihZH+4s9Yo2qsWpzBmZ3lPXaJ2KPXLfkvW4LnA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@next/env": "15.5.3",
|
||||
"@next/env": "15.5.4",
|
||||
"@swc/helpers": "0.5.15",
|
||||
"caniuse-lite": "^1.0.30001579",
|
||||
"postcss": "8.4.31",
|
||||
|
|
@ -11432,14 +11432,14 @@
|
|||
"node": "^18.18.0 || ^19.8.0 || >= 20.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@next/swc-darwin-arm64": "15.5.3",
|
||||
"@next/swc-darwin-x64": "15.5.3",
|
||||
"@next/swc-linux-arm64-gnu": "15.5.3",
|
||||
"@next/swc-linux-arm64-musl": "15.5.3",
|
||||
"@next/swc-linux-x64-gnu": "15.5.3",
|
||||
"@next/swc-linux-x64-musl": "15.5.3",
|
||||
"@next/swc-win32-arm64-msvc": "15.5.3",
|
||||
"@next/swc-win32-x64-msvc": "15.5.3",
|
||||
"@next/swc-darwin-arm64": "15.5.4",
|
||||
"@next/swc-darwin-x64": "15.5.4",
|
||||
"@next/swc-linux-arm64-gnu": "15.5.4",
|
||||
"@next/swc-linux-arm64-musl": "15.5.4",
|
||||
"@next/swc-linux-x64-gnu": "15.5.4",
|
||||
"@next/swc-linux-x64-musl": "15.5.4",
|
||||
"@next/swc-win32-arm64-msvc": "15.5.4",
|
||||
"@next/swc-win32-x64-msvc": "15.5.4",
|
||||
"sharp": "^0.34.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
|
|
@ -12450,24 +12450,24 @@
|
|||
}
|
||||
},
|
||||
"node_modules/react": {
|
||||
"version": "19.1.1",
|
||||
"resolved": "https://registry.npmjs.org/react/-/react-19.1.1.tgz",
|
||||
"integrity": "sha512-w8nqGImo45dmMIfljjMwOGtbmC/mk4CMYhWIicdSflH91J9TyCyczcPFXJzrZ/ZXcgGRFeP6BU0BEJTw6tZdfQ==",
|
||||
"version": "19.2.0",
|
||||
"resolved": "https://registry.npmjs.org/react/-/react-19.2.0.tgz",
|
||||
"integrity": "sha512-tmbWg6W31tQLeB5cdIBOicJDJRR2KzXsV7uSK9iNfLWQ5bIZfxuPEHp7M8wiHyHnn0DD1i7w3Zmin0FtkrwoCQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/react-dom": {
|
||||
"version": "19.1.1",
|
||||
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.1.1.tgz",
|
||||
"integrity": "sha512-Dlq/5LAZgF0Gaz6yiqZCf6VCcZs1ghAJyrsu84Q/GT0gV+mCxbfmKNoGRKBYMJ8IEdGPqu49YWXD02GCknEDkw==",
|
||||
"version": "19.2.0",
|
||||
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.2.0.tgz",
|
||||
"integrity": "sha512-UlbRu4cAiGaIewkPyiRGJk0imDN2T3JjieT6spoL2UeSf5od4n5LB/mQ4ejmxhCFT1tYe8IvaFulzynWovsEFQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"scheduler": "^0.26.0"
|
||||
"scheduler": "^0.27.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": "^19.1.1"
|
||||
"react": "^19.2.0"
|
||||
}
|
||||
},
|
||||
"node_modules/react-is": {
|
||||
|
|
@ -12982,9 +12982,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/scheduler": {
|
||||
"version": "0.26.0",
|
||||
"resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.26.0.tgz",
|
||||
"integrity": "sha512-NlHwttCI/l5gCPR3D1nNXtWABUmBwvZpEQiD4IXSbIDq8BzLIK/7Ir5gTFSGZDUu37K5cMNp0hFtzO38sC7gWA==",
|
||||
"version": "0.27.0",
|
||||
"resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.27.0.tgz",
|
||||
"integrity": "sha512-eNv+WrVbKu1f3vbYJT/xtiF5syA5HPIMtf9IgY/nKg0sWqzAUEvqY/xm7OcZc/qafLx/iO9FgOmeSAp4v5ti/Q==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/semver": {
|
||||
|
|
|
|||
|
|
@ -25,11 +25,11 @@
|
|||
"framer-motion": "^12.23.12",
|
||||
"llama-stack-client": "^0.2.23",
|
||||
"lucide-react": "^0.542.0",
|
||||
"next": "15.5.3",
|
||||
"next": "15.5.4",
|
||||
"next-auth": "^4.24.11",
|
||||
"next-themes": "^0.4.6",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.1.1",
|
||||
"react-dom": "^19.2.0",
|
||||
"react-markdown": "^10.1.0",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"remeda": "^2.32.0",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue