mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-14 16:52:37 +00:00
Merge branch 'llamastack:main' into model_unregisteration_error_message
This commit is contained in:
commit
aa09a44c94
1036 changed files with 314835 additions and 114394 deletions
|
|
@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
|
||||
|
||||
from .openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
|
|
@ -42,6 +42,20 @@ from .openai_responses import (
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ResponseShieldSpec(BaseModel):
|
||||
"""Specification for a shield to apply during response generation.
|
||||
|
||||
:param type: The type/identifier of the shield.
|
||||
"""
|
||||
|
||||
type: str
|
||||
# TODO: more fields to be added for shield configuration
|
||||
|
||||
|
||||
ResponseShield = str | ResponseShieldSpec
|
||||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
"""An attachment to an agent turn.
|
||||
|
||||
|
|
@ -472,20 +486,23 @@ class AgentStepResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Agents(Protocol):
|
||||
"""Agents API for creating and interacting with agentic systems.
|
||||
"""Agents
|
||||
|
||||
Main functionalities provided by this API:
|
||||
- Create agents with specific instructions and ability to use tools.
|
||||
- Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn".
|
||||
- Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details).
|
||||
- Agents can be provided with various shields (see the Safety API for more details).
|
||||
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
|
||||
"""
|
||||
APIs for creating and interacting with agentic systems."""
|
||||
|
||||
@webmethod(
|
||||
route="/agents", method="POST", descriptive_name="create_agent", deprecated=True, level=LLAMA_STACK_API_V1
|
||||
route="/agents",
|
||||
method="POST",
|
||||
descriptive_name="create_agent",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents",
|
||||
method="POST",
|
||||
descriptive_name="create_agent",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
@webmethod(route="/agents", method="POST", descriptive_name="create_agent", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
|
|
@ -648,8 +665,17 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="GET",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
|
|
@ -666,9 +692,16 @@ class Agents(Protocol):
|
|||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}", method="DELETE", deprecated=True, level=LLAMA_STACK_API_V1
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="DELETE",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def delete_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
|
|
@ -681,7 +714,12 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="DELETE", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}",
|
||||
method="DELETE",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def delete_agent(
|
||||
self,
|
||||
|
|
@ -704,7 +742,12 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}",
|
||||
method="GET",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
"""Describe an agent by its ID.
|
||||
|
|
@ -714,7 +757,12 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/sessions",
|
||||
method="GET",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
|
|
@ -738,6 +786,12 @@ class Agents(Protocol):
|
|||
#
|
||||
# Both of these APIs are inherently stateful.
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/responses/{response_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(route="/responses/{response_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_openai_response(
|
||||
self,
|
||||
|
|
@ -750,6 +804,7 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/responses", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_openai_response(
|
||||
self,
|
||||
|
|
@ -764,6 +819,12 @@ class Agents(Protocol):
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||
shields: Annotated[
|
||||
list[ResponseShield] | None,
|
||||
ExtraBodyField(
|
||||
"List of shields to apply during response generation. Shields provide safety and content moderation."
|
||||
),
|
||||
] = None,
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a new OpenAI response.
|
||||
|
||||
|
|
@ -771,10 +832,12 @@ class Agents(Protocol):
|
|||
:param model: The underlying LLM used for completions.
|
||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||
:param include: (Optional) Additional fields to include in the response.
|
||||
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/responses", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
|
|
@ -793,6 +856,9 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_openai_response_input_items(
|
||||
self,
|
||||
|
|
@ -815,6 +881,7 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
"""Delete an OpenAI response by its ID.
|
||||
|
|
|
|||
|
|
@ -888,6 +888,10 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
|||
|
||||
input: list[OpenAIResponseInput]
|
||||
|
||||
def to_response_object(self) -> OpenAIResponseObject:
|
||||
"""Convert to OpenAIResponseObject by excluding input field."""
|
||||
return OpenAIResponseObject(**{k: v for k, v in self.model_dump().items() if k != "input"})
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListOpenAIResponseObject(BaseModel):
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ class Batches(Protocol):
|
|||
Note: This API is currently under active development and may undergo changes.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_batch(
|
||||
self,
|
||||
|
|
@ -63,6 +64,7 @@ class Batches(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Retrieve information about a specific batch.
|
||||
|
|
@ -72,6 +74,7 @@ class Batches(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Cancel a batch that is in progress.
|
||||
|
|
@ -81,6 +84,7 @@ class Batches(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_batches(
|
||||
self,
|
||||
|
|
|
|||
31
llama_stack/apis/conversations/__init__.py
Normal file
31
llama_stack/apis/conversations/__init__.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .conversations import (
|
||||
Conversation,
|
||||
ConversationCreateRequest,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
ConversationItemCreateRequest,
|
||||
ConversationItemDeletedResource,
|
||||
ConversationItemList,
|
||||
Conversations,
|
||||
ConversationUpdateRequest,
|
||||
Metadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Conversation",
|
||||
"ConversationCreateRequest",
|
||||
"ConversationDeletedResource",
|
||||
"ConversationItem",
|
||||
"ConversationItemCreateRequest",
|
||||
"ConversationItemDeletedResource",
|
||||
"ConversationItemList",
|
||||
"Conversations",
|
||||
"ConversationUpdateRequest",
|
||||
"Metadata",
|
||||
]
|
||||
260
llama_stack/apis/conversations/conversations.py
Normal file
260
llama_stack/apis/conversations/conversations.py
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Literal, Protocol, runtime_checkable
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
from openai._types import NotGiven
|
||||
from openai.types.responses.response_includable import ResponseIncludable
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
Metadata = dict[str, str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Conversation(BaseModel):
|
||||
"""OpenAI-compatible conversation object."""
|
||||
|
||||
id: str = Field(..., description="The unique ID of the conversation.")
|
||||
object: Literal["conversation"] = Field(
|
||||
default="conversation", description="The object type, which is always conversation."
|
||||
)
|
||||
created_at: int = Field(
|
||||
..., description="The time at which the conversation was created, measured in seconds since the Unix epoch."
|
||||
)
|
||||
metadata: Metadata | None = Field(
|
||||
default=None,
|
||||
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard.",
|
||||
)
|
||||
items: list[dict] | None = Field(
|
||||
default=None,
|
||||
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationMessage(BaseModel):
|
||||
"""OpenAI-compatible message item for conversations."""
|
||||
|
||||
id: str = Field(..., description="unique identifier for this message")
|
||||
content: list[dict] = Field(..., description="message content")
|
||||
role: str = Field(..., description="message role")
|
||||
status: str = Field(..., description="message status")
|
||||
type: Literal["message"] = "message"
|
||||
object: Literal["message"] = "message"
|
||||
|
||||
|
||||
ConversationItem = Annotated[
|
||||
OpenAIResponseMessage
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageMCPCall
|
||||
| OpenAIResponseOutputMessageMCPListTools,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ConversationItem, name="ConversationItem")
|
||||
|
||||
# Using OpenAI types directly caused issues but some notes for reference:
|
||||
# Note that ConversationItem is a Annotated Union of the types below:
|
||||
# from openai.types.responses import *
|
||||
# from openai.types.responses.response_item import *
|
||||
# from openai.types.conversations import ConversationItem
|
||||
# f = [
|
||||
# ResponseFunctionToolCallItem,
|
||||
# ResponseFunctionToolCallOutputItem,
|
||||
# ResponseFileSearchToolCall,
|
||||
# ResponseFunctionWebSearch,
|
||||
# ImageGenerationCall,
|
||||
# ResponseComputerToolCall,
|
||||
# ResponseComputerToolCallOutputItem,
|
||||
# ResponseReasoningItem,
|
||||
# ResponseCodeInterpreterToolCall,
|
||||
# LocalShellCall,
|
||||
# LocalShellCallOutput,
|
||||
# McpListTools,
|
||||
# McpApprovalRequest,
|
||||
# McpApprovalResponse,
|
||||
# McpCall,
|
||||
# ResponseCustomToolCall,
|
||||
# ResponseCustomToolCallOutput
|
||||
# ]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationCreateRequest(BaseModel):
|
||||
"""Request body for creating a conversation."""
|
||||
|
||||
items: list[ConversationItem] | None = Field(
|
||||
default=[],
|
||||
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
max_length=20,
|
||||
)
|
||||
metadata: Metadata | None = Field(
|
||||
default={},
|
||||
description="Set of 16 key-value pairs that can be attached to an object. Useful for storing additional information",
|
||||
max_length=16,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationUpdateRequest(BaseModel):
|
||||
"""Request body for updating a conversation."""
|
||||
|
||||
metadata: Metadata = Field(
|
||||
...,
|
||||
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters.",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationDeletedResource(BaseModel):
|
||||
"""Response for deleted conversation."""
|
||||
|
||||
id: str = Field(..., description="The deleted conversation identifier")
|
||||
object: str = Field(default="conversation.deleted", description="Object type")
|
||||
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationItemCreateRequest(BaseModel):
|
||||
"""Request body for creating conversation items."""
|
||||
|
||||
items: list[ConversationItem] = Field(
|
||||
...,
|
||||
description="Items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
max_length=20,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationItemList(BaseModel):
|
||||
"""List of conversation items with pagination."""
|
||||
|
||||
object: str = Field(default="list", description="Object type")
|
||||
data: list[ConversationItem] = Field(..., description="List of conversation items")
|
||||
first_id: str | None = Field(default=None, description="The ID of the first item in the list")
|
||||
last_id: str | None = Field(default=None, description="The ID of the last item in the list")
|
||||
has_more: bool = Field(default=False, description="Whether there are more items available")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationItemDeletedResource(BaseModel):
|
||||
"""Response for deleted conversation item."""
|
||||
|
||||
id: str = Field(..., description="The deleted item identifier")
|
||||
object: str = Field(default="conversation.item.deleted", description="Object type")
|
||||
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Conversations(Protocol):
|
||||
"""Protocol for conversation management operations."""
|
||||
|
||||
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_conversation(
|
||||
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||
) -> Conversation:
|
||||
"""Create a conversation.
|
||||
|
||||
:param items: Initial items to include in the conversation context.
|
||||
:param metadata: Set of key-value pairs that can be attached to an object.
|
||||
:returns: The created conversation object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Get a conversation with the given ID.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:returns: The conversation object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||
"""Update a conversation's metadata with the given ID.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param metadata: Set of key-value pairs that can be attached to an object.
|
||||
:returns: The updated conversation object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||
"""Delete a conversation with the given ID.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:returns: The deleted conversation resource.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||
"""Create items in the conversation.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param items: Items to include in the conversation context.
|
||||
:returns: List of created items.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||
"""Retrieve a conversation item.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param item_id: The item identifier.
|
||||
:returns: The conversation item.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list(
|
||||
self,
|
||||
conversation_id: str,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
include: list[ResponseIncludable] | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
|
||||
) -> ConversationItemList:
|
||||
"""List items in the conversation.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param after: An item ID to list items after, used in pagination.
|
||||
:param include: Specify additional output data to include in the response.
|
||||
:param limit: A limit on the number of objects to be returned (1-100, default 20).
|
||||
:param order: The order to return items in (asc or desc, default desc).
|
||||
:returns: List of conversation items.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def openai_delete_conversation_item(
|
||||
self, conversation_id: str, item_id: str
|
||||
) -> ConversationItemDeletedResource:
|
||||
"""Delete a conversation item.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param item_id: The item identifier.
|
||||
:returns: The deleted item resource.
|
||||
"""
|
||||
...
|
||||
|
|
@ -8,7 +8,7 @@ from typing import Any, Protocol, runtime_checkable
|
|||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
|
|
@ -21,7 +21,8 @@ class DatasetIO(Protocol):
|
|||
# keeping for aligning with inference/safety, but this is not used
|
||||
dataset_store: DatasetStore
|
||||
|
||||
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
@ -45,7 +46,10 @@ class DatasetIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(
|
||||
route="/datasetio/append-rows/{dataset_id:path}", method="POST", deprecated=True, level=LLAMA_STACK_API_V1
|
||||
)
|
||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST", level=LLAMA_STACK_API_V1BETA)
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
"""Append rows to a dataset.
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from typing import Annotated, Any, Literal, Protocol
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
|
|
@ -146,7 +146,8 @@ class ListDatasetsResponse(BaseModel):
|
|||
|
||||
|
||||
class Datasets(Protocol):
|
||||
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets", method="POST", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA)
|
||||
async def register_dataset(
|
||||
self,
|
||||
purpose: DatasetPurpose,
|
||||
|
|
@ -215,7 +216,8 @@ class Datasets(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
|
||||
async def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
@ -227,7 +229,8 @@ class Datasets(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets", method="GET", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets", method="GET", level=LLAMA_STACK_API_V1BETA)
|
||||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
"""List all datasets.
|
||||
|
||||
|
|
@ -235,7 +238,8 @@ class Datasets(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA)
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
|
|||
|
|
@ -129,6 +129,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
tool_groups = "tool_groups"
|
||||
files = "files"
|
||||
prompts = "prompts"
|
||||
conversations = "conversations"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
|
|
|||
|
|
@ -105,6 +105,7 @@ class OpenAIFileDeleteResponse(BaseModel):
|
|||
@trace_protocol
|
||||
class Files(Protocol):
|
||||
# OpenAI Files API Endpoints
|
||||
@webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_upload_file(
|
||||
self,
|
||||
|
|
@ -127,6 +128,7 @@ class Files(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_list_files(
|
||||
self,
|
||||
|
|
@ -146,6 +148,7 @@ class Files(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_retrieve_file(
|
||||
self,
|
||||
|
|
@ -159,6 +162,7 @@ class Files(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def openai_delete_file(
|
||||
self,
|
||||
|
|
@ -172,6 +176,7 @@ class Files(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_retrieve_file_content(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -27,14 +27,12 @@ from llama_stack.models.llama.datatypes import (
|
|||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
register_schema(ToolCall)
|
||||
register_schema(ToolParamDefinition)
|
||||
register_schema(ToolDefinition)
|
||||
|
||||
from enum import StrEnum
|
||||
|
|
@ -1008,67 +1006,6 @@ class InferenceProvider(Protocol):
|
|||
|
||||
model_store: ModelStore | None = None
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
"""Generate a completion for the given content using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param content: The content to generate a completion for.
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:returns: If stream=False, returns a CompletionResponse with the full completion.
|
||||
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
"""Generate a chat completion for the given messages using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation.
|
||||
:param sampling_params: Parameters to control the sampling strategy.
|
||||
:param tools: (Optional) List of tool definitions available to the model.
|
||||
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||
.. deprecated::
|
||||
Use tool_config instead.
|
||||
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
||||
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
||||
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
||||
.. deprecated::
|
||||
Use tool_config instead.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
|
||||
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
|
||||
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
|
||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:param tool_config: (Optional) Configuration for tool use.
|
||||
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
|
||||
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/rerank", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def rerank(
|
||||
self,
|
||||
|
|
@ -1088,6 +1025,7 @@ class InferenceProvider(Protocol):
|
|||
raise NotImplementedError("Reranking is not implemented")
|
||||
return # this is so mypy's safe-super rule will consider the method concrete
|
||||
|
||||
@webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_completion(
|
||||
self,
|
||||
|
|
@ -1139,6 +1077,7 @@ class InferenceProvider(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
|
@ -1195,6 +1134,7 @@ class InferenceProvider(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
|
@ -1224,6 +1164,7 @@ class Inference(InferenceProvider):
|
|||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
|
|
@ -1242,6 +1183,9 @@ class Inference(InferenceProvider):
|
|||
"""
|
||||
raise NotImplementedError("List chat completions is not implemented")
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
"""Describe a chat completion by its ID.
|
||||
|
|
|
|||
|
|
@ -111,6 +111,14 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
"""List models using the OpenAI API.
|
||||
|
||||
:returns: A OpenAIListModelsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_model(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -114,6 +114,7 @@ class Safety(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
"""Classifies if text and/or image inputs are potentially harmful.
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from typing import (
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.models.llama.datatypes import Primitive
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
|
@ -426,7 +426,14 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(
|
||||
route="/telemetry/traces",
|
||||
method="POST",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition] | None = None,
|
||||
|
|
@ -445,7 +452,17 @@ class Telemetry(Protocol):
|
|||
...
|
||||
|
||||
@webmethod(
|
||||
route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1
|
||||
route="/telemetry/traces/{trace_id:path}",
|
||||
method="GET",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/telemetry/traces/{trace_id:path}",
|
||||
method="GET",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
"""Get a trace by its ID.
|
||||
|
|
@ -459,8 +476,15 @@ class Telemetry(Protocol):
|
|||
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}",
|
||||
method="GET",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}",
|
||||
method="GET",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span:
|
||||
"""Get a span by its ID.
|
||||
|
||||
|
|
@ -473,9 +497,16 @@ class Telemetry(Protocol):
|
|||
@webmethod(
|
||||
route="/telemetry/spans/{span_id:path}/tree",
|
||||
method="POST",
|
||||
deprecated=True,
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/telemetry/spans/{span_id:path}/tree",
|
||||
method="POST",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
|
|
@ -491,7 +522,14 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(
|
||||
route="/telemetry/spans",
|
||||
method="POST",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def query_spans(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition],
|
||||
|
|
@ -507,7 +545,8 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans/export", method="POST", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/telemetry/spans/export", method="POST", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/telemetry/spans/export", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def save_spans_to_dataset(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition],
|
||||
|
|
@ -525,7 +564,17 @@ class Telemetry(Protocol):
|
|||
...
|
||||
|
||||
@webmethod(
|
||||
route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1
|
||||
route="/telemetry/metrics/{metric_name}",
|
||||
method="POST",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/telemetry/metrics/{metric_name}",
|
||||
method="POST",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def query_metrics(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
|
|
@ -19,59 +19,23 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
from .rag_tool import RAGToolRuntime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolParameter(BaseModel):
|
||||
"""Parameter definition for a tool.
|
||||
|
||||
:param name: Name of the parameter
|
||||
:param parameter_type: Type of the parameter (e.g., string, integer)
|
||||
:param description: Human-readable description of what the parameter does
|
||||
:param required: Whether this parameter is required for tool invocation
|
||||
:param items: Type of the elements when parameter_type is array
|
||||
:param title: (Optional) Title of the parameter
|
||||
:param default: (Optional) Default value for the parameter if not provided
|
||||
"""
|
||||
|
||||
name: str
|
||||
parameter_type: str
|
||||
description: str
|
||||
required: bool = Field(default=True)
|
||||
items: dict | None = None
|
||||
title: str | None = None
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
"""A tool that can be invoked by agents.
|
||||
|
||||
:param type: Type of resource, always 'tool'
|
||||
:param toolgroup_id: ID of the tool group this tool belongs to
|
||||
:param description: Human-readable description of what the tool does
|
||||
:param parameters: List of parameters this tool accepts
|
||||
:param metadata: (Optional) Additional metadata about the tool
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.tool] = ResourceType.tool
|
||||
toolgroup_id: str
|
||||
description: str
|
||||
parameters: list[ToolParameter]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDef(BaseModel):
|
||||
"""Tool definition used in runtime contexts.
|
||||
|
||||
:param name: Name of the tool
|
||||
:param description: (Optional) Human-readable description of what the tool does
|
||||
:param parameters: (Optional) List of parameters this tool accepts
|
||||
:param input_schema: (Optional) JSON Schema for tool inputs (MCP inputSchema)
|
||||
:param output_schema: (Optional) JSON Schema for tool outputs (MCP outputSchema)
|
||||
:param metadata: (Optional) Additional metadata about the tool
|
||||
:param toolgroup_id: (Optional) ID of the tool group this tool belongs to
|
||||
"""
|
||||
|
||||
toolgroup_id: str | None = None
|
||||
name: str
|
||||
description: str | None = None
|
||||
parameters: list[ToolParameter] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
output_schema: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
|
|
@ -122,7 +86,7 @@ class ToolInvocationResult(BaseModel):
|
|||
|
||||
|
||||
class ToolStore(Protocol):
|
||||
async def get_tool(self, tool_name: str) -> Tool: ...
|
||||
async def get_tool(self, tool_name: str) -> ToolDef: ...
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
||||
|
||||
|
||||
|
|
@ -135,15 +99,6 @@ class ListToolGroupsResponse(BaseModel):
|
|||
data: list[ToolGroup]
|
||||
|
||||
|
||||
class ListToolsResponse(BaseModel):
|
||||
"""Response containing a list of tools.
|
||||
|
||||
:param data: List of tools
|
||||
"""
|
||||
|
||||
data: list[Tool]
|
||||
|
||||
|
||||
class ListToolDefsResponse(BaseModel):
|
||||
"""Response containing a list of tool definitions.
|
||||
|
||||
|
|
@ -194,11 +149,11 @@ class ToolGroups(Protocol):
|
|||
...
|
||||
|
||||
@webmethod(route="/tools", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
|
||||
"""List tools with optional tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to list tools for.
|
||||
:returns: A ListToolsResponse.
|
||||
:returns: A ListToolDefsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
@ -206,11 +161,11 @@ class ToolGroups(Protocol):
|
|||
async def get_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
) -> Tool:
|
||||
) -> ToolDef:
|
||||
"""Get a tool by its name.
|
||||
|
||||
:param tool_name: The name of the tool to get.
|
||||
:returns: A Tool.
|
||||
:returns: A ToolDef.
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
|
|||
|
|
@ -512,6 +512,7 @@ class VectorIO(Protocol):
|
|||
...
|
||||
|
||||
# OpenAI Vector Stores API endpoints
|
||||
@webmethod(route="/openai/v1/vector_stores", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/vector_stores", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
|
|
@ -538,6 +539,7 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/vector_stores", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
|
|
@ -556,6 +558,9 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
|
|
@ -568,6 +573,9 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}",
|
||||
method="POST",
|
||||
|
|
@ -590,6 +598,9 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}",
|
||||
method="DELETE",
|
||||
|
|
@ -606,6 +617,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/search",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/search",
|
||||
method="POST",
|
||||
|
|
@ -638,6 +655,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files",
|
||||
method="POST",
|
||||
|
|
@ -660,6 +683,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files",
|
||||
method="GET",
|
||||
|
|
@ -686,6 +715,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="GET",
|
||||
|
|
@ -704,6 +739,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}/content",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}/content",
|
||||
method="GET",
|
||||
|
|
@ -722,6 +763,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="POST",
|
||||
|
|
@ -742,6 +789,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="DELETE",
|
||||
|
|
@ -765,6 +818,12 @@ class VectorIO(Protocol):
|
|||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
|
@ -787,6 +846,12 @@ class VectorIO(Protocol):
|
|||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
|
|
@ -800,6 +865,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
|
||||
method="GET",
|
||||
|
|
@ -828,6 +899,12 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
|
||||
method="POST",
|
||||
|
|
|
|||
|
|
@ -6,11 +6,18 @@
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import ssl
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import uvicorn
|
||||
import yaml
|
||||
|
||||
from llama_stack.cli.stack.utils import ImageType
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.core.datatypes import LoggingConfig, StackRunConfig
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars, validate_env_pair
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
|
@ -146,23 +153,7 @@ class StackRun(Subcommand):
|
|||
# using the current environment packages.
|
||||
if not image_type and not image_name:
|
||||
logger.info("No image type or image name provided. Assuming environment packages.")
|
||||
from llama_stack.core.server.server import main as server_main
|
||||
|
||||
# Build the server args from the current args passed to the CLI
|
||||
server_args = argparse.Namespace()
|
||||
for arg in vars(args):
|
||||
# If this is a function, avoid passing it
|
||||
# "args" contains:
|
||||
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
|
||||
if callable(getattr(args, arg)):
|
||||
continue
|
||||
if arg == "config":
|
||||
server_args.config = str(config_file)
|
||||
else:
|
||||
setattr(server_args, arg, getattr(args, arg))
|
||||
|
||||
# Run the server
|
||||
server_main(server_args)
|
||||
self._uvicorn_run(config_file, args)
|
||||
else:
|
||||
run_args = formulate_run_args(image_type, image_name)
|
||||
|
||||
|
|
@ -184,6 +175,76 @@ class StackRun(Subcommand):
|
|||
|
||||
run_command(run_args)
|
||||
|
||||
def _uvicorn_run(self, config_file: Path | None, args: argparse.Namespace) -> None:
|
||||
if not config_file:
|
||||
self.parser.error("Config file is required")
|
||||
|
||||
# Set environment variables if provided
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
self.parser.error(f"Invalid environment variable format: {env_pair}")
|
||||
|
||||
config_file = resolve_config_or_distro(str(config_file), Mode.RUN)
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
port = args.port or config.server.port
|
||||
host = config.server.host or ["::", "0.0.0.0"]
|
||||
|
||||
# Set the config file in environment so create_app can find it
|
||||
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
|
||||
|
||||
uvicorn_config = {
|
||||
"factory": True,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
}
|
||||
|
||||
keyfile = config.server.tls_keyfile
|
||||
certfile = config.server.tls_certfile
|
||||
if keyfile and certfile:
|
||||
uvicorn_config["ssl_keyfile"] = config.server.tls_keyfile
|
||||
uvicorn_config["ssl_certfile"] = config.server.tls_certfile
|
||||
if config.server.tls_cafile:
|
||||
uvicorn_config["ssl_ca_certs"] = config.server.tls_cafile
|
||||
uvicorn_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
||||
|
||||
logger.info(
|
||||
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
logger.info(f"Listening on {host}:{port}")
|
||||
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
# stack trace when using Ctrl+C or kill -2 (SIGINT).
|
||||
# SIGTERM (kill -15) works fine without this because Python doesn't
|
||||
# have a default handler for it.
|
||||
#
|
||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
uvicorn.run("llama_stack.core.server.server:create_app", **uvicorn_config)
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
|
||||
def _start_ui_development_server(self, stack_server_port: int):
|
||||
logger.info("Attempting to start UI development server...")
|
||||
# Check if npm is available
|
||||
|
|
|
|||
|
|
@ -324,14 +324,14 @@ fi
|
|||
RUN pip uninstall -y uv
|
||||
EOF
|
||||
|
||||
# If a run config is provided, we use the --config flag
|
||||
# If a run config is provided, we use the llama stack CLI
|
||||
if [[ -n "$run_config" ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$RUN_CONFIG_PATH"]
|
||||
ENTRYPOINT ["llama", "stack", "run", "$RUN_CONFIG_PATH"]
|
||||
EOF
|
||||
elif [[ "$distro_or_config" != *.yaml ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$distro_or_config"]
|
||||
ENTRYPOINT ["llama", "stack", "run", "$distro_or_config"]
|
||||
EOF
|
||||
fi
|
||||
|
||||
|
|
|
|||
5
llama_stack/core/conversations/__init__.py
Normal file
5
llama_stack/core/conversations/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
306
llama_stack/core/conversations/conversations.py
Normal file
306
llama_stack/core/conversations/conversations.py
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
Conversation,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
ConversationItemDeletedResource,
|
||||
ConversationItemList,
|
||||
Conversations,
|
||||
Metadata,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import (
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreConfig,
|
||||
sqlstore_impl,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai::conversations")
|
||||
|
||||
|
||||
class ConversationServiceConfig(BaseModel):
|
||||
"""Configuration for the built-in conversation service.
|
||||
|
||||
:param conversations_store: SQL store configuration for conversations (defaults to SQLite)
|
||||
:param policy: Access control rules
|
||||
"""
|
||||
|
||||
conversations_store: SqlStoreConfig = SqliteSqlStoreConfig(
|
||||
db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix()
|
||||
)
|
||||
policy: list[AccessRule] = []
|
||||
|
||||
|
||||
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
"""Get the conversation service implementation."""
|
||||
impl = ConversationServiceImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class ConversationServiceImpl(Conversations):
|
||||
"""Built-in conversation service implementation using AuthorizedSqlStore."""
|
||||
|
||||
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
self.policy = config.policy
|
||||
|
||||
base_sql_store = sqlstore_impl(config.conversations_store)
|
||||
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the store and create tables."""
|
||||
if isinstance(self.config.conversations_store, SqliteSqlStoreConfig):
|
||||
os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True)
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"openai_conversations",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"items": ColumnType.JSON,
|
||||
"metadata": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"conversation_items",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"conversation_id": ColumnType.STRING,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"item_data": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
async def create_conversation(
|
||||
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||
) -> Conversation:
|
||||
"""Create a conversation."""
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
conversation_id = f"conv_{random_bytes.hex()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
record_data = {
|
||||
"id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"items": [],
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
await self.sql_store.insert(
|
||||
table="openai_conversations",
|
||||
data=record_data,
|
||||
)
|
||||
|
||||
if items:
|
||||
item_records = []
|
||||
for item in items:
|
||||
item_dict = item.model_dump()
|
||||
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||
|
||||
item_record = {
|
||||
"id": item_id,
|
||||
"conversation_id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
item_records.append(item_record)
|
||||
|
||||
await self.sql_store.insert(table="conversation_items", data=item_records)
|
||||
|
||||
conversation = Conversation(
|
||||
id=conversation_id,
|
||||
created_at=created_at,
|
||||
metadata=metadata,
|
||||
object="conversation",
|
||||
)
|
||||
|
||||
logger.info(f"Created conversation {conversation_id}")
|
||||
return conversation
|
||||
|
||||
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Get a conversation with the given ID."""
|
||||
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Conversation {conversation_id} not found")
|
||||
|
||||
return Conversation(
|
||||
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
|
||||
)
|
||||
|
||||
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||
"""Update a conversation's metadata with the given ID"""
|
||||
await self.sql_store.update(
|
||||
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
|
||||
)
|
||||
|
||||
return await self.get_conversation(conversation_id)
|
||||
|
||||
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||
"""Delete a conversation with the given ID."""
|
||||
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
|
||||
|
||||
logger.info(f"Deleted conversation {conversation_id}")
|
||||
return ConversationDeletedResource(id=conversation_id)
|
||||
|
||||
def _validate_conversation_id(self, conversation_id: str) -> None:
|
||||
"""Validate conversation ID format."""
|
||||
if not conversation_id.startswith("conv_"):
|
||||
raise ValueError(
|
||||
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
|
||||
)
|
||||
|
||||
def _get_or_generate_item_id(self, item: ConversationItem, item_dict: dict) -> str:
|
||||
"""Get existing item ID or generate one if missing."""
|
||||
if item.id is None:
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
if item.type == "message":
|
||||
item_id = f"msg_{random_bytes.hex()}"
|
||||
else:
|
||||
item_id = f"item_{random_bytes.hex()}"
|
||||
item_dict["id"] = item_id
|
||||
return item_id
|
||||
return item.id
|
||||
|
||||
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Validate conversation ID and return the conversation if it exists."""
|
||||
self._validate_conversation_id(conversation_id)
|
||||
return await self.get_conversation(conversation_id)
|
||||
|
||||
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||
"""Create (add) items to a conversation."""
|
||||
await self._get_validated_conversation(conversation_id)
|
||||
|
||||
created_items = []
|
||||
created_at = int(time.time())
|
||||
|
||||
for item in items:
|
||||
item_dict = item.model_dump()
|
||||
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||
|
||||
item_record = {
|
||||
"id": item_id,
|
||||
"conversation_id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
|
||||
try:
|
||||
await self.sql_store.insert(table="conversation_items", data=item_record)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
table="conversation_items",
|
||||
data={"created_at": created_at, "item_data": item_dict},
|
||||
where={"id": item_id},
|
||||
)
|
||||
|
||||
created_items.append(item_dict)
|
||||
|
||||
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
|
||||
|
||||
# Convert created items (dicts) to proper ConversationItem types
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
|
||||
|
||||
return ConversationItemList(
|
||||
data=response_items,
|
||||
first_id=created_items[0]["id"] if created_items else None,
|
||||
last_id=created_items[-1]["id"] if created_items else None,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||
"""Retrieve a conversation item."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
if not item_id:
|
||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||
|
||||
# Get item from conversation_items table
|
||||
record = await self.sql_store.fetch_one(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
return adapter.validate_python(record["item_data"])
|
||||
|
||||
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
|
||||
"""List items in the conversation."""
|
||||
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
|
||||
records = result.data
|
||||
|
||||
if order != NOT_GIVEN and order == "asc":
|
||||
records.sort(key=lambda x: x["created_at"])
|
||||
else:
|
||||
records.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
actual_limit = 20
|
||||
if limit != NOT_GIVEN and isinstance(limit, int):
|
||||
actual_limit = limit
|
||||
|
||||
records = records[:actual_limit]
|
||||
items = [record["item_data"] for record in records]
|
||||
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items]
|
||||
|
||||
first_id = response_items[0].id if response_items else None
|
||||
last_id = response_items[-1].id if response_items else None
|
||||
|
||||
return ConversationItemList(
|
||||
data=response_items,
|
||||
first_id=first_id,
|
||||
last_id=last_id,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
async def openai_delete_conversation_item(
|
||||
self, conversation_id: str, item_id: str
|
||||
) -> ConversationItemDeletedResource:
|
||||
"""Delete a conversation item."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
if not item_id:
|
||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||
|
||||
_ = await self._get_validated_conversation(conversation_id)
|
||||
|
||||
record = await self.sql_store.fetch_one(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||
|
||||
await self.sql_store.delete(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
|
||||
return ConversationItemDeletedResource(id=item_id)
|
||||
|
|
@ -22,7 +22,7 @@ from llama_stack.apis.safety import Safety
|
|||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
from llama_stack.apis.shields import Shield, ShieldInput
|
||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
|
|
@ -84,15 +84,11 @@ class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
|
|||
pass
|
||||
|
||||
|
||||
class ToolWithOwner(Tool, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | ToolGroup
|
||||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
ModelWithOwner
|
||||
|
|
@ -101,7 +97,6 @@ RoutableObjectWithProvider = Annotated[
|
|||
| DatasetWithOwner
|
||||
| ScoringFnWithOwner
|
||||
| BenchmarkWithOwner
|
||||
| ToolWithOwner
|
||||
| ToolGroupWithOwner,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
@ -480,6 +475,13 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
|
|||
If not specified, a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
conversations_store: SqlStoreConfig | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Configuration for the persistence store used by the conversations API.
|
||||
If not specified, a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
models: list[ModelInput] = Field(default_factory=list)
|
||||
shields: list[ShieldInput] = Field(default_factory=list)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from llama_stack.providers.datatypes import (
|
|||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts}
|
||||
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
|
||||
|
||||
|
||||
def stack_apis() -> list[Api]:
|
||||
|
|
@ -243,6 +243,7 @@ def get_external_providers_from_module(
|
|||
spec = module.get_provider_spec()
|
||||
else:
|
||||
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
spec = ProviderSpec(
|
||||
api=Api(provider_api),
|
||||
provider_type=provider.provider_type,
|
||||
|
|
@ -251,9 +252,20 @@ def get_external_providers_from_module(
|
|||
config_class="",
|
||||
)
|
||||
provider_type = provider.provider_type
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
# return a partially filled out spec that the build script will populate.
|
||||
registry[Api(provider_api)][provider_type] = spec
|
||||
if isinstance(spec, list):
|
||||
# optionally allow people to pass inline and remote provider specs as a returned list.
|
||||
# with the old method, users could pass in directories of specs using overlapping code
|
||||
# we want to ensure we preserve that flexibility in this method.
|
||||
logger.info(
|
||||
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
|
||||
)
|
||||
for provider_spec in spec:
|
||||
if provider_spec.provider_type != provider.provider_type:
|
||||
continue
|
||||
logger.info(f"Adding {provider.provider_type} to registry")
|
||||
registry[Api(provider_api)][provider.provider_type] = provider_spec
|
||||
else:
|
||||
registry[Api(provider_api)][provider_type] = spec
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
||||
|
|
|
|||
|
|
@ -374,6 +374,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
||||
# Merge extra_json parameters (extra_body from SDK is converted to extra_json)
|
||||
if hasattr(options, "extra_json") and options.extra_json:
|
||||
body |= options.extra_json
|
||||
|
||||
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Any
|
|||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batches import Batches
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.datatypes import ExternalApiSpec
|
||||
|
|
@ -96,6 +97,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
Api.tool_runtime: ToolRuntime,
|
||||
Api.files: Files,
|
||||
Api.prompts: Prompts,
|
||||
Api.conversations: Conversations,
|
||||
}
|
||||
|
||||
if external_apis:
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponseStreamChunk,
|
||||
Inference,
|
||||
ListOpenAIChatCompletionResponse,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
|
|
@ -42,12 +41,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
Order,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
|
|
@ -185,129 +179,6 @@ class InferenceRouter(Inference):
|
|||
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
||||
return model
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = None,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id, ModelType.llm)
|
||||
if tool_config:
|
||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
|
||||
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
||||
else:
|
||||
params = {}
|
||||
if tool_choice:
|
||||
params["tool_choice"] = tool_choice
|
||||
if tool_prompt_format:
|
||||
params["tool_prompt_format"] = tool_prompt_format
|
||||
tool_config = ToolConfig(**params)
|
||||
|
||||
tools = tools or []
|
||||
if tool_config.tool_choice == ToolChoice.none:
|
||||
tools = []
|
||||
elif tool_config.tool_choice == ToolChoice.auto:
|
||||
pass
|
||||
elif tool_config.tool_choice == ToolChoice.required:
|
||||
pass
|
||||
else:
|
||||
# verify tool_choice is one of the tools
|
||||
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
|
||||
if tool_config.tool_choice not in tool_names:
|
||||
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
|
||||
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
response_stream = await provider.chat_completion(**params)
|
||||
return self.stream_tokens_and_compute_metrics(
|
||||
response=response_stream,
|
||||
prompt_tokens=prompt_tokens,
|
||||
model=model,
|
||||
tool_prompt_format=tool_config.tool_prompt_format,
|
||||
)
|
||||
|
||||
response = await provider.chat_completion(**params)
|
||||
metrics = await self.count_tokens_and_compute_metrics(
|
||||
response=response,
|
||||
prompt_tokens=prompt_tokens,
|
||||
model=model,
|
||||
tool_prompt_format=tool_config.tool_prompt_format,
|
||||
)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
logger.debug(
|
||||
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||
)
|
||||
model = await self._get_model(model_id, ModelType.llm)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
prompt_tokens = await self._count_tokens(content)
|
||||
response = await provider.completion(**params)
|
||||
if stream:
|
||||
return self.stream_tokens_and_compute_metrics(
|
||||
response=response,
|
||||
prompt_tokens=prompt_tokens,
|
||||
model=model,
|
||||
)
|
||||
|
||||
metrics = await self.count_tokens_and_compute_metrics(
|
||||
response=response, prompt_tokens=prompt_tokens, model=model
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
|
||||
return response
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolsResponse,
|
||||
ListToolDefsResponse,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
|
|
@ -86,6 +86,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolsResponse:
|
||||
) -> ListToolDefsResponse:
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.list_tools(tool_group_id)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.errors import ToolGroupNotFoundError
|
||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ListToolGroupsResponse, ToolDef, ToolGroup, ToolGroups
|
||||
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name
|
|||
|
||||
|
||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||
toolgroups_to_tools: dict[str, list[Tool]] = {}
|
||||
toolgroups_to_tools: dict[str, list[ToolDef]] = {}
|
||||
tool_to_toolgroup: dict[str, str] = {}
|
||||
|
||||
# overridden
|
||||
|
|
@ -43,7 +43,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
routing_key = self.tool_to_toolgroup[routing_key]
|
||||
return await super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
|
||||
if toolgroup_id:
|
||||
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
||||
toolgroup_id = group_id
|
||||
|
|
@ -68,30 +68,19 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
continue
|
||||
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
||||
|
||||
return ListToolsResponse(data=all_tools)
|
||||
return ListToolDefsResponse(data=all_tools)
|
||||
|
||||
async def _index_tools(self, toolgroup: ToolGroup):
|
||||
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
||||
|
||||
# TODO: kill this Tool vs ToolDef distinction
|
||||
tooldefs = tooldefs_response.data
|
||||
tools = []
|
||||
for t in tooldefs:
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=t.name,
|
||||
toolgroup_id=toolgroup.identifier,
|
||||
description=t.description or "",
|
||||
parameters=t.parameters or [],
|
||||
metadata=t.metadata,
|
||||
provider_id=toolgroup.provider_id,
|
||||
)
|
||||
)
|
||||
t.toolgroup_id = toolgroup.identifier
|
||||
|
||||
self.toolgroups_to_tools[toolgroup.identifier] = tools
|
||||
for tool in tools:
|
||||
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
|
||||
self.toolgroups_to_tools[toolgroup.identifier] = tooldefs
|
||||
for tool in tooldefs:
|
||||
self.tool_to_toolgroup[tool.name] = toolgroup.identifier
|
||||
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||
|
|
@ -102,12 +91,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
raise ToolGroupNotFoundError(toolgroup_id)
|
||||
return tool_group
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
async def get_tool(self, tool_name: str) -> ToolDef:
|
||||
if tool_name in self.tool_to_toolgroup:
|
||||
toolgroup_id = self.tool_to_toolgroup[tool_name]
|
||||
tools = self.toolgroups_to_tools[toolgroup_id]
|
||||
for tool in tools:
|
||||
if tool.identifier == tool_name:
|
||||
if tool.name == tool_name:
|
||||
return tool
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
|
|
@ -132,7 +121,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
# baked in some of the code and tests right now.
|
||||
if not toolgroup.mcp_endpoint:
|
||||
await self._index_tools(toolgroup)
|
||||
return toolgroup
|
||||
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
await self.unregister_object(await self.get_tool_group(toolgroup_id))
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import functools
|
||||
|
|
@ -12,7 +11,6 @@ import inspect
|
|||
import json
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
|
@ -35,7 +33,6 @@ from pydantic import BaseModel, ValidationError
|
|||
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
|
|
@ -55,7 +52,6 @@ from llama_stack.core.stack import (
|
|||
Stack,
|
||||
cast_image_name_to_string,
|
||||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
|
|
@ -257,7 +253,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
|
||||
return result
|
||||
except Exception as e:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if logger.isEnabledFor(logging.INFO):
|
||||
logger.exception(f"Error executing endpoint {route=} {method=}")
|
||||
else:
|
||||
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
|
||||
|
|
@ -333,23 +329,18 @@ class ClientVersionMiddleware:
|
|||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def create_app(
|
||||
config_file: str | None = None,
|
||||
env_vars: list[str] | None = None,
|
||||
) -> StackApp:
|
||||
def create_app() -> StackApp:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
Args:
|
||||
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
|
||||
env_vars: List of environment variables in KEY=value format.
|
||||
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
|
||||
This factory function reads configuration from environment variables:
|
||||
- LLAMA_STACK_CONFIG: Path to config file (required)
|
||||
|
||||
Returns:
|
||||
Configured StackApp instance.
|
||||
"""
|
||||
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
|
||||
config_file = os.getenv("LLAMA_STACK_CONFIG")
|
||||
if config_file is None:
|
||||
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
|
||||
raise ValueError("LLAMA_STACK_CONFIG environment variable is required")
|
||||
|
||||
config_file = resolve_config_or_distro(config_file, Mode.RUN)
|
||||
|
||||
|
|
@ -361,16 +352,6 @@ def create_app(
|
|||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||
|
||||
if env_vars:
|
||||
for env_pair in env_vars:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
|
||||
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
|
|
@ -451,6 +432,7 @@ def create_app(
|
|||
apis_to_serve.add("inspect")
|
||||
apis_to_serve.add("providers")
|
||||
apis_to_serve.add("prompts")
|
||||
apis_to_serve.add("conversations")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
|
|
@ -493,101 +475,6 @@ def create_app(
|
|||
return app
|
||||
|
||||
|
||||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
|
||||
add_config_distro_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||
)
|
||||
|
||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||
# parsed from the command line
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
config_or_distro = get_config_from_args(args)
|
||||
|
||||
try:
|
||||
app = create_app(
|
||||
config_file=config_or_distro,
|
||||
env_vars=args.env,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating app: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
import uvicorn
|
||||
|
||||
# Configure SSL if certificates are provided
|
||||
port = args.port or config.server.port
|
||||
|
||||
ssl_config = None
|
||||
keyfile = config.server.tls_keyfile
|
||||
certfile = config.server.tls_certfile
|
||||
|
||||
if keyfile and certfile:
|
||||
ssl_config = {
|
||||
"ssl_keyfile": keyfile,
|
||||
"ssl_certfile": certfile,
|
||||
}
|
||||
if config.server.tls_cafile:
|
||||
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
|
||||
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
||||
logger.info(
|
||||
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
listen_host = config.server.host or ["::", "0.0.0.0"]
|
||||
logger.info(f"Listening on {listen_host}:{port}")
|
||||
|
||||
uvicorn_config = {
|
||||
"app": app,
|
||||
"host": listen_host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
}
|
||||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
# stack trace when using Ctrl+C or kill -2 (SIGINT).
|
||||
# SIGTERM (kill -15) works fine without this because Python doesn't
|
||||
# have a default handler for it.
|
||||
#
|
||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
"""Logs the run config with redacted fields and disabled providers removed."""
|
||||
logger.info("Run configuration:")
|
||||
|
|
@ -614,7 +501,3 @@ def remove_disabled_providers(obj):
|
|||
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import yaml
|
|||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
|
|
@ -34,6 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
|
|||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
|
|
@ -73,6 +75,7 @@ class LlamaStack(
|
|||
RAGToolRuntime,
|
||||
Files,
|
||||
Prompts,
|
||||
Conversations,
|
||||
):
|
||||
pass
|
||||
|
||||
|
|
@ -312,6 +315,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
|||
)
|
||||
impls[Api.prompts] = prompts_impl
|
||||
|
||||
conversations_impl = ConversationServiceImpl(
|
||||
ConversationServiceConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.conversations] = conversations_impl
|
||||
|
||||
|
||||
class Stack:
|
||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||
|
|
@ -342,6 +351,8 @@ class Stack:
|
|||
|
||||
if Api.prompts in impls:
|
||||
await impls[Api.prompts].initialize()
|
||||
if Api.conversations in impls:
|
||||
await impls[Api.conversations].initialize()
|
||||
|
||||
await register_resources(self.run_config, impls)
|
||||
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ if [[ "$env_type" == "venv" ]]; then
|
|||
yaml_config_arg=""
|
||||
fi
|
||||
|
||||
$PYTHON_BINARY -m llama_stack.core.server.server \
|
||||
llama stack run \
|
||||
$yaml_config_arg \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class DistributionRegistry(Protocol):
|
|||
|
||||
|
||||
REGISTER_PREFIX = "distributions:registry"
|
||||
KEY_VERSION = "v9"
|
||||
KEY_VERSION = "v10"
|
||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ def tool_chat_page():
|
|||
|
||||
for toolgroup_id in toolgroup_selection:
|
||||
tools = client.tools.list(toolgroup_id=toolgroup_id)
|
||||
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
|
||||
grouped_tools[toolgroup_id] = [tool.name for tool in tools]
|
||||
total_tools += len(tools)
|
||||
|
||||
st.markdown(f"Active Tools: 🛠 {total_tools}")
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
post_training:
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ docker run -it \
|
|||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v $HOME/.llama:/root/.llama \
|
||||
# NOTE: mount the llama-stack directory if testing local changes else not needed
|
||||
-v /home/hjshah/git/llama-stack:/app/llama-stack-source \
|
||||
-v $HOME/git/llama-stack:/app/llama-stack-source \
|
||||
# localhost/distribution-dell:dev if building / testing locally
|
||||
llamastack/distribution-{{ name }}\
|
||||
--port $LLAMA_STACK_PORT \
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
post_training:
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
post_training:
|
||||
|
|
|
|||
|
|
@ -31,7 +31,14 @@ CATEGORIES = [
|
|||
"client",
|
||||
"telemetry",
|
||||
"openai_responses",
|
||||
"testing",
|
||||
"providers",
|
||||
"models",
|
||||
"files",
|
||||
"vector_io",
|
||||
"tool_runtime",
|
||||
]
|
||||
UNCATEGORIZED = "uncategorized"
|
||||
|
||||
# Initialize category levels with default level
|
||||
_category_levels: dict[str, int] = dict.fromkeys(CATEGORIES, DEFAULT_LOG_LEVEL)
|
||||
|
|
@ -121,7 +128,7 @@ def strip_rich_markup(text):
|
|||
|
||||
class CustomRichHandler(RichHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["console"] = Console(width=150)
|
||||
kwargs["console"] = Console()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def emit(self, record):
|
||||
|
|
@ -165,7 +172,7 @@ def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None
|
|||
|
||||
def filter(self, record):
|
||||
if not hasattr(record, "category"):
|
||||
record.category = "uncategorized" # Default to 'uncategorized' if no category found
|
||||
record.category = UNCATEGORIZED # Default to 'uncategorized' if no category found
|
||||
return True
|
||||
|
||||
# Determine the root logger's level (default to WARNING if not specified)
|
||||
|
|
@ -247,7 +254,19 @@ def get_logger(
|
|||
_category_levels.update(parse_yaml_config(config))
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL))
|
||||
if category in _category_levels:
|
||||
log_level = _category_levels[category]
|
||||
else:
|
||||
root_category = category.split("::")[0]
|
||||
if root_category in _category_levels:
|
||||
log_level = _category_levels[root_category]
|
||||
else:
|
||||
log_level = _category_levels.get("root", DEFAULT_LOG_LEVEL)
|
||||
if category != UNCATEGORIZED:
|
||||
logging.warning(
|
||||
f"Unknown logging category: {category}. Falling back to default 'root' level: {log_level}"
|
||||
)
|
||||
logger.setLevel(log_level)
|
||||
return logging.LoggerAdapter(logger, {"category": category})
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -37,14 +37,7 @@ RecursiveType = Primitive | list[Primitive] | dict[str, Primitive]
|
|||
class ToolCall(BaseModel):
|
||||
call_id: str
|
||||
tool_name: BuiltinTool | str
|
||||
# Plan is to deprecate the Dict in favor of a JSON string
|
||||
# that is parsed on the client side instead of trying to manage
|
||||
# the recursive type here.
|
||||
# Making this a union so that client side can start prepping for this change.
|
||||
# Eventually, we will remove both the Dict and arguments_json field,
|
||||
# and arguments will just be a str
|
||||
arguments: str | dict[str, RecursiveType]
|
||||
arguments_json: str | None = None
|
||||
arguments: str
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -88,19 +81,11 @@ class StopReason(Enum):
|
|||
out_of_tokens = "out_of_tokens"
|
||||
|
||||
|
||||
class ToolParamDefinition(BaseModel):
|
||||
param_type: str
|
||||
description: str | None = None
|
||||
required: bool | None = True
|
||||
items: Any | None = None
|
||||
title: str | None = None
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
tool_name: BuiltinTool | str
|
||||
description: str | None = None
|
||||
parameters: dict[str, ToolParamDefinition] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
output_schema: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -232,8 +232,7 @@ class ChatFormat:
|
|||
ToolCall(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
arguments_json=json.dumps(tool_arguments),
|
||||
arguments=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from typing import Any
|
|||
from llama_stack.apis.inference import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
)
|
||||
|
||||
from .base import PromptTemplate, PromptTemplateGeneratorBase
|
||||
|
|
@ -101,11 +100,8 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set tparams = t.parameters -%}
|
||||
{%- set required_params = [] -%}
|
||||
{%- for name, param in tparams.items() if param.required == true -%}
|
||||
{%- set _ = required_params.append(name) -%}
|
||||
{%- endfor -%}
|
||||
{%- set tprops = t.input_schema.get('properties', {}) -%}
|
||||
{%- set required_params = t.input_schema.get('required', []) -%}
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
|
|
@ -114,11 +110,11 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": [
|
||||
{%- for name, param in tparams.items() %}
|
||||
{%- for name, param in tprops.items() %}
|
||||
{
|
||||
"{{name}}": {
|
||||
"type": "object",
|
||||
"description": "{{param.description}}"
|
||||
"description": "{{param.get('description', '')}}"
|
||||
}
|
||||
}{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
|
|
@ -143,17 +139,19 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
ToolDefinition(
|
||||
tool_name="trending_songs",
|
||||
description="Returns the trending songs on a Music site",
|
||||
parameters={
|
||||
"n": ToolParamDefinition(
|
||||
param_type="int",
|
||||
description="The number of songs to return",
|
||||
required=True,
|
||||
),
|
||||
"genre": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="The genre of the songs to return",
|
||||
required=False,
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"n": {
|
||||
"type": "int",
|
||||
"description": "The number of songs to return",
|
||||
},
|
||||
"genre": {
|
||||
"type": "str",
|
||||
"description": "The genre of the songs to return",
|
||||
},
|
||||
},
|
||||
"required": ["n"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
@ -170,11 +168,14 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
{#- manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set modified_params = t.parameters.copy() -%}
|
||||
{%- for key, value in modified_params.items() -%}
|
||||
{%- if 'default' in value -%}
|
||||
{%- set _ = value.pop('default', None) -%}
|
||||
{%- set tprops = t.input_schema.get('properties', {}) -%}
|
||||
{%- set modified_params = {} -%}
|
||||
{%- for key, value in tprops.items() -%}
|
||||
{%- set param_copy = value.copy() -%}
|
||||
{%- if 'default' in param_copy -%}
|
||||
{%- set _ = param_copy.pop('default', None) -%}
|
||||
{%- endif -%}
|
||||
{%- set _ = modified_params.update({key: param_copy}) -%}
|
||||
{%- endfor -%}
|
||||
{%- set tparams = modified_params | tojson -%}
|
||||
Use the function '{{ tname }}' to '{{ tdesc }}':
|
||||
|
|
@ -205,17 +206,19 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
ToolDefinition(
|
||||
tool_name="trending_songs",
|
||||
description="Returns the trending songs on a Music site",
|
||||
parameters={
|
||||
"n": ToolParamDefinition(
|
||||
param_type="int",
|
||||
description="The number of songs to return",
|
||||
required=True,
|
||||
),
|
||||
"genre": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="The genre of the songs to return",
|
||||
required=False,
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"n": {
|
||||
"type": "int",
|
||||
"description": "The number of songs to return",
|
||||
},
|
||||
"genre": {
|
||||
"type": "str",
|
||||
"description": "The genre of the songs to return",
|
||||
},
|
||||
},
|
||||
"required": ["n"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
@ -255,11 +258,8 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set tparams = t.parameters -%}
|
||||
{%- set required_params = [] -%}
|
||||
{%- for name, param in tparams.items() if param.required == true -%}
|
||||
{%- set _ = required_params.append(name) -%}
|
||||
{%- endfor -%}
|
||||
{%- set tprops = (t.input_schema or {}).get('properties', {}) -%}
|
||||
{%- set required_params = (t.input_schema or {}).get('required', []) -%}
|
||||
{
|
||||
"name": "{{tname}}",
|
||||
"description": "{{tdesc}}",
|
||||
|
|
@ -267,11 +267,11 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
"type": "dict",
|
||||
"required": {{ required_params | tojson }},
|
||||
"properties": {
|
||||
{%- for name, param in tparams.items() %}
|
||||
{%- for name, param in tprops.items() %}
|
||||
"{{name}}": {
|
||||
"type": "{{param.param_type}}",
|
||||
"description": "{{param.description}}"{% if param.default %},
|
||||
"default": "{{param.default}}"{% endif %}
|
||||
"type": "{{param.get('type', 'string')}}",
|
||||
"description": "{{param.get('description', '')}}"{% if param.get('default') %},
|
||||
"default": "{{param.get('default')}}"{% endif %}
|
||||
}{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
}
|
||||
|
|
@ -299,18 +299,20 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get weather info for places",
|
||||
parameters={
|
||||
"city": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The name of the city to get the weather for",
|
||||
required=True,
|
||||
),
|
||||
"metric": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The metric for weather. Options are: celsius, fahrenheit",
|
||||
required=False,
|
||||
default="celsius",
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for",
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -220,17 +220,18 @@ class ToolUtils:
|
|||
|
||||
@staticmethod
|
||||
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
|
||||
args = json.loads(t.arguments)
|
||||
if t.tool_name == BuiltinTool.brave_search:
|
||||
q = t.arguments["query"]
|
||||
q = args["query"]
|
||||
return f'brave_search.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.wolfram_alpha:
|
||||
q = t.arguments["query"]
|
||||
q = args["query"]
|
||||
return f'wolfram_alpha.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.photogen:
|
||||
q = t.arguments["query"]
|
||||
q = args["query"]
|
||||
return f'photogen.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.code_interpreter:
|
||||
return t.arguments["code"]
|
||||
return args["code"]
|
||||
else:
|
||||
fname = t.tool_name
|
||||
|
||||
|
|
@ -239,12 +240,11 @@ class ToolUtils:
|
|||
{
|
||||
"type": "function",
|
||||
"name": fname,
|
||||
"parameters": t.arguments,
|
||||
"parameters": args,
|
||||
}
|
||||
)
|
||||
elif tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
args = json.dumps(t.arguments)
|
||||
return f"<function={fname}>{args}</function>"
|
||||
return f"<function={fname}>{t.arguments}</function>"
|
||||
|
||||
elif tool_prompt_format == ToolPromptFormat.python_list:
|
||||
|
||||
|
|
@ -260,7 +260,7 @@ class ToolUtils:
|
|||
else:
|
||||
raise ValueError(f"Unsupported type: {type(value)}")
|
||||
|
||||
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
|
||||
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in args.items())
|
||||
return f"[{fname}({args_str})]"
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
|
|
@ -184,7 +185,7 @@ def usecases() -> list[UseCase | str]:
|
|||
ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name=BuiltinTool.wolfram_alpha,
|
||||
arguments={"query": "100th decimal of pi"},
|
||||
arguments=json.dumps({"query": "100th decimal of pi"}),
|
||||
)
|
||||
],
|
||||
),
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
|
|
@ -185,7 +186,7 @@ def usecases() -> list[UseCase | str]:
|
|||
ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name=BuiltinTool.wolfram_alpha,
|
||||
arguments={"query": "100th decimal of pi"},
|
||||
arguments=json.dumps({"query": "100th decimal of pi"}),
|
||||
)
|
||||
],
|
||||
),
|
||||
|
|
|
|||
|
|
@ -298,8 +298,7 @@ class ChatFormat:
|
|||
ToolCall(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
arguments_json=json.dumps(tool_arguments),
|
||||
arguments=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
import textwrap
|
||||
|
||||
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.apis.inference import ToolDefinition
|
||||
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||
PromptTemplate,
|
||||
PromptTemplateGeneratorBase,
|
||||
|
|
@ -81,11 +81,8 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set tparams = t.parameters -%}
|
||||
{%- set required_params = [] -%}
|
||||
{%- for name, param in tparams.items() if param.required == true -%}
|
||||
{%- set _ = required_params.append(name) -%}
|
||||
{%- endfor -%}
|
||||
{%- set tprops = t.input_schema.get('properties', {}) -%}
|
||||
{%- set required_params = t.input_schema.get('required', []) -%}
|
||||
{
|
||||
"name": "{{tname}}",
|
||||
"description": "{{tdesc}}",
|
||||
|
|
@ -93,11 +90,11 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
"type": "dict",
|
||||
"required": {{ required_params | tojson }},
|
||||
"properties": {
|
||||
{%- for name, param in tparams.items() %}
|
||||
{%- for name, param in tprops.items() %}
|
||||
"{{name}}": {
|
||||
"type": "{{param.param_type}}",
|
||||
"description": "{{param.description}}"{% if param.default %},
|
||||
"default": "{{param.default}}"{% endif %}
|
||||
"type": "{{param.get('type', 'string')}}",
|
||||
"description": "{{param.get('description', '')}}"{% if param.get('default') %},
|
||||
"default": "{{param.get('default')}}"{% endif %}
|
||||
}{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
}
|
||||
|
|
@ -119,18 +116,20 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get weather info for places",
|
||||
parameters={
|
||||
"city": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The name of the city to get the weather for",
|
||||
required=True,
|
||||
),
|
||||
"metric": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The metric for weather. Options are: celsius, fahrenheit",
|
||||
required=False,
|
||||
default="celsius",
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for",
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -50,11 +50,16 @@ from llama_stack.apis.inference import (
|
|||
CompletionMessage,
|
||||
Inference,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
|
|
@ -68,6 +73,11 @@ from llama_stack.models.llama.datatypes import (
|
|||
BuiltinTool,
|
||||
ToolCall,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict_new,
|
||||
convert_openai_chat_completion_stream,
|
||||
convert_tooldef_to_openai_tool,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
|
|
@ -177,12 +187,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
turn_id = str(uuid.uuid4())
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
turn_id = str(uuid.uuid4())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
|
@ -505,26 +515,93 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
tool_calls = []
|
||||
content = ""
|
||||
stop_reason = None
|
||||
stop_reason: StopReason | None = None
|
||||
|
||||
async with tracing.span("inference") as span:
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=self.tool_defs,
|
||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||
|
||||
def _serialize_nested(value):
|
||||
"""Recursively serialize nested Pydantic models to dicts."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(mode="json")
|
||||
elif isinstance(value, dict):
|
||||
return {k: _serialize_nested(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [_serialize_nested(item) for item in value]
|
||||
else:
|
||||
return value
|
||||
|
||||
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
|
||||
# Serialize any nested Pydantic models to plain dicts
|
||||
openai_msg = _serialize_nested(openai_msg)
|
||||
|
||||
role = openai_msg.get("role")
|
||||
if role == "user":
|
||||
return OpenAIUserMessageParam(**openai_msg)
|
||||
elif role == "system":
|
||||
return OpenAISystemMessageParam(**openai_msg)
|
||||
elif role == "assistant":
|
||||
return OpenAIAssistantMessageParam(**openai_msg)
|
||||
elif role == "tool":
|
||||
return OpenAIToolMessageParam(**openai_msg)
|
||||
elif role == "developer":
|
||||
return OpenAIDeveloperMessageParam(**openai_msg)
|
||||
else:
|
||||
raise ValueError(f"Unknown message role: {role}")
|
||||
|
||||
# Convert messages to OpenAI format
|
||||
openai_messages: list[OpenAIMessageParam] = [
|
||||
_add_type(await convert_message_to_openai_dict_new(message)) for message in input_messages
|
||||
]
|
||||
|
||||
# Convert tool definitions to OpenAI format
|
||||
openai_tools = [convert_tooldef_to_openai_tool(x) for x in (self.tool_defs or [])]
|
||||
|
||||
# Extract tool_choice from tool_config for OpenAI compatibility
|
||||
# Note: tool_choice can only be provided when tools are also provided
|
||||
tool_choice = None
|
||||
if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice:
|
||||
tc = self.agent_config.tool_config.tool_choice
|
||||
tool_choice_str = tc.value if hasattr(tc, "value") else str(tc)
|
||||
# Convert tool_choice to OpenAI format
|
||||
if tool_choice_str in ("auto", "none", "required"):
|
||||
tool_choice = tool_choice_str
|
||||
else:
|
||||
# It's a specific tool name, wrap it in the proper format
|
||||
tool_choice = {"type": "function", "function": {"name": tool_choice_str}}
|
||||
|
||||
# Convert sampling params to OpenAI format (temperature, top_p, max_tokens)
|
||||
temperature = getattr(getattr(sampling_params, "strategy", None), "temperature", None)
|
||||
top_p = getattr(getattr(sampling_params, "strategy", None), "top_p", None)
|
||||
max_tokens = getattr(sampling_params, "max_tokens", None)
|
||||
|
||||
# Use OpenAI chat completion
|
||||
openai_stream = await self.inference_api.openai_chat_completion(
|
||||
model=self.agent_config.model,
|
||||
messages=openai_messages,
|
||||
tools=openai_tools if openai_tools else None,
|
||||
tool_choice=tool_choice,
|
||||
response_format=self.agent_config.response_format,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
sampling_params=sampling_params,
|
||||
tool_config=self.agent_config.tool_config,
|
||||
):
|
||||
)
|
||||
|
||||
# Convert OpenAI stream back to Llama Stack format
|
||||
response_stream = convert_openai_chat_completion_stream(
|
||||
openai_stream, enable_incremental_tool_calls=True
|
||||
)
|
||||
|
||||
async for chunk in response_stream:
|
||||
event = chunk.event
|
||||
if event.event_type == ChatCompletionResponseEventType.start:
|
||||
continue
|
||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
stop_reason = event.stop_reason or StopReason.end_of_turn
|
||||
continue
|
||||
|
||||
delta = event.delta
|
||||
|
|
@ -533,7 +610,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_calls.append(delta.tool_call)
|
||||
elif delta.parse_status == ToolCallParseStatus.failed:
|
||||
# If we cannot parse the tools, set the content to the unparsed raw text
|
||||
content = delta.tool_call
|
||||
content = str(delta.tool_call)
|
||||
if stream:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
|
@ -560,9 +637,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
raise ValueError(f"Unexpected delta type {type(delta)}")
|
||||
|
||||
if event.stop_reason is not None:
|
||||
stop_reason = event.stop_reason
|
||||
span.set_attribute("stop_reason", stop_reason)
|
||||
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
|
||||
span.set_attribute(
|
||||
"input",
|
||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||
|
|
@ -790,20 +865,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
|
||||
# Use input_schema from ToolDef directly
|
||||
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
items=param.items,
|
||||
title=param.title,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
|
|
@ -813,44 +880,34 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
|
||||
)
|
||||
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
|
||||
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
|
||||
if input_tool_name is not None and not any(tool.name == input_tool_name for tool in tools.data):
|
||||
raise ValueError(
|
||||
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.name for tool in tools.data])}"
|
||||
)
|
||||
|
||||
for tool_def in tools.data:
|
||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||
identifier: str | BuiltinTool | None = tool_def.identifier
|
||||
identifier: str | BuiltinTool | None = tool_def.name
|
||||
if identifier == "web_search":
|
||||
identifier = BuiltinTool.brave_search
|
||||
else:
|
||||
identifier = BuiltinTool(identifier)
|
||||
else:
|
||||
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
|
||||
if input_tool_name in (None, tool_def.identifier):
|
||||
identifier = tool_def.identifier
|
||||
if input_tool_name in (None, tool_def.name):
|
||||
identifier = tool_def.name
|
||||
else:
|
||||
identifier = None
|
||||
|
||||
if tool_name_to_def.get(identifier, None):
|
||||
raise ValueError(f"Tool {identifier} already exists")
|
||||
if identifier:
|
||||
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
||||
tool_name_to_def[identifier] = ToolDefinition(
|
||||
tool_name=identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
items=param.items,
|
||||
title=param.title,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
|
||||
self.tool_defs, self.tool_name_to_args = (
|
||||
list(tool_name_to_def.values()),
|
||||
|
|
@ -894,12 +951,18 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_name_str = tool_name
|
||||
|
||||
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
||||
|
||||
try:
|
||||
args = json.loads(tool_call.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse arguments for tool call: {tool_call.arguments}") from e
|
||||
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name_str,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**tool_call.arguments,
|
||||
**args,
|
||||
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -329,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shields: list | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input,
|
||||
|
|
@ -342,6 +343,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools,
|
||||
include,
|
||||
max_infer_iters,
|
||||
shields,
|
||||
)
|
||||
|
||||
async def list_openai_responses(
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import time
|
|||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
|
|
@ -26,12 +26,16 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.responses.responses_store import (
|
||||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
from .tool_executor import ToolExecutor
|
||||
|
|
@ -41,7 +45,7 @@ from .utils import (
|
|||
convert_response_text_to_chat_response_format,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai::responses")
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
|
||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||
|
|
@ -72,26 +76,48 @@ class OpenAIResponsesImpl:
|
|||
async def _prepend_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_response_id: str | None = None,
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages,
|
||||
):
|
||||
new_input_items = previous_response.input.copy()
|
||||
new_input_items.extend(previous_response.output)
|
||||
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
|
||||
return new_input_items
|
||||
|
||||
async def _process_input_with_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_response_id: str | None,
|
||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
|
||||
"""Process input with optional previous response context.
|
||||
|
||||
Returns:
|
||||
tuple: (all_input for storage, messages for chat completion)
|
||||
"""
|
||||
if previous_response_id:
|
||||
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
|
||||
await self.responses_store.get_response_object(previous_response_id)
|
||||
)
|
||||
all_input = await self._prepend_previous_response(input, previous_response)
|
||||
|
||||
# previous response input items
|
||||
new_input_items = previous_response_with_input.input
|
||||
|
||||
# previous response output items
|
||||
new_input_items.extend(previous_response_with_input.output)
|
||||
|
||||
# new input items from the current request
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
if previous_response.messages:
|
||||
# Use stored messages directly and convert only new input
|
||||
message_adapter = TypeAdapter(list[OpenAIMessageParam])
|
||||
messages = message_adapter.validate_python(previous_response.messages)
|
||||
new_messages = await convert_response_input_to_chat_messages(input)
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
# Backward compatibility: reconstruct from inputs
|
||||
messages = await convert_response_input_to_chat_messages(all_input)
|
||||
else:
|
||||
all_input = input
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
|
||||
input = new_input_items
|
||||
|
||||
return input
|
||||
return all_input, messages
|
||||
|
||||
async def _prepend_instructions(self, messages, instructions):
|
||||
if instructions:
|
||||
|
|
@ -102,7 +128,7 @@ class OpenAIResponsesImpl:
|
|||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
response_with_input = await self.responses_store.get_response_object(response_id)
|
||||
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
|
||||
return response_with_input.to_response_object()
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
|
|
@ -138,6 +164,7 @@ class OpenAIResponsesImpl:
|
|||
self,
|
||||
response: OpenAIResponseObject,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
|
|
@ -165,6 +192,7 @@ class OpenAIResponsesImpl:
|
|||
await self.responses_store.store_response_object(
|
||||
response_object=response,
|
||||
input=input_items_data,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
async def create_openai_response(
|
||||
|
|
@ -180,10 +208,15 @@ class OpenAIResponsesImpl:
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shields: list | None = None,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
# Shields parameter received via extra_body - not yet implemented
|
||||
if shields is not None:
|
||||
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
model=model,
|
||||
|
|
@ -224,8 +257,7 @@ class OpenAIResponsesImpl:
|
|||
max_infer_iters: int | None = 10,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Structured outputs
|
||||
|
|
@ -265,7 +297,8 @@ class OpenAIResponsesImpl:
|
|||
if store and final_response:
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=input,
|
||||
input=all_input,
|
||||
messages=orchestrator.final_messages,
|
||||
)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
OpenAIMessageParam,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
|
@ -62,22 +63,13 @@ def convert_tooldef_to_chat_tool(tool_def):
|
|||
ChatCompletionToolParam suitable for OpenAI chat completion
|
||||
"""
|
||||
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
internal_tool_def = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
items=param.items,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(internal_tool_def)
|
||||
|
||||
|
|
@ -103,6 +95,8 @@ class StreamingResponseOrchestrator:
|
|||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
# Track final messages after all tool executions
|
||||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Initialize output messages
|
||||
|
|
@ -129,13 +123,16 @@ class StreamingResponseOrchestrator:
|
|||
messages = self.ctx.messages.copy()
|
||||
|
||||
while True:
|
||||
# Text is the default response format for chat completion so don't need to pass it
|
||||
# (some providers don't support non-empty response_format when tools are present)
|
||||
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
|
||||
completion_result = await self.inference_api.openai_chat_completion(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
tools=self.ctx.chat_tools,
|
||||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=self.ctx.response_format,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
# Process streaming chunks and build complete response
|
||||
|
|
@ -189,6 +186,8 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
messages = next_turn_messages
|
||||
|
||||
self.final_messages = messages.copy() + [current_response.choices[0].message]
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
created_at=self.created_at,
|
||||
|
|
@ -352,8 +351,11 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
|
||||
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||
tool_call = chat_response_tool_calls[tool_call_index]
|
||||
# Ensure that arguments, if sent back to the inference provider, are not None
|
||||
tool_call.function.arguments = tool_call.function.arguments or "{}"
|
||||
tool_call_item_id = tool_call_item_ids[tool_call_index]
|
||||
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
|
||||
final_arguments = tool_call.function.arguments
|
||||
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
|
||||
|
||||
# Check if this is an MCP tool call
|
||||
|
|
@ -522,23 +524,15 @@ class StreamingResponseOrchestrator:
|
|||
"""Process all tools and emit appropriate streaming events."""
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from llama_stack.apis.tools import Tool
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
||||
def make_openai_tool(tool_name: str, tool: ToolDef) -> ChatCompletionToolParam:
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool.parameters
|
||||
},
|
||||
input_schema=tool.input_schema,
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(tool_def)
|
||||
|
||||
|
|
@ -625,16 +619,11 @@ class StreamingResponseOrchestrator:
|
|||
MCPListToolsTool(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
input_schema={
|
||||
input_schema=t.input_schema
|
||||
or {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
p.name: {
|
||||
"type": p.parameter_type,
|
||||
"description": p.description,
|
||||
}
|
||||
for p in t.parameters
|
||||
},
|
||||
"required": [p.name for p in t.parameters if p.required],
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,41 +5,17 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
InferenceProvider,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TokenLogProbs,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -57,15 +33,6 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
augment_content_with_response_format_prompt,
|
||||
chat_completion_request_to_messages,
|
||||
convert_request_to_raw,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generators import LlamaGenerator
|
||||
|
|
@ -82,8 +49,6 @@ def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_
|
|||
|
||||
|
||||
class MetaReferenceInferenceImpl(
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
|
|
@ -100,6 +65,9 @@ class MetaReferenceInferenceImpl(
|
|||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
async def openai_completion(self, *args, **kwargs):
|
||||
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
|
|
@ -165,15 +133,10 @@ class MetaReferenceInferenceImpl(
|
|||
self.llama_model = llama_model
|
||||
|
||||
log.info("Warming up...")
|
||||
await self.completion(
|
||||
model_id=model_id,
|
||||
content="Hello, world!",
|
||||
sampling_params=SamplingParams(max_tokens=10),
|
||||
)
|
||||
await self.chat_completion(
|
||||
model_id=model_id,
|
||||
messages=[UserMessage(content="Hi how are you?")],
|
||||
sampling_params=SamplingParams(max_tokens=20),
|
||||
await self.openai_chat_completion(
|
||||
model=model_id,
|
||||
messages=[{"role": "user", "content": "Hi how are you?"}],
|
||||
max_tokens=20,
|
||||
)
|
||||
log.info("Warmed up!")
|
||||
|
||||
|
|
@ -185,373 +148,30 @@ class MetaReferenceInferenceImpl(
|
|||
elif request.model != self.model_id:
|
||||
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
||||
|
||||
async def completion(
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | CompletionResponseStreamChunk:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
content = augment_content_with_response_format_prompt(response_format, content)
|
||||
request = CompletionRequest(
|
||||
model=model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
if request.stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
results = await self._nonstream_completion([request])
|
||||
return results[0]
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
stop_reason = None
|
||||
|
||||
for token_results in self.generator.completion([request]):
|
||||
token_result = token_results[0]
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
logprobs = None
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta="",
|
||||
stop_reason=StopReason.out_of_tokens,
|
||||
)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
for x in impl():
|
||||
yield x
|
||||
else:
|
||||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def _nonstream_completion(self, request_batch: list[CompletionRequest]) -> list[CompletionResponse]:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
first_request = request_batch[0]
|
||||
|
||||
class ItemState(BaseModel):
|
||||
tokens: list[int] = []
|
||||
logprobs: list[TokenLogProbs] = []
|
||||
stop_reason: StopReason | None = None
|
||||
finished: bool = False
|
||||
|
||||
def impl():
|
||||
states = [ItemState() for _ in request_batch]
|
||||
|
||||
results = []
|
||||
for token_results in self.generator.completion(request_batch):
|
||||
for result in token_results:
|
||||
idx = result.batch_idx
|
||||
state = states[idx]
|
||||
if state.finished or result.ignore_token:
|
||||
continue
|
||||
|
||||
state.finished = result.finished
|
||||
if first_request.logprobs:
|
||||
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||
|
||||
state.tokens.append(result.token)
|
||||
if result.token == tokenizer.eot_id:
|
||||
state.stop_reason = StopReason.end_of_turn
|
||||
elif result.token == tokenizer.eom_id:
|
||||
state.stop_reason = StopReason.end_of_message
|
||||
|
||||
for state in states:
|
||||
if state.stop_reason is None:
|
||||
state.stop_reason = StopReason.out_of_tokens
|
||||
|
||||
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
|
||||
state.tokens = state.tokens[:-1]
|
||||
content = self.generator.formatter.tokenizer.decode(state.tokens)
|
||||
results.append(
|
||||
CompletionResponse(
|
||||
content=content,
|
||||
stop_reason=state.stop_reason,
|
||||
logprobs=state.logprobs if first_request.logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
return impl()
|
||||
else:
|
||||
return impl()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config or ToolConfig(),
|
||||
)
|
||||
self.check_model(request)
|
||||
|
||||
# augment and rewrite messages depending on the model
|
||||
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
||||
# download media and convert to raw content so we can send it to the model
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
|
||||
if request.stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
results = await self._nonstream_chat_completion([request])
|
||||
return results[0]
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request_batch: list[ChatCompletionRequest]
|
||||
) -> list[ChatCompletionResponse]:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
first_request = request_batch[0]
|
||||
|
||||
class ItemState(BaseModel):
|
||||
tokens: list[int] = []
|
||||
logprobs: list[TokenLogProbs] = []
|
||||
stop_reason: StopReason | None = None
|
||||
finished: bool = False
|
||||
|
||||
def impl():
|
||||
states = [ItemState() for _ in request_batch]
|
||||
|
||||
for token_results in self.generator.chat_completion(request_batch):
|
||||
first = token_results[0]
|
||||
if not first.finished and not first.ignore_token:
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
|
||||
cprint(first.text, color="cyan", end="", file=sys.stderr)
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
|
||||
|
||||
for result in token_results:
|
||||
idx = result.batch_idx
|
||||
state = states[idx]
|
||||
if state.finished or result.ignore_token:
|
||||
continue
|
||||
|
||||
state.finished = result.finished
|
||||
if first_request.logprobs:
|
||||
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||
|
||||
state.tokens.append(result.token)
|
||||
if result.token == tokenizer.eot_id:
|
||||
state.stop_reason = StopReason.end_of_turn
|
||||
elif result.token == tokenizer.eom_id:
|
||||
state.stop_reason = StopReason.end_of_message
|
||||
|
||||
results = []
|
||||
for state in states:
|
||||
if state.stop_reason is None:
|
||||
state.stop_reason = StopReason.out_of_tokens
|
||||
|
||||
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
|
||||
results.append(
|
||||
ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=state.logprobs if first_request.logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
return impl()
|
||||
else:
|
||||
return impl()
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
ipython = False
|
||||
|
||||
for token_results in self.generator.chat_completion([request]):
|
||||
token_result = token_results[0]
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||
cprint(token_result.text, color="cyan", end="", file=sys.stderr)
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
|
||||
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
if ipython:
|
||||
delta = ToolCallDelta(
|
||||
tool_call=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
else:
|
||||
delta = TextDelta(text=text)
|
||||
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call="",
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
for x in impl():
|
||||
yield x
|
||||
else:
|
||||
for x in impl():
|
||||
yield x
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
|
||||
|
|
|
|||
|
|
@ -27,8 +27,6 @@ class ModelRunner:
|
|||
def __call__(self, task: Any):
|
||||
if task[0] == "chat_completion":
|
||||
return self.llama.chat_completion(task[1])
|
||||
elif task[0] == "completion":
|
||||
return self.llama.completion(task[1])
|
||||
else:
|
||||
raise ValueError(f"Unexpected task type {task[0]}")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,19 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionResponse,
|
||||
InferenceProvider,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -26,7 +25,6 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
|
|||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
)
|
||||
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
|
@ -36,7 +34,6 @@ log = get_logger(name=__name__, category="inference")
|
|||
|
||||
class SentenceTransformersInferenceImpl(
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
|
|
@ -74,28 +71,58 @@ class SentenceTransformersInferenceImpl(
|
|||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
async def openai_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: str,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncGenerator:
|
||||
raise ValueError("Sentence transformers don't support completion")
|
||||
# Standard OpenAI completion parameters
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
# vLLM-specific parameters
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
# for fill-in-the-middle type completion
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
|
||||
|
||||
async def chat_completion(
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
raise ValueError("Sentence transformers don't support chat completion")
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")
|
||||
|
|
|
|||
|
|
@ -68,9 +68,7 @@ public class FunctionTagCustomToolGenerator {
|
|||
{
|
||||
"name": "{{t.tool_name}}",
|
||||
"description": "{{t.description}}",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"properties": { {{t.parameters}} }
|
||||
"input_schema": { {{t.input_schema}} }
|
||||
}
|
||||
|
||||
{{/let}}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
import re
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Inference, UserMessage
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
|
@ -55,15 +55,16 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
|||
generated_answer=generated_answer,
|
||||
)
|
||||
|
||||
judge_response = await self.inference_api.chat_completion(
|
||||
model_id=fn_def.params.judge_model,
|
||||
judge_response = await self.inference_api.openai_chat_completion(
|
||||
model=fn_def.params.judge_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content=judge_input_msg,
|
||||
),
|
||||
{
|
||||
"role": "user",
|
||||
"content": judge_input_msg,
|
||||
}
|
||||
],
|
||||
)
|
||||
content = judge_response.completion_message.content
|
||||
content = judge_response.choices[0].message.content
|
||||
rating_regexes = fn_def.params.judge_score_regexes
|
||||
|
||||
judge_rating = None
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class TelemetryConfig(BaseModel):
|
|||
description="The service name to use for telemetry",
|
||||
)
|
||||
sinks: list[TelemetrySink] = Field(
|
||||
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
|
||||
default=[TelemetrySink.SQLITE],
|
||||
description="List of telemetry sinks to enable (possible values: otel_trace, otel_metric, sqlite, console)",
|
||||
)
|
||||
sqlite_db_path: str = Field(
|
||||
|
|
@ -49,7 +49,7 @@ class TelemetryConfig(BaseModel):
|
|||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:=\u200b}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:=console,sqlite}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:=sqlite}",
|
||||
"sqlite_db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
||||
"otel_exporter_otlp_endpoint": "${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -130,11 +130,9 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
trace.get_tracer_provider().force_flush()
|
||||
|
||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
|
||||
logger.debug(f"DEBUG: log_event called with event type: {type(event).__name__}")
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
self._log_unstructured(event, ttl_seconds)
|
||||
elif isinstance(event, MetricEvent):
|
||||
logger.debug("DEBUG: Routing MetricEvent to _log_metric")
|
||||
self._log_metric(event)
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
self._log_structured(event, ttl_seconds)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ from llama_stack.apis.tools import (
|
|||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import (
|
||||
|
|
@ -301,13 +300,16 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
ToolDef(
|
||||
name="knowledge_search",
|
||||
description="Search for information in a database.",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for. Can be a natural language sentence or keywords.",
|
||||
parameter_type="string",
|
||||
),
|
||||
],
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for. Can be a natural language sentence or keywords.",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -52,9 +52,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter_type="cerebras",
|
||||
provider_type="remote::cerebras",
|
||||
pip_packages=[
|
||||
"cerebras_cloud_sdk",
|
||||
],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.cerebras",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||
|
|
@ -169,7 +167,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter_type="openai",
|
||||
provider_type="remote::openai",
|
||||
pip_packages=["litellm"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.openai",
|
||||
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||
|
|
@ -179,7 +177,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter_type="anthropic",
|
||||
provider_type="remote::anthropic",
|
||||
pip_packages=["litellm"],
|
||||
pip_packages=["anthropic"],
|
||||
module="llama_stack.providers.remote.inference.anthropic",
|
||||
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||
|
|
@ -189,9 +187,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter_type="gemini",
|
||||
provider_type="remote::gemini",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.gemini",
|
||||
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||
|
|
@ -202,7 +198,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
adapter_type="vertexai",
|
||||
provider_type="remote::vertexai",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
"google-cloud-aiplatform",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.vertexai",
|
||||
|
|
@ -233,9 +228,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
adapter_type="groq",
|
||||
provider_type="remote::groq",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.groq",
|
||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
|
|
@ -245,7 +238,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
adapter_type="llama-openai-compat",
|
||||
provider_type="remote::llama-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||
|
|
@ -255,9 +248,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
adapter_type="sambanova",
|
||||
provider_type="remote::sambanova",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.sambanova",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||
|
|
@ -287,7 +278,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
provider_type="remote::azure",
|
||||
adapter_type="azure",
|
||||
pip_packages=["litellm"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.azure",
|
||||
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",
|
||||
|
|
|
|||
|
|
@ -500,7 +500,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
|||
api=Api.vector_io,
|
||||
adapter_type="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
pip_packages=["weaviate-client"],
|
||||
pip_packages=["weaviate-client>=4.16.5"],
|
||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,6 @@ from .config import AnthropicConfig
|
|||
async def get_adapter_impl(config: AnthropicConfig, _deps):
|
||||
from .anthropic import AnthropicInferenceAdapter
|
||||
|
||||
impl = AnthropicInferenceAdapter(config)
|
||||
impl = AnthropicInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,13 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from collections.abc import Iterable
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AnthropicConfig
|
||||
|
||||
|
||||
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
class AnthropicInferenceAdapter(OpenAIMixin):
|
||||
config: AnthropicConfig
|
||||
|
||||
provider_data_api_key_field: str = "anthropic_api_key"
|
||||
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
|
||||
# TODO: add support for voyageai, which is where these models are hosted
|
||||
# embedding_model_metadata = {
|
||||
|
|
@ -23,22 +29,11 @@ class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# }
|
||||
|
||||
def __init__(self, config: AnthropicConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="anthropic",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="anthropic_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://api.anthropic.com/v1"
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ class AnthropicProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class AnthropicConfig(BaseModel):
|
||||
class AnthropicConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Anthropic models",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,6 @@ from .config import AzureConfig
|
|||
async def get_adapter_impl(config: AzureConfig, _deps):
|
||||
from .azure import AzureInferenceAdapter
|
||||
|
||||
impl = AzureInferenceAdapter(config)
|
||||
impl = AzureInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,31 +4,20 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||
LiteLLMOpenAIMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AzureConfig
|
||||
|
||||
|
||||
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: AzureConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="azure",
|
||||
api_key_from_config=config.api_key.get_secret_value(),
|
||||
provider_data_api_key_field="azure_api_key",
|
||||
openai_compat_api_base=str(config.api_base),
|
||||
)
|
||||
self.config = config
|
||||
class AzureInferenceAdapter(OpenAIMixin):
|
||||
config: AzureConfig
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
provider_data_api_key_field: str = "azure_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
|
|
@ -37,26 +26,3 @@ class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
Returns the Azure API base URL from the configuration.
|
||||
"""
|
||||
return urljoin(str(self.config.api_base), "/openai/v1")
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||
# Get base parameters from parent
|
||||
params = await super()._get_params(request)
|
||||
|
||||
# Add Azure specific parameters
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data:
|
||||
if getattr(provider_data, "azure_api_key", None):
|
||||
params["api_key"] = provider_data.azure_api_key
|
||||
if getattr(provider_data, "azure_api_base", None):
|
||||
params["api_base"] = provider_data.azure_api_base
|
||||
if getattr(provider_data, "azure_api_version", None):
|
||||
params["api_version"] = provider_data.azure_api_version
|
||||
if getattr(provider_data, "azure_api_type", None):
|
||||
params["api_type"] = provider_data.azure_api_type
|
||||
else:
|
||||
params["api_key"] = self.config.api_key.get_secret_value()
|
||||
params["api_base"] = str(self.config.api_base)
|
||||
params["api_version"] = self.config.api_version
|
||||
params["api_type"] = self.config.api_type
|
||||
|
||||
return params
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -30,7 +31,7 @@ class AzureProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class AzureConfig(BaseModel):
|
||||
class AzureConfig(RemoteInferenceProviderConfig):
|
||||
api_key: SecretStr = Field(
|
||||
description="Azure API key for Azure",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,27 +5,22 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from botocore.client import BaseClient
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
|
|
@ -33,13 +28,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_strategy_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
|
|
@ -88,8 +77,6 @@ def _to_inference_profile_id(model_id: str, region: str = None) -> str:
|
|||
class BedrockInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
|
@ -109,82 +96,6 @@ class BedrockInferenceAdapter(
|
|||
if self._client is not None:
|
||||
self._client.close()
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params_for_chat_completion(request)
|
||||
res = self.client.invoke_model(**params)
|
||||
chunk = next(res["body"])
|
||||
result = json.loads(chunk.decode("utf-8"))
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=result["stop_reason"],
|
||||
text=result["generation"],
|
||||
)
|
||||
|
||||
response = OpenAICompatCompletionResponse(choices=[choice])
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params_for_chat_completion(request)
|
||||
res = self.client.invoke_model_with_response_stream(**params)
|
||||
event_stream = res["body"]
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
for chunk in event_stream:
|
||||
chunk = chunk["chunk"]["bytes"]
|
||||
result = json.loads(chunk.decode("utf-8"))
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=result["stop_reason"],
|
||||
text=result["generation"],
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(choices=[choice])
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
|
||||
bedrock_model = request.model
|
||||
|
||||
|
|
@ -221,3 +132,59 @@ class BedrockInferenceAdapter(
|
|||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
# Standard OpenAI completion parameters
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
# vLLM-specific parameters
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
# for fill-in-the-middle type completion
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ async def get_adapter_impl(config: CerebrasImplConfig, _deps):
|
|||
|
||||
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = CerebrasInferenceAdapter(config)
|
||||
impl = CerebrasInferenceAdapter(config=config)
|
||||
|
||||
await impl.initialize()
|
||||
|
||||
|
|
|
|||
|
|
@ -4,62 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from cerebras.cloud.sdk import AsyncCerebras
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
TopKSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import CerebrasImplConfig
|
||||
|
||||
|
||||
class CerebrasInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
):
|
||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
# TODO: make this use provider data, etc. like other providers
|
||||
self._cerebras_client = AsyncCerebras(
|
||||
base_url=self.config.base_url,
|
||||
api_key=self.config.api_key.get_secret_value(),
|
||||
)
|
||||
class CerebrasInferenceAdapter(OpenAIMixin):
|
||||
config: CerebrasImplConfig
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
|
@ -67,122 +21,6 @@ class CerebrasInferenceAdapter(
|
|||
def get_base_url(self) -> str:
|
||||
return urljoin(self.config.base_url, "v1")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_completion(
|
||||
request,
|
||||
)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
|
||||
r = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
|
||||
r = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
|
||||
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
||||
raise ValueError("`top_k` not supported by Cerebras")
|
||||
|
||||
prompt = ""
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||
elif isinstance(request, CompletionRequest):
|
||||
prompt = await completion_request_to_prompt(request)
|
||||
else:
|
||||
raise ValueError(f"Unknown request type {type(request)}")
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
"prompt": prompt,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
|||
|
|
@ -7,21 +7,22 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CerebrasImplConfig(BaseModel):
|
||||
class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
||||
base_url: str = Field(
|
||||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||
description="Base URL for the Cerebras API",
|
||||
)
|
||||
api_key: SecretStr = Field(
|
||||
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
|
||||
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
|
||||
description="Cerebras API Key",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,6 @@ async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
|||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = DatabricksInferenceAdapter(config)
|
||||
impl = DatabricksInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -6,19 +6,20 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatabricksImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
)
|
||||
api_token: SecretStr = Field(
|
||||
default=SecretStr(None),
|
||||
default=SecretStr(None), # type: ignore[arg-type]
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,32 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
Model,
|
||||
OpenAICompletion,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.inference import OpenAICompletion
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
|
@ -38,41 +18,31 @@ from .config import DatabricksImplConfig
|
|||
logger = get_logger(name=__name__, category="inference::databricks")
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
Inference,
|
||||
):
|
||||
class DatabricksInferenceAdapter(OpenAIMixin):
|
||||
config: DatabricksImplConfig
|
||||
|
||||
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
|
||||
embedding_model_metadata = {
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
|
||||
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_token.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/serving-endpoints"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [
|
||||
endpoint.name
|
||||
for endpoint in WorkspaceClient(
|
||||
host=self.config.url, token=self.get_api_key()
|
||||
).serving_endpoints.list() # TODO: this is not async
|
||||
]
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
|
|
@ -98,47 +68,3 @@ class DatabricksInferenceAdapter(
|
|||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._model_cache = {} # from OpenAIMixin
|
||||
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
|
||||
endpoints = ws_client.serving_endpoints.list()
|
||||
for endpoint in endpoints:
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__,
|
||||
provider_resource_id=endpoint.name,
|
||||
identifier=endpoint.name,
|
||||
)
|
||||
if endpoint.task == "llm/v1/chat":
|
||||
model.model_type = ModelType.llm # this is redundant, but informative
|
||||
elif endpoint.task == "llm/v1/embeddings":
|
||||
if endpoint.name not in self.embedding_model_metadata:
|
||||
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
|
||||
continue
|
||||
model.model_type = ModelType.embedding
|
||||
model.metadata = self.embedding_model_metadata[endpoint.name]
|
||||
else:
|
||||
logger.warning(f"Unknown model type, skipping: {endpoint}")
|
||||
continue
|
||||
|
||||
self._model_cache[endpoint.name] = model
|
||||
|
||||
return list(self._model_cache.values())
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -17,6 +17,6 @@ async def get_adapter_impl(config: FireworksImplConfig, _deps):
|
|||
from .fireworks import FireworksInferenceAdapter
|
||||
|
||||
assert isinstance(config, FireworksImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = FireworksInferenceAdapter(config)
|
||||
impl = FireworksInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,252 +4,27 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
embedding_model_metadata = {
|
||||
class FireworksInferenceAdapter(OpenAIMixin):
|
||||
config: FireworksImplConfig
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
|
||||
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
|
||||
}
|
||||
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self)
|
||||
self.config = config
|
||||
self.allowed_models = config.allowed_models
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
provider_data_api_key_field: str = "fireworks_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
return config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.fireworks_api_key:
|
||||
raise ValueError(
|
||||
'Pass Fireworks API Key in the header X-LlamaStack-Provider-Data as { "fireworks_api_key": <your api key>}'
|
||||
)
|
||||
return provider_data.fireworks_api_key
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else None # type: ignore[return-value]
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return "https://api.fireworks.ai/inference/v1"
|
||||
|
||||
def _get_client(self) -> Fireworks:
|
||||
fireworks_api_key = self.get_api_key()
|
||||
return Fireworks(api_key=fireworks_api_key)
|
||||
|
||||
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
|
||||
"""Remove BOS token as Fireworks automatically prepends it"""
|
||||
if prompt.startswith("<|begin_of_text|>"):
|
||||
return prompt[len("<|begin_of_text|>") :]
|
||||
return prompt
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await self._get_client().completion.acreate(**params)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
# Wrapper for async generator similar
|
||||
async def _to_async_generator():
|
||||
stream = self._get_client().completion.create(**params)
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
def _build_options(
|
||||
self,
|
||||
sampling_params: SamplingParams | None,
|
||||
fmt: ResponseFormat,
|
||||
logprobs: LogProbConfig | None,
|
||||
) -> dict:
|
||||
options = get_sampling_options(sampling_params)
|
||||
options.setdefault("max_tokens", 512)
|
||||
|
||||
if fmt:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": fmt.json_schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
options["response_format"] = {
|
||||
"type": "grammar",
|
||||
"grammar": fmt.bnf,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
if logprobs and logprobs.top_k:
|
||||
options["logprobs"] = logprobs.top_k
|
||||
if options["logprobs"] <= 0 or options["logprobs"] >= 5:
|
||||
raise ValueError("Required range: 0 < top_k < 5")
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
if "messages" in params:
|
||||
r = await self._get_client().chat.completions.acreate(**params)
|
||||
else:
|
||||
r = await self._get_client().completion.acreate(**params)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
async def _to_async_generator():
|
||||
if "messages" in params:
|
||||
stream = self._get_client().chat.completions.acreate(**params)
|
||||
else:
|
||||
stream = self._get_client().completion.acreate(**params)
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
# TODO: tools are never added to the request, so we need to add them here
|
||||
if media_present or not llama_model:
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_openai_dict(m, download=True) for m in request.messages
|
||||
]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
else:
|
||||
assert not media_present, "Fireworks does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if "prompt" in input_dict:
|
||||
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||
|
||||
params = {
|
||||
"model": request.model,
|
||||
**input_dict,
|
||||
"stream": bool(request.stream),
|
||||
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
||||
}
|
||||
logger.debug(f"params to fireworks: {params}")
|
||||
|
||||
return params
|
||||
|
|
|
|||
|
|
@ -10,6 +10,6 @@ from .config import GeminiConfig
|
|||
async def get_adapter_impl(config: GeminiConfig, _deps):
|
||||
from .gemini import GeminiInferenceAdapter
|
||||
|
||||
impl = GeminiInferenceAdapter(config)
|
||||
impl = GeminiInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ class GeminiProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class GeminiConfig(BaseModel):
|
||||
class GeminiConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Gemini models",
|
||||
|
|
|
|||
|
|
@ -4,33 +4,21 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import GeminiConfig
|
||||
|
||||
|
||||
class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
embedding_model_metadata = {
|
||||
class GeminiInferenceAdapter(OpenAIMixin):
|
||||
config: GeminiConfig
|
||||
|
||||
provider_data_api_key_field: str = "gemini_api_key"
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
||||
}
|
||||
|
||||
def __init__(self, config: GeminiConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="gemini",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="gemini_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
|
|
|
|||
|
|
@ -11,5 +11,5 @@ async def get_adapter_impl(config: GroqConfig, _deps):
|
|||
# import dynamically so the import is used only when it is needed
|
||||
from .groq import GroqInferenceAdapter
|
||||
|
||||
adapter = GroqInferenceAdapter(config)
|
||||
adapter = GroqInferenceAdapter(config=config)
|
||||
return adapter
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ class GroqProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class GroqConfig(BaseModel):
|
||||
class GroqConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
# The Groq client library loads the GROQ_API_KEY environment variable by default
|
||||
default=None,
|
||||
|
|
|
|||
|
|
@ -6,30 +6,16 @@
|
|||
|
||||
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
_config: GroqConfig
|
||||
class GroqInferenceAdapter(OpenAIMixin):
|
||||
config: GroqConfig
|
||||
|
||||
def __init__(self, config: GroqConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="groq",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
)
|
||||
self.config = config
|
||||
provider_data_api_key_field: str = "groq_api_key"
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/openai/v1"
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
|
|
|||
|
|
@ -4,14 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import LlamaCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> InferenceProvider:
|
||||
async def get_adapter_impl(config: LlamaCompatConfig, _deps):
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .llama import LlamaCompatInferenceAdapter
|
||||
|
||||
adapter = LlamaCompatInferenceAdapter(config)
|
||||
adapter = LlamaCompatInferenceAdapter(config=config)
|
||||
return adapter
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ class LlamaProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class LlamaCompatConfig(BaseModel):
|
||||
class LlamaCompatConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Llama API key",
|
||||
|
|
|
|||
|
|
@ -3,40 +3,26 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference.inference import OpenAICompletion
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
class LlamaCompatInferenceAdapter(OpenAIMixin):
|
||||
config: LlamaCompatConfig
|
||||
|
||||
provider_data_api_key_field: str = "llama_api_key"
|
||||
"""
|
||||
Llama API Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of ModelRegistryHelper.check_model_availability().
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists
|
||||
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
|
||||
"""
|
||||
|
||||
_config: LlamaCompatConfig
|
||||
|
||||
def __init__(self, config: LlamaCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="meta_llama",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="llama_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
|
|
@ -46,8 +32,27 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
"""
|
||||
return self.config.openai_compat_api_base
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
|
|||
|
||||
if not isinstance(config, NVIDIAConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
adapter = NVIDIAInferenceAdapter(config)
|
||||
adapter = NVIDIAInferenceAdapter(config=config)
|
||||
await adapter.initialize()
|
||||
return adapter
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,13 +7,14 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NVIDIAConfig(BaseModel):
|
||||
class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||
"""
|
||||
Configuration for the NVIDIA NIM inference endpoint.
|
||||
|
||||
|
|
|
|||
|
|
@ -4,54 +4,26 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from openai import NOT_GIVEN, APIConnectionError
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||
|
||||
from . import NVIDIAConfig
|
||||
from .openai_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_completion_request,
|
||||
convert_openai_completion_choice,
|
||||
convert_openai_completion_stream,
|
||||
)
|
||||
from .utils import _is_nvidia_hosted
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||
config: NVIDIAConfig
|
||||
|
||||
"""
|
||||
NVIDIA Inference Adapter for Llama Stack.
|
||||
|
||||
|
|
@ -66,32 +38,21 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
"""
|
||||
|
||||
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||
embedding_model_metadata = {
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192},
|
||||
"nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024},
|
||||
"nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096},
|
||||
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
|
||||
}
|
||||
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
|
||||
|
||||
if _is_nvidia_hosted(config):
|
||||
if not config.api_key:
|
||||
if _is_nvidia_hosted(self.config):
|
||||
if not self.config.api_key:
|
||||
raise RuntimeError(
|
||||
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
|
||||
)
|
||||
# elif self._config.api_key:
|
||||
#
|
||||
# we don't raise this warning because a user may have deployed their
|
||||
# self-hosted NIM with an API key requirement.
|
||||
#
|
||||
# warnings.warn(
|
||||
# "API key is not required for self-hosted NVIDIA NIM. "
|
||||
# "Consider removing the api_key from the configuration."
|
||||
# )
|
||||
|
||||
self._config = config
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
|
|
@ -99,7 +60,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
|
||||
:return: The NVIDIA API key
|
||||
"""
|
||||
return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
|
|
@ -107,49 +68,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
|
||||
:return: The NVIDIA API base URL
|
||||
"""
|
||||
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if content_has_media(content):
|
||||
raise NotImplementedError("Media is not supported")
|
||||
|
||||
# ToDo: check health of NeMo endpoints and enable this
|
||||
# removing this health check as NeMo customizer endpoint health check is returning 404
|
||||
# await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = await self._get_provider_model_id(model_id)
|
||||
request = convert_completion_request(
|
||||
request=CompletionRequest(
|
||||
model=provider_model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
),
|
||||
n=1,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
if stream:
|
||||
return convert_openai_completion_stream(response)
|
||||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return convert_openai_completion_choice(response.choices[0])
|
||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
|
@ -201,49 +120,3 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
model=response.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if tool_prompt_format:
|
||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
|
||||
|
||||
# await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = await self._get_provider_model_id(model_id)
|
||||
request = await convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=provider_model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
),
|
||||
n=1,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
if stream:
|
||||
return convert_openai_chat_completion_stream(response, enable_incremental_tool_calls=False)
|
||||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return convert_openai_chat_completion_choice(response.choices[0])
|
||||
|
|
|
|||
|
|
@ -10,6 +10,6 @@ from .config import OllamaImplConfig
|
|||
async def get_adapter_impl(config: OllamaImplConfig, _deps):
|
||||
from .ollama import OllamaInferenceAdapter
|
||||
|
||||
impl = OllamaInferenceAdapter(config)
|
||||
impl = OllamaInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -6,12 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
class OllamaImplConfig(BaseModel):
|
||||
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
refresh_models: bool = Field(
|
||||
default=False,
|
||||
|
|
|
|||
|
|
@ -6,79 +6,29 @@
|
|||
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from ollama import AsyncClient as AsyncOllamaClient
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
GrammarResponseFormat,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
convert_image_content_to_url,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::ollama")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
ModelRegistryHelper,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
class OllamaInferenceAdapter(OpenAIMixin):
|
||||
config: OllamaImplConfig
|
||||
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
|
||||
embedding_model_metadata = {
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"all-minilm:l6-v2": {
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
|
|
@ -97,29 +47,8 @@ class OllamaInferenceAdapter(
|
|||
},
|
||||
}
|
||||
|
||||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
# TODO: remove ModelRegistryHelper.__init__ when completion and
|
||||
# chat_completion are. this exists to satisfy the input /
|
||||
# output processing for llama models. specifically,
|
||||
# tool_calling is handled by raw template processing,
|
||||
# instead of using the /api/chat endpoint w/ tools=...
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_entries=[
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:3b-instruct-fp16",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
],
|
||||
)
|
||||
self.config = config
|
||||
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
||||
self.download_images = True
|
||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||
download_images: bool = True
|
||||
_clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||
|
||||
@property
|
||||
def ollama_client(self) -> AsyncOllamaClient:
|
||||
|
|
@ -163,200 +92,6 @@ class OllamaInferenceAdapter(
|
|||
async def shutdown(self) -> None:
|
||||
self._clients.clear()
|
||||
|
||||
async def _get_model(self, model_id: str) -> Model:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set")
|
||||
return await self.model_store.get_model(model_id)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
if model.provider_resource_id is None:
|
||||
raise ValueError(f"Model {model_id} has no provider_resource_id set")
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _stream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.ollama_client.generate(**params)
|
||||
async for chunk in s:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||
text=chunk["response"],
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await self.ollama_client.generate(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r["done_reason"] if r["done"] else None,
|
||||
text=r["response"],
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
return process_completion_response(response)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
if model.provider_resource_id is None:
|
||||
raise ValueError(f"Model {model_id} has no provider_resource_id set")
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
response_format=response_format,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
|
||||
sampling_options = get_sampling_options(request.sampling_params)
|
||||
# This is needed since the Ollama API expects num_predict to be set
|
||||
# for early truncation instead of max_tokens.
|
||||
if sampling_options.get("max_tokens") is not None:
|
||||
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
||||
|
||||
input_dict: dict[str, Any] = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present or not llama_model:
|
||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||
# flatten the list of lists
|
||||
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
||||
else:
|
||||
input_dict["raw"] = True
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request,
|
||||
llama_model,
|
||||
)
|
||||
else:
|
||||
assert not media_present, "Ollama does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
input_dict["raw"] = True
|
||||
|
||||
if fmt := request.response_format:
|
||||
if isinstance(fmt, JsonSchemaResponseFormat):
|
||||
input_dict["format"] = fmt.json_schema
|
||||
elif isinstance(fmt, GrammarResponseFormat):
|
||||
raise NotImplementedError("Grammar response format is not supported")
|
||||
else:
|
||||
raise ValueError(f"Unknown response format type: {fmt.type}")
|
||||
|
||||
params = {
|
||||
"model": request.model,
|
||||
**input_dict,
|
||||
"options": sampling_options,
|
||||
"stream": request.stream,
|
||||
}
|
||||
logger.debug(f"params to ollama: {params}")
|
||||
|
||||
return params
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
if "messages" in params:
|
||||
r = await self.ollama_client.chat(**params)
|
||||
else:
|
||||
r = await self.ollama_client.generate(**params)
|
||||
|
||||
if "message" in r:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r["done_reason"] if r["done"] else None,
|
||||
text=r["message"]["content"],
|
||||
)
|
||||
else:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r["done_reason"] if r["done"] else None,
|
||||
text=r["response"],
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
if "messages" in params:
|
||||
s = await self.ollama_client.chat(**params)
|
||||
else:
|
||||
s = await self.ollama_client.generate(**params)
|
||||
async for chunk in s:
|
||||
if "message" in chunk:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||
text=chunk["message"]["content"],
|
||||
)
|
||||
else:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||
text=chunk["response"],
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if await self.check_model_availability(model.provider_model_id):
|
||||
return model
|
||||
|
|
@ -368,24 +103,3 @@ class OllamaInferenceAdapter(
|
|||
return model
|
||||
|
||||
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||
async def _convert_content(content) -> dict:
|
||||
if isinstance(content, ImageContentItem):
|
||||
return {
|
||||
"role": message.role,
|
||||
"images": [await convert_image_content_to_url(content, download=True, include_format=False)],
|
||||
}
|
||||
else:
|
||||
text = content.text if isinstance(content, TextContentItem) else content
|
||||
assert isinstance(text, str)
|
||||
return {
|
||||
"role": message.role,
|
||||
"content": text,
|
||||
}
|
||||
|
||||
if isinstance(message.content, list):
|
||||
return [await _convert_content(c) for c in message.content]
|
||||
else:
|
||||
return [await _convert_content(message.content)]
|
||||
|
|
|
|||
|
|
@ -10,6 +10,6 @@ from .config import OpenAIConfig
|
|||
async def get_adapter_impl(config: OpenAIConfig, _deps):
|
||||
from .openai import OpenAIInferenceAdapter
|
||||
|
||||
impl = OpenAIInferenceAdapter(config)
|
||||
impl = OpenAIInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ class OpenAIProviderDataValidator(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIConfig(BaseModel):
|
||||
class OpenAIConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import OpenAIConfig
|
||||
|
|
@ -14,52 +13,24 @@ logger = get_logger(name=__name__, category="inference::openai")
|
|||
|
||||
|
||||
#
|
||||
# This OpenAI adapter implements Inference methods using two mixins -
|
||||
# This OpenAI adapter implements Inference methods using OpenAIMixin
|
||||
#
|
||||
# | Inference Method | Implementation Source |
|
||||
# |----------------------------|--------------------------|
|
||||
# | completion | LiteLLMOpenAIMixin |
|
||||
# | chat_completion | LiteLLMOpenAIMixin |
|
||||
# | embedding | LiteLLMOpenAIMixin |
|
||||
# | openai_completion | OpenAIMixin |
|
||||
# | openai_chat_completion | OpenAIMixin |
|
||||
# | openai_embeddings | OpenAIMixin |
|
||||
#
|
||||
class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
class OpenAIInferenceAdapter(OpenAIMixin):
|
||||
"""
|
||||
OpenAI Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of ModelRegistryHelper.check_model_availability().
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists
|
||||
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
|
||||
"""
|
||||
|
||||
embedding_model_metadata = {
|
||||
config: OpenAIConfig
|
||||
|
||||
provider_data_api_key_field: str = "openai_api_key"
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
|
||||
}
|
||||
|
||||
def __init__(self, config: OpenAIConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="openai",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="openai_api_key",
|
||||
)
|
||||
self.config = config
|
||||
# we set is_openai_compat so users can use the canonical
|
||||
# openai model names like "gpt-4" or "gpt-3.5-turbo"
|
||||
# and the model name will be translated to litellm's
|
||||
# "openai/gpt-4" or "openai/gpt-3.5-turbo" transparently.
|
||||
# if we do not set this, users will be exposed to the
|
||||
# litellm specific model names, an abstraction leak.
|
||||
self.is_openai_compat = True
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
|
|
@ -68,9 +39,3 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
Returns the OpenAI API base URL from the configuration.
|
||||
"""
|
||||
return self.config.base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
|
|
|
|||
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PassthroughImplConfig(BaseModel):
|
||||
class PassthroughImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default=None,
|
||||
description="The URL for the passthrough endpoint",
|
||||
|
|
|
|||
|
|
@ -4,34 +4,22 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack_client import AsyncLlamaStackClient
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.core.library_client import convert_pydantic_to_json_value, convert_to_pydantic
|
||||
from llama_stack.core.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
||||
|
|
@ -43,12 +31,6 @@ class PassthroughInferenceAdapter(Inference):
|
|||
ModelRegistryHelper.__init__(self)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
|
@ -86,107 +68,6 @@ class PassthroughInferenceAdapter(Inference):
|
|||
provider_data=provider_data,
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
request_params = {
|
||||
"model_id": model.provider_resource_id,
|
||||
"content": content,
|
||||
"sampling_params": sampling_params,
|
||||
"response_format": response_format,
|
||||
"stream": stream,
|
||||
"logprobs": logprobs,
|
||||
}
|
||||
|
||||
request_params = {key: value for key, value in request_params.items() if value is not None}
|
||||
|
||||
# cast everything to json dict
|
||||
json_params = self.cast_value_to_json_dict(request_params)
|
||||
|
||||
# only pass through the not None params
|
||||
return await client.inference.completion(**json_params)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
# TODO: revisit this remove tool_calls from messages logic
|
||||
for message in messages:
|
||||
if hasattr(message, "tool_calls"):
|
||||
message.tool_calls = None
|
||||
|
||||
request_params = {
|
||||
"model_id": model.provider_resource_id,
|
||||
"messages": messages,
|
||||
"sampling_params": sampling_params,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
"tool_prompt_format": tool_prompt_format,
|
||||
"response_format": response_format,
|
||||
"stream": stream,
|
||||
"logprobs": logprobs,
|
||||
}
|
||||
|
||||
# only pass through the not None params
|
||||
request_params = {key: value for key, value in request_params.items() if value is not None}
|
||||
|
||||
# cast everything to json dict
|
||||
json_params = self.cast_value_to_json_dict(request_params)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(json_params)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(json_params)
|
||||
|
||||
async def _nonstream_chat_completion(self, json_params: dict[str, Any]) -> ChatCompletionResponse:
|
||||
client = self._get_client()
|
||||
response = await client.inference.chat_completion(**json_params)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=response.completion_message.content.text,
|
||||
stop_reason=response.completion_message.stop_reason,
|
||||
tool_calls=response.completion_message.tool_calls,
|
||||
),
|
||||
logprobs=response.logprobs,
|
||||
)
|
||||
|
||||
async def _stream_chat_completion(self, json_params: dict[str, Any]) -> AsyncGenerator:
|
||||
client = self._get_client()
|
||||
stream_response = await client.inference.chat_completion(**json_params)
|
||||
|
||||
async for chunk in stream_response:
|
||||
chunk = chunk.to_dict()
|
||||
|
||||
# temporary hack to remove the metrics from the response
|
||||
chunk["metrics"] = []
|
||||
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
|
||||
yield chunk
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue