mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-23 00:27:26 +00:00
Merge remote-tracking branch 'origin/main' into stores
This commit is contained in:
commit
b72154ce5e
1161 changed files with 609896 additions and 42960 deletions
|
@ -797,7 +797,7 @@ class Agents(Protocol):
|
|||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
"""Retrieve an OpenAI response by its ID.
|
||||
"""Get a model response.
|
||||
|
||||
:param response_id: The ID of the OpenAI response to retrieve.
|
||||
:returns: An OpenAIResponseObject.
|
||||
|
@ -812,6 +812,7 @@ class Agents(Protocol):
|
|||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
|
@ -826,11 +827,12 @@ class Agents(Protocol):
|
|||
),
|
||||
] = None,
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a new OpenAI response.
|
||||
"""Create a model response.
|
||||
|
||||
:param input: Input message(s) to create the response.
|
||||
:param model: The underlying LLM used for completions.
|
||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
|
||||
:param include: (Optional) Additional fields to include in the response.
|
||||
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
|
||||
:returns: An OpenAIResponseObject.
|
||||
|
@ -846,7 +848,7 @@ class Agents(Protocol):
|
|||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
"""List all OpenAI responses.
|
||||
"""List all responses.
|
||||
|
||||
:param after: The ID of the last response to return.
|
||||
:param limit: The number of responses to return.
|
||||
|
@ -869,7 +871,7 @@ class Agents(Protocol):
|
|||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
"""List input items for a given OpenAI response.
|
||||
"""List input items.
|
||||
|
||||
:param response_id: The ID of the response to retrieve input items for.
|
||||
:param after: An item ID to list items after, used for pagination.
|
||||
|
@ -884,7 +886,7 @@ class Agents(Protocol):
|
|||
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
"""Delete an OpenAI response by its ID.
|
||||
"""Delete a response.
|
||||
|
||||
:param response_id: The ID of the OpenAI response to delete.
|
||||
:returns: An OpenAIDeleteResponseObject
|
||||
|
|
|
@ -346,6 +346,174 @@ class OpenAIResponseText(BaseModel):
|
|||
format: OpenAIResponseTextFormat | None = None
|
||||
|
||||
|
||||
# Must match type Literals of OpenAIResponseInputToolWebSearch below
|
||||
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11"]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolWebSearch(BaseModel):
|
||||
"""Web search tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Web search tool type variant to use
|
||||
:param search_context_size: (Optional) Size of search context, must be "low", "medium", or "high"
|
||||
"""
|
||||
|
||||
# Must match values of WebSearchToolTypes above
|
||||
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
|
||||
"web_search"
|
||||
)
|
||||
# TODO: actually use search_context_size somewhere...
|
||||
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
|
||||
# TODO: add user_location
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFunction(BaseModel):
|
||||
"""Function tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Tool type identifier, always "function"
|
||||
:param name: Name of the function that can be called
|
||||
:param description: (Optional) Description of what the function does
|
||||
:param parameters: (Optional) JSON schema defining the function's parameters
|
||||
:param strict: (Optional) Whether to enforce strict parameter validation
|
||||
"""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
name: str
|
||||
description: str | None = None
|
||||
parameters: dict[str, Any] | None
|
||||
strict: bool | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||
"""File search tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Tool type identifier, always "file_search"
|
||||
:param vector_store_ids: List of vector store identifiers to search within
|
||||
:param filters: (Optional) Additional filters to apply to the search
|
||||
:param max_num_results: (Optional) Maximum number of search results to return (1-50)
|
||||
:param ranking_options: (Optional) Options for ranking and scoring search results
|
||||
"""
|
||||
|
||||
type: Literal["file_search"] = "file_search"
|
||||
vector_store_ids: list[str]
|
||||
filters: dict[str, Any] | None = None
|
||||
max_num_results: int | None = Field(default=10, ge=1, le=50)
|
||||
ranking_options: FileSearchRankingOptions | None = None
|
||||
|
||||
|
||||
class ApprovalFilter(BaseModel):
|
||||
"""Filter configuration for MCP tool approval requirements.
|
||||
|
||||
:param always: (Optional) List of tool names that always require approval
|
||||
:param never: (Optional) List of tool names that never require approval
|
||||
"""
|
||||
|
||||
always: list[str] | None = None
|
||||
never: list[str] | None = None
|
||||
|
||||
|
||||
class AllowedToolsFilter(BaseModel):
|
||||
"""Filter configuration for restricting which MCP tools can be used.
|
||||
|
||||
:param tool_names: (Optional) List of specific tool names that are allowed
|
||||
"""
|
||||
|
||||
tool_names: list[str] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolMCP(BaseModel):
|
||||
"""Model Context Protocol (MCP) tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Tool type identifier, always "mcp"
|
||||
:param server_label: Label to identify this MCP server
|
||||
:param server_url: URL endpoint of the MCP server
|
||||
:param headers: (Optional) HTTP headers to include when connecting to the server
|
||||
:param require_approval: Approval requirement for tool calls ("always", "never", or filter)
|
||||
:param allowed_tools: (Optional) Restriction on which tools can be used from this server
|
||||
"""
|
||||
|
||||
type: Literal["mcp"] = "mcp"
|
||||
server_label: str
|
||||
server_url: str
|
||||
headers: dict[str, Any] | None = None
|
||||
|
||||
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
|
||||
allowed_tools: list[str] | AllowedToolsFilter | None = None
|
||||
|
||||
|
||||
OpenAIResponseInputTool = Annotated[
|
||||
OpenAIResponseInputToolWebSearch
|
||||
| OpenAIResponseInputToolFileSearch
|
||||
| OpenAIResponseInputToolFunction
|
||||
| OpenAIResponseInputToolMCP,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseToolMCP(BaseModel):
|
||||
"""Model Context Protocol (MCP) tool configuration for OpenAI response object.
|
||||
|
||||
:param type: Tool type identifier, always "mcp"
|
||||
:param server_label: Label to identify this MCP server
|
||||
:param allowed_tools: (Optional) Restriction on which tools can be used from this server
|
||||
"""
|
||||
|
||||
type: Literal["mcp"] = "mcp"
|
||||
server_label: str
|
||||
allowed_tools: list[str] | AllowedToolsFilter | None = None
|
||||
|
||||
|
||||
OpenAIResponseTool = Annotated[
|
||||
OpenAIResponseInputToolWebSearch
|
||||
| OpenAIResponseInputToolFileSearch
|
||||
| OpenAIResponseInputToolFunction
|
||||
| OpenAIResponseToolMCP, # The only type that differes from that in the inputs is the MCP tool
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseTool, name="OpenAIResponseTool")
|
||||
|
||||
|
||||
class OpenAIResponseUsageOutputTokensDetails(BaseModel):
|
||||
"""Token details for output tokens in OpenAI response usage.
|
||||
|
||||
:param reasoning_tokens: Number of tokens used for reasoning (o1/o3 models)
|
||||
"""
|
||||
|
||||
reasoning_tokens: int | None = None
|
||||
|
||||
|
||||
class OpenAIResponseUsageInputTokensDetails(BaseModel):
|
||||
"""Token details for input tokens in OpenAI response usage.
|
||||
|
||||
:param cached_tokens: Number of tokens retrieved from cache
|
||||
"""
|
||||
|
||||
cached_tokens: int | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseUsage(BaseModel):
|
||||
"""Usage information for OpenAI response.
|
||||
|
||||
:param input_tokens: Number of tokens in the input
|
||||
:param output_tokens: Number of tokens in the output
|
||||
:param total_tokens: Total tokens used (input + output)
|
||||
:param input_tokens_details: Detailed breakdown of input token usage
|
||||
:param output_tokens_details: Detailed breakdown of output token usage
|
||||
"""
|
||||
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
total_tokens: int
|
||||
input_tokens_details: OpenAIResponseUsageInputTokensDetails | None = None
|
||||
output_tokens_details: OpenAIResponseUsageOutputTokensDetails | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObject(BaseModel):
|
||||
"""Complete OpenAI response object containing generation results and metadata.
|
||||
|
@ -362,7 +530,9 @@ class OpenAIResponseObject(BaseModel):
|
|||
:param temperature: (Optional) Sampling temperature used for generation
|
||||
:param text: Text formatting configuration for the response
|
||||
:param top_p: (Optional) Nucleus sampling parameter used for generation
|
||||
:param tools: (Optional) An array of tools the model may call while generating a response.
|
||||
:param truncation: (Optional) Truncation strategy applied to the response
|
||||
:param usage: (Optional) Token usage information for the response
|
||||
"""
|
||||
|
||||
created_at: int
|
||||
|
@ -379,7 +549,9 @@ class OpenAIResponseObject(BaseModel):
|
|||
# before the field was added. New responses will have this set always.
|
||||
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
top_p: float | None = None
|
||||
tools: list[OpenAIResponseTool] | None = None
|
||||
truncation: str | None = None
|
||||
usage: OpenAIResponseUsage | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -400,7 +572,7 @@ class OpenAIDeleteResponseObject(BaseModel):
|
|||
class OpenAIResponseObjectStreamResponseCreated(BaseModel):
|
||||
"""Streaming event indicating a new response has been created.
|
||||
|
||||
:param response: The newly created response object
|
||||
:param response: The response object that was created
|
||||
:param type: Event type identifier, always "response.created"
|
||||
"""
|
||||
|
||||
|
@ -408,11 +580,25 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel):
|
|||
type: Literal["response.created"] = "response.created"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseInProgress(BaseModel):
|
||||
"""Streaming event indicating the response remains in progress.
|
||||
|
||||
:param response: Current response state while in progress
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.in_progress"
|
||||
"""
|
||||
|
||||
response: OpenAIResponseObject
|
||||
sequence_number: int
|
||||
type: Literal["response.in_progress"] = "response.in_progress"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||
"""Streaming event indicating a response has been completed.
|
||||
|
||||
:param response: The completed response object
|
||||
:param response: Completed response object
|
||||
:param type: Event type identifier, always "response.completed"
|
||||
"""
|
||||
|
||||
|
@ -420,6 +606,34 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
|||
type: Literal["response.completed"] = "response.completed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseIncomplete(BaseModel):
|
||||
"""Streaming event emitted when a response ends in an incomplete state.
|
||||
|
||||
:param response: Response object describing the incomplete state
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.incomplete"
|
||||
"""
|
||||
|
||||
response: OpenAIResponseObject
|
||||
sequence_number: int
|
||||
type: Literal["response.incomplete"] = "response.incomplete"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseFailed(BaseModel):
|
||||
"""Streaming event emitted when a response fails.
|
||||
|
||||
:param response: Response object describing the failure
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.failed"
|
||||
"""
|
||||
|
||||
response: OpenAIResponseObject
|
||||
sequence_number: int
|
||||
type: Literal["response.failed"] = "response.failed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseOutputItemAdded(BaseModel):
|
||||
"""Streaming event for when a new output item is added to the response.
|
||||
|
@ -650,19 +864,46 @@ class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseContentPartOutputText(BaseModel):
|
||||
"""Text content within a streamed response part.
|
||||
|
||||
:param type: Content part type identifier, always "output_text"
|
||||
:param text: Text emitted for this content part
|
||||
:param annotations: Structured annotations associated with the text
|
||||
:param logprobs: (Optional) Token log probability details
|
||||
"""
|
||||
|
||||
type: Literal["output_text"] = "output_text"
|
||||
text: str
|
||||
# TODO: add annotations, logprobs, etc.
|
||||
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
|
||||
logprobs: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseContentPartRefusal(BaseModel):
|
||||
"""Refusal content within a streamed response part.
|
||||
|
||||
:param type: Content part type identifier, always "refusal"
|
||||
:param refusal: Refusal text supplied by the model
|
||||
"""
|
||||
|
||||
type: Literal["refusal"] = "refusal"
|
||||
refusal: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseContentPartReasoningText(BaseModel):
|
||||
"""Reasoning text emitted as part of a streamed response.
|
||||
|
||||
:param type: Content part type identifier, always "reasoning_text"
|
||||
:param text: Reasoning text supplied by the model
|
||||
"""
|
||||
|
||||
type: Literal["reasoning_text"] = "reasoning_text"
|
||||
text: str
|
||||
|
||||
|
||||
OpenAIResponseContentPart = Annotated[
|
||||
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal,
|
||||
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal | OpenAIResponseContentPartReasoningText,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
|
||||
|
@ -672,15 +913,19 @@ register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
|
|||
class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel):
|
||||
"""Streaming event for when a new content part is added to a response item.
|
||||
|
||||
:param content_index: Index position of the part within the content array
|
||||
:param response_id: Unique identifier of the response containing this content
|
||||
:param item_id: Unique identifier of the output item containing this content part
|
||||
:param output_index: Index position of the output item in the response
|
||||
:param part: The content part that was added
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.content_part.added"
|
||||
"""
|
||||
|
||||
content_index: int
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
part: OpenAIResponseContentPart
|
||||
sequence_number: int
|
||||
type: Literal["response.content_part.added"] = "response.content_part.added"
|
||||
|
@ -690,22 +935,269 @@ class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel):
|
|||
class OpenAIResponseObjectStreamResponseContentPartDone(BaseModel):
|
||||
"""Streaming event for when a content part is completed.
|
||||
|
||||
:param content_index: Index position of the part within the content array
|
||||
:param response_id: Unique identifier of the response containing this content
|
||||
:param item_id: Unique identifier of the output item containing this content part
|
||||
:param output_index: Index position of the output item in the response
|
||||
:param part: The completed content part
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.content_part.done"
|
||||
"""
|
||||
|
||||
content_index: int
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
part: OpenAIResponseContentPart
|
||||
sequence_number: int
|
||||
type: Literal["response.content_part.done"] = "response.content_part.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseReasoningTextDelta(BaseModel):
|
||||
"""Streaming event for incremental reasoning text updates.
|
||||
|
||||
:param content_index: Index position of the reasoning content part
|
||||
:param delta: Incremental reasoning text being added
|
||||
:param item_id: Unique identifier of the output item being updated
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.reasoning_text.delta"
|
||||
"""
|
||||
|
||||
content_index: int
|
||||
delta: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.reasoning_text.delta"] = "response.reasoning_text.delta"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseReasoningTextDone(BaseModel):
|
||||
"""Streaming event for when reasoning text is completed.
|
||||
|
||||
:param content_index: Index position of the reasoning content part
|
||||
:param text: Final complete reasoning text
|
||||
:param item_id: Unique identifier of the completed output item
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.reasoning_text.done"
|
||||
"""
|
||||
|
||||
content_index: int
|
||||
text: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.reasoning_text.done"] = "response.reasoning_text.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseContentPartReasoningSummary(BaseModel):
|
||||
"""Reasoning summary part in a streamed response.
|
||||
|
||||
:param type: Content part type identifier, always "summary_text"
|
||||
:param text: Summary text
|
||||
"""
|
||||
|
||||
type: Literal["summary_text"] = "summary_text"
|
||||
text: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseReasoningSummaryPartAdded(BaseModel):
|
||||
"""Streaming event for when a new reasoning summary part is added.
|
||||
|
||||
:param item_id: Unique identifier of the output item
|
||||
:param output_index: Index position of the output item
|
||||
:param part: The summary part that was added
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param summary_index: Index of the summary part within the reasoning summary
|
||||
:param type: Event type identifier, always "response.reasoning_summary_part.added"
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
output_index: int
|
||||
part: OpenAIResponseContentPartReasoningSummary
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_part.added"] = "response.reasoning_summary_part.added"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseReasoningSummaryPartDone(BaseModel):
|
||||
"""Streaming event for when a reasoning summary part is completed.
|
||||
|
||||
:param item_id: Unique identifier of the output item
|
||||
:param output_index: Index position of the output item
|
||||
:param part: The completed summary part
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param summary_index: Index of the summary part within the reasoning summary
|
||||
:param type: Event type identifier, always "response.reasoning_summary_part.done"
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
output_index: int
|
||||
part: OpenAIResponseContentPartReasoningSummary
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_part.done"] = "response.reasoning_summary_part.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseReasoningSummaryTextDelta(BaseModel):
|
||||
"""Streaming event for incremental reasoning summary text updates.
|
||||
|
||||
:param delta: Incremental summary text being added
|
||||
:param item_id: Unique identifier of the output item
|
||||
:param output_index: Index position of the output item
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param summary_index: Index of the summary part within the reasoning summary
|
||||
:param type: Event type identifier, always "response.reasoning_summary_text.delta"
|
||||
"""
|
||||
|
||||
delta: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_text.delta"] = "response.reasoning_summary_text.delta"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseReasoningSummaryTextDone(BaseModel):
|
||||
"""Streaming event for when reasoning summary text is completed.
|
||||
|
||||
:param text: Final complete summary text
|
||||
:param item_id: Unique identifier of the output item
|
||||
:param output_index: Index position of the output item
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param summary_index: Index of the summary part within the reasoning summary
|
||||
:param type: Event type identifier, always "response.reasoning_summary_text.done"
|
||||
"""
|
||||
|
||||
text: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
summary_index: int
|
||||
type: Literal["response.reasoning_summary_text.done"] = "response.reasoning_summary_text.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseRefusalDelta(BaseModel):
|
||||
"""Streaming event for incremental refusal text updates.
|
||||
|
||||
:param content_index: Index position of the content part
|
||||
:param delta: Incremental refusal text being added
|
||||
:param item_id: Unique identifier of the output item
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.refusal.delta"
|
||||
"""
|
||||
|
||||
content_index: int
|
||||
delta: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.refusal.delta"] = "response.refusal.delta"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseRefusalDone(BaseModel):
|
||||
"""Streaming event for when refusal text is completed.
|
||||
|
||||
:param content_index: Index position of the content part
|
||||
:param refusal: Final complete refusal text
|
||||
:param item_id: Unique identifier of the output item
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.refusal.done"
|
||||
"""
|
||||
|
||||
content_index: int
|
||||
refusal: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.refusal.done"] = "response.refusal.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseOutputTextAnnotationAdded(BaseModel):
|
||||
"""Streaming event for when an annotation is added to output text.
|
||||
|
||||
:param item_id: Unique identifier of the item to which the annotation is being added
|
||||
:param output_index: Index position of the output item in the response's output array
|
||||
:param content_index: Index position of the content part within the output item
|
||||
:param annotation_index: Index of the annotation within the content part
|
||||
:param annotation: The annotation object being added
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.output_text.annotation.added"
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
annotation_index: int
|
||||
annotation: OpenAIResponseAnnotations
|
||||
sequence_number: int
|
||||
type: Literal["response.output_text.annotation.added"] = "response.output_text.annotation.added"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseFileSearchCallInProgress(BaseModel):
|
||||
"""Streaming event for file search calls in progress.
|
||||
|
||||
:param item_id: Unique identifier of the file search call
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.file_search_call.in_progress"
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.file_search_call.in_progress"] = "response.file_search_call.in_progress"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseFileSearchCallSearching(BaseModel):
|
||||
"""Streaming event for file search currently searching.
|
||||
|
||||
:param item_id: Unique identifier of the file search call
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.file_search_call.searching"
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.file_search_call.searching"] = "response.file_search_call.searching"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseFileSearchCallCompleted(BaseModel):
|
||||
"""Streaming event for completed file search calls.
|
||||
|
||||
:param item_id: Unique identifier of the completed file search call
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.file_search_call.completed"
|
||||
"""
|
||||
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.file_search_call.completed"] = "response.file_search_call.completed"
|
||||
|
||||
|
||||
OpenAIResponseObjectStream = Annotated[
|
||||
OpenAIResponseObjectStreamResponseCreated
|
||||
| OpenAIResponseObjectStreamResponseInProgress
|
||||
| OpenAIResponseObjectStreamResponseOutputItemAdded
|
||||
| OpenAIResponseObjectStreamResponseOutputItemDone
|
||||
| OpenAIResponseObjectStreamResponseOutputTextDelta
|
||||
|
@ -725,6 +1217,20 @@ OpenAIResponseObjectStream = Annotated[
|
|||
| OpenAIResponseObjectStreamResponseMcpCallCompleted
|
||||
| OpenAIResponseObjectStreamResponseContentPartAdded
|
||||
| OpenAIResponseObjectStreamResponseContentPartDone
|
||||
| OpenAIResponseObjectStreamResponseReasoningTextDelta
|
||||
| OpenAIResponseObjectStreamResponseReasoningTextDone
|
||||
| OpenAIResponseObjectStreamResponseReasoningSummaryPartAdded
|
||||
| OpenAIResponseObjectStreamResponseReasoningSummaryPartDone
|
||||
| OpenAIResponseObjectStreamResponseReasoningSummaryTextDelta
|
||||
| OpenAIResponseObjectStreamResponseReasoningSummaryTextDone
|
||||
| OpenAIResponseObjectStreamResponseRefusalDelta
|
||||
| OpenAIResponseObjectStreamResponseRefusalDone
|
||||
| OpenAIResponseObjectStreamResponseOutputTextAnnotationAdded
|
||||
| OpenAIResponseObjectStreamResponseFileSearchCallInProgress
|
||||
| OpenAIResponseObjectStreamResponseFileSearchCallSearching
|
||||
| OpenAIResponseObjectStreamResponseFileSearchCallCompleted
|
||||
| OpenAIResponseObjectStreamResponseIncomplete
|
||||
| OpenAIResponseObjectStreamResponseFailed
|
||||
| OpenAIResponseObjectStreamResponseCompleted,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
@ -760,114 +1266,6 @@ OpenAIResponseInput = Annotated[
|
|||
register_schema(OpenAIResponseInput, name="OpenAIResponseInput")
|
||||
|
||||
|
||||
# Must match type Literals of OpenAIResponseInputToolWebSearch below
|
||||
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11"]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolWebSearch(BaseModel):
|
||||
"""Web search tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Web search tool type variant to use
|
||||
:param search_context_size: (Optional) Size of search context, must be "low", "medium", or "high"
|
||||
"""
|
||||
|
||||
# Must match values of WebSearchToolTypes above
|
||||
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
|
||||
"web_search"
|
||||
)
|
||||
# TODO: actually use search_context_size somewhere...
|
||||
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
|
||||
# TODO: add user_location
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFunction(BaseModel):
|
||||
"""Function tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Tool type identifier, always "function"
|
||||
:param name: Name of the function that can be called
|
||||
:param description: (Optional) Description of what the function does
|
||||
:param parameters: (Optional) JSON schema defining the function's parameters
|
||||
:param strict: (Optional) Whether to enforce strict parameter validation
|
||||
"""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
name: str
|
||||
description: str | None = None
|
||||
parameters: dict[str, Any] | None
|
||||
strict: bool | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||
"""File search tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Tool type identifier, always "file_search"
|
||||
:param vector_store_ids: List of vector store identifiers to search within
|
||||
:param filters: (Optional) Additional filters to apply to the search
|
||||
:param max_num_results: (Optional) Maximum number of search results to return (1-50)
|
||||
:param ranking_options: (Optional) Options for ranking and scoring search results
|
||||
"""
|
||||
|
||||
type: Literal["file_search"] = "file_search"
|
||||
vector_store_ids: list[str]
|
||||
filters: dict[str, Any] | None = None
|
||||
max_num_results: int | None = Field(default=10, ge=1, le=50)
|
||||
ranking_options: FileSearchRankingOptions | None = None
|
||||
|
||||
|
||||
class ApprovalFilter(BaseModel):
|
||||
"""Filter configuration for MCP tool approval requirements.
|
||||
|
||||
:param always: (Optional) List of tool names that always require approval
|
||||
:param never: (Optional) List of tool names that never require approval
|
||||
"""
|
||||
|
||||
always: list[str] | None = None
|
||||
never: list[str] | None = None
|
||||
|
||||
|
||||
class AllowedToolsFilter(BaseModel):
|
||||
"""Filter configuration for restricting which MCP tools can be used.
|
||||
|
||||
:param tool_names: (Optional) List of specific tool names that are allowed
|
||||
"""
|
||||
|
||||
tool_names: list[str] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolMCP(BaseModel):
|
||||
"""Model Context Protocol (MCP) tool configuration for OpenAI response inputs.
|
||||
|
||||
:param type: Tool type identifier, always "mcp"
|
||||
:param server_label: Label to identify this MCP server
|
||||
:param server_url: URL endpoint of the MCP server
|
||||
:param headers: (Optional) HTTP headers to include when connecting to the server
|
||||
:param require_approval: Approval requirement for tool calls ("always", "never", or filter)
|
||||
:param allowed_tools: (Optional) Restriction on which tools can be used from this server
|
||||
"""
|
||||
|
||||
type: Literal["mcp"] = "mcp"
|
||||
server_label: str
|
||||
server_url: str
|
||||
headers: dict[str, Any] | None = None
|
||||
|
||||
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
|
||||
allowed_tools: list[str] | AllowedToolsFilter | None = None
|
||||
|
||||
|
||||
OpenAIResponseInputTool = Annotated[
|
||||
OpenAIResponseInputToolWebSearch
|
||||
| OpenAIResponseInputToolFileSearch
|
||||
| OpenAIResponseInputToolFunction
|
||||
| OpenAIResponseInputToolMCP,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||
|
||||
|
||||
class ListOpenAIResponseInputItem(BaseModel):
|
||||
"""List container for OpenAI response input items.
|
||||
|
||||
|
|
|
@ -86,3 +86,18 @@ class TokenValidationError(ValueError):
|
|||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ConversationNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced conversation"""
|
||||
|
||||
def __init__(self, conversation_id: str) -> None:
|
||||
super().__init__(conversation_id, "Conversation", "client.conversations.list()")
|
||||
|
||||
|
||||
class InvalidConversationIdError(ValueError):
|
||||
"""raised when a conversation ID has an invalid format"""
|
||||
|
||||
def __init__(self, conversation_id: str) -> None:
|
||||
message = f"Invalid conversation ID '{conversation_id}'. Expected an ID that begins with 'conv_'."
|
||||
super().__init__(message)
|
||||
|
|
|
@ -96,7 +96,6 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
:cvar telemetry: Observability and system monitoring
|
||||
:cvar models: Model metadata and management
|
||||
:cvar shields: Safety shield implementations
|
||||
:cvar vector_dbs: Vector database management
|
||||
:cvar datasets: Dataset creation and management
|
||||
:cvar scoring_functions: Scoring function definitions
|
||||
:cvar benchmarks: Benchmark suite management
|
||||
|
@ -122,7 +121,6 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
|
||||
models = "models"
|
||||
shields = "shields"
|
||||
vector_dbs = "vector_dbs"
|
||||
datasets = "datasets"
|
||||
scoring_functions = "scoring_functions"
|
||||
benchmarks = "benchmarks"
|
||||
|
|
|
@ -104,6 +104,11 @@ class OpenAIFileDeleteResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Files(Protocol):
|
||||
"""Files
|
||||
|
||||
This API is used to upload documents that can be used with other Llama Stack APIs.
|
||||
"""
|
||||
|
||||
# OpenAI Files API Endpoints
|
||||
@webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
|
||||
|
@ -113,7 +118,8 @@ class Files(Protocol):
|
|||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
"""Upload file.
|
||||
|
||||
Upload a file that can be used across various endpoints.
|
||||
|
||||
The file upload should be a multipart form request with:
|
||||
|
@ -137,7 +143,8 @@ class Files(Protocol):
|
|||
order: Order | None = Order.desc,
|
||||
purpose: OpenAIFilePurpose | None = None,
|
||||
) -> ListOpenAIFileResponse:
|
||||
"""
|
||||
"""List files.
|
||||
|
||||
Returns a list of files that belong to the user's organization.
|
||||
|
||||
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
|
||||
|
@ -154,7 +161,8 @@ class Files(Protocol):
|
|||
self,
|
||||
file_id: str,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
"""Retrieve file.
|
||||
|
||||
Returns information about a specific file.
|
||||
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
|
@ -168,8 +176,7 @@ class Files(Protocol):
|
|||
self,
|
||||
file_id: str,
|
||||
) -> OpenAIFileDeleteResponse:
|
||||
"""
|
||||
Delete a file.
|
||||
"""Delete file.
|
||||
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
:returns: An OpenAIFileDeleteResponse indicating successful deletion.
|
||||
|
@ -182,7 +189,8 @@ class Files(Protocol):
|
|||
self,
|
||||
file_id: str,
|
||||
) -> Response:
|
||||
"""
|
||||
"""Retrieve file content.
|
||||
|
||||
Returns the contents of the specified file.
|
||||
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
|
|
|
@ -14,6 +14,7 @@ from typing import (
|
|||
runtime_checkable,
|
||||
)
|
||||
|
||||
from fastapi import Body
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
@ -776,12 +777,14 @@ class OpenAIChoiceDelta(BaseModel):
|
|||
:param refusal: (Optional) The refusal of the delta
|
||||
:param role: (Optional) The role of the delta
|
||||
:param tool_calls: (Optional) The tool calls of the delta
|
||||
:param reasoning_content: (Optional) The reasoning content from the model (non-standard, for o1/o3 models)
|
||||
"""
|
||||
|
||||
content: str | None = None
|
||||
refusal: str | None = None
|
||||
role: str | None = None
|
||||
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||
reasoning_content: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -816,6 +819,42 @@ class OpenAIChoice(BaseModel):
|
|||
logprobs: OpenAIChoiceLogprobs | None = None
|
||||
|
||||
|
||||
class OpenAIChatCompletionUsageCompletionTokensDetails(BaseModel):
|
||||
"""Token details for output tokens in OpenAI chat completion usage.
|
||||
|
||||
:param reasoning_tokens: Number of tokens used for reasoning (o1/o3 models)
|
||||
"""
|
||||
|
||||
reasoning_tokens: int | None = None
|
||||
|
||||
|
||||
class OpenAIChatCompletionUsagePromptTokensDetails(BaseModel):
|
||||
"""Token details for prompt tokens in OpenAI chat completion usage.
|
||||
|
||||
:param cached_tokens: Number of tokens retrieved from cache
|
||||
"""
|
||||
|
||||
cached_tokens: int | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionUsage(BaseModel):
|
||||
"""Usage information for OpenAI chat completion.
|
||||
|
||||
:param prompt_tokens: Number of tokens in the prompt
|
||||
:param completion_tokens: Number of tokens in the completion
|
||||
:param total_tokens: Total tokens used (prompt + completion)
|
||||
:param input_tokens_details: Detailed breakdown of input token usage
|
||||
:param output_tokens_details: Detailed breakdown of output token usage
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_tokens_details: OpenAIChatCompletionUsagePromptTokensDetails | None = None
|
||||
completion_tokens_details: OpenAIChatCompletionUsageCompletionTokensDetails | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletion(BaseModel):
|
||||
"""Response from an OpenAI-compatible chat completion request.
|
||||
|
@ -825,6 +864,7 @@ class OpenAIChatCompletion(BaseModel):
|
|||
:param object: The object type, which will be "chat.completion"
|
||||
:param created: The Unix timestamp in seconds when the chat completion was created
|
||||
:param model: The model that was used to generate the chat completion
|
||||
:param usage: Token usage information for the completion
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
@ -832,6 +872,7 @@ class OpenAIChatCompletion(BaseModel):
|
|||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
usage: OpenAIChatCompletionUsage | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -843,6 +884,7 @@ class OpenAIChatCompletionChunk(BaseModel):
|
|||
:param object: The object type, which will be "chat.completion.chunk"
|
||||
:param created: The Unix timestamp in seconds when the chat completion was created
|
||||
:param model: The model that was used to generate the chat completion
|
||||
:param usage: Token usage information (typically included in final chunk with stream_options)
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
@ -850,6 +892,7 @@ class OpenAIChatCompletionChunk(BaseModel):
|
|||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
usage: OpenAIChatCompletionUsage | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -995,6 +1038,127 @@ class ListOpenAIChatCompletionResponse(BaseModel):
|
|||
object: Literal["list"] = "list"
|
||||
|
||||
|
||||
# extra_body can be accessed via .model_extra
|
||||
@json_schema_type
|
||||
class OpenAICompletionRequestWithExtraBody(BaseModel, extra="allow"):
|
||||
"""Request parameters for OpenAI-compatible completion endpoint.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param prompt: The prompt to generate a completion for.
|
||||
:param best_of: (Optional) The number of completions to generate.
|
||||
:param echo: (Optional) Whether to echo the prompt.
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param logit_bias: (Optional) The logit bias to use.
|
||||
:param logprobs: (Optional) The log probabilities to use.
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param n: (Optional) The number of completions to generate.
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param seed: (Optional) The seed to use.
|
||||
:param stop: (Optional) The stop tokens to use.
|
||||
:param stream: (Optional) Whether to stream the response.
|
||||
:param stream_options: (Optional) The stream options to use.
|
||||
:param temperature: (Optional) The temperature to use.
|
||||
:param top_p: (Optional) The top p to use.
|
||||
:param user: (Optional) The user to use.
|
||||
:param suffix: (Optional) The suffix that should be appended to the completion.
|
||||
"""
|
||||
|
||||
# Standard OpenAI completion parameters
|
||||
model: str
|
||||
prompt: str | list[str] | list[int] | list[list[int]]
|
||||
best_of: int | None = None
|
||||
echo: bool | None = None
|
||||
frequency_penalty: float | None = None
|
||||
logit_bias: dict[str, float] | None = None
|
||||
logprobs: bool | None = None
|
||||
max_tokens: int | None = None
|
||||
n: int | None = None
|
||||
presence_penalty: float | None = None
|
||||
seed: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
stream: bool | None = None
|
||||
stream_options: dict[str, Any] | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
user: str | None = None
|
||||
suffix: str | None = None
|
||||
|
||||
|
||||
# extra_body can be accessed via .model_extra
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionRequestWithExtraBody(BaseModel, extra="allow"):
|
||||
"""Request parameters for OpenAI-compatible chat completion endpoint.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation.
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param function_call: (Optional) The function call to use.
|
||||
:param functions: (Optional) List of functions to use.
|
||||
:param logit_bias: (Optional) The logit bias to use.
|
||||
:param logprobs: (Optional) The log probabilities to use.
|
||||
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param n: (Optional) The number of completions to generate.
|
||||
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param response_format: (Optional) The response format to use.
|
||||
:param seed: (Optional) The seed to use.
|
||||
:param stop: (Optional) The stop tokens to use.
|
||||
:param stream: (Optional) Whether to stream the response.
|
||||
:param stream_options: (Optional) The stream options to use.
|
||||
:param temperature: (Optional) The temperature to use.
|
||||
:param tool_choice: (Optional) The tool choice to use.
|
||||
:param tools: (Optional) The tools to use.
|
||||
:param top_logprobs: (Optional) The top log probabilities to use.
|
||||
:param top_p: (Optional) The top p to use.
|
||||
:param user: (Optional) The user to use.
|
||||
"""
|
||||
|
||||
# Standard OpenAI chat completion parameters
|
||||
model: str
|
||||
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)]
|
||||
frequency_penalty: float | None = None
|
||||
function_call: str | dict[str, Any] | None = None
|
||||
functions: list[dict[str, Any]] | None = None
|
||||
logit_bias: dict[str, float] | None = None
|
||||
logprobs: bool | None = None
|
||||
max_completion_tokens: int | None = None
|
||||
max_tokens: int | None = None
|
||||
n: int | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
presence_penalty: float | None = None
|
||||
response_format: OpenAIResponseFormatParam | None = None
|
||||
seed: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
stream: bool | None = None
|
||||
stream_options: dict[str, Any] | None = None
|
||||
temperature: float | None = None
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
top_logprobs: int | None = None
|
||||
top_p: float | None = None
|
||||
user: str | None = None
|
||||
|
||||
|
||||
# extra_body can be accessed via .model_extra
|
||||
@json_schema_type
|
||||
class OpenAIEmbeddingsRequestWithExtraBody(BaseModel, extra="allow"):
|
||||
"""Request parameters for OpenAI-compatible embeddings endpoint.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
|
||||
:param encoding_format: (Optional) The format to return the embeddings in. Can be either "float" or "base64". Defaults to "float".
|
||||
:param dimensions: (Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
|
||||
:param user: (Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
|
||||
"""
|
||||
|
||||
model: str
|
||||
input: str | list[str]
|
||||
encoding_format: str | None = "float"
|
||||
dimensions: int | None = None
|
||||
user: str | None = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class InferenceProvider(Protocol):
|
||||
|
@ -1029,50 +1193,11 @@ class InferenceProvider(Protocol):
|
|||
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_completion(
|
||||
self,
|
||||
# Standard OpenAI completion parameters
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
# vLLM-specific parameters
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
# for fill-in-the-middle type completion
|
||||
suffix: str | None = None,
|
||||
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
|
||||
) -> OpenAICompletion:
|
||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
"""Create completion.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param prompt: The prompt to generate a completion for.
|
||||
:param best_of: (Optional) The number of completions to generate.
|
||||
:param echo: (Optional) Whether to echo the prompt.
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param logit_bias: (Optional) The logit bias to use.
|
||||
:param logprobs: (Optional) The log probabilities to use.
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param n: (Optional) The number of completions to generate.
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param seed: (Optional) The seed to use.
|
||||
:param stop: (Optional) The stop tokens to use.
|
||||
:param stream: (Optional) Whether to stream the response.
|
||||
:param stream_options: (Optional) The stream options to use.
|
||||
:param temperature: (Optional) The temperature to use.
|
||||
:param top_p: (Optional) The top p to use.
|
||||
:param user: (Optional) The user to use.
|
||||
:param suffix: (Optional) The suffix that should be appended to the completion.
|
||||
Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
:returns: An OpenAICompletion.
|
||||
"""
|
||||
...
|
||||
|
@ -1081,55 +1206,11 @@ class InferenceProvider(Protocol):
|
|||
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
params: Annotated[OpenAIChatCompletionRequestWithExtraBody, Body(...)],
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
"""Create chat completions.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation.
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param function_call: (Optional) The function call to use.
|
||||
:param functions: (Optional) List of functions to use.
|
||||
:param logit_bias: (Optional) The logit bias to use.
|
||||
:param logprobs: (Optional) The log probabilities to use.
|
||||
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param n: (Optional) The number of completions to generate.
|
||||
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param response_format: (Optional) The response format to use.
|
||||
:param seed: (Optional) The seed to use.
|
||||
:param stop: (Optional) The stop tokens to use.
|
||||
:param stream: (Optional) Whether to stream the response.
|
||||
:param stream_options: (Optional) The stream options to use.
|
||||
:param temperature: (Optional) The temperature to use.
|
||||
:param tool_choice: (Optional) The tool choice to use.
|
||||
:param tools: (Optional) The tools to use.
|
||||
:param top_logprobs: (Optional) The top log probabilities to use.
|
||||
:param top_p: (Optional) The top p to use.
|
||||
:param user: (Optional) The user to use.
|
||||
Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
:returns: An OpenAIChatCompletion.
|
||||
"""
|
||||
...
|
||||
|
@ -1138,26 +1219,20 @@ class InferenceProvider(Protocol):
|
|||
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
params: Annotated[OpenAIEmbeddingsRequestWithExtraBody, Body(...)],
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""Generate OpenAI-compatible embeddings for the given input using the specified model.
|
||||
"""Create embeddings.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
|
||||
:param encoding_format: (Optional) The format to return the embeddings in. Can be either "float" or "base64". Defaults to "float".
|
||||
:param dimensions: (Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
|
||||
:param user: (Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
|
||||
Generate OpenAI-compatible embeddings for the given input using the specified model.
|
||||
:returns: An OpenAIEmbeddingsResponse containing the embeddings.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Inference(InferenceProvider):
|
||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
"""Inference
|
||||
|
||||
Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
|
@ -1173,7 +1248,7 @@ class Inference(InferenceProvider):
|
|||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIChatCompletionResponse:
|
||||
"""List all chat completions.
|
||||
"""List chat completions.
|
||||
|
||||
:param after: The ID of the last chat completion to return.
|
||||
:param limit: The maximum number of chat completions to return.
|
||||
|
@ -1188,7 +1263,9 @@ class Inference(InferenceProvider):
|
|||
)
|
||||
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
"""Describe a chat completion by its ID.
|
||||
"""Get chat completion.
|
||||
|
||||
Describe a chat completion by its ID.
|
||||
|
||||
:param completion_id: ID of the chat completion.
|
||||
:returns: A OpenAICompletionWithInputMessages.
|
||||
|
|
|
@ -58,25 +58,36 @@ class ListRoutesResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
"""Inspect
|
||||
|
||||
APIs for inspecting the Llama Stack service, including health status, available API routes with methods and implementing providers.
|
||||
"""
|
||||
|
||||
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
"""List all available API routes with their methods and implementing providers.
|
||||
"""List routes.
|
||||
|
||||
List all available API routes with their methods and implementing providers.
|
||||
|
||||
:returns: Response containing information about all available routes.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
|
||||
async def health(self) -> HealthInfo:
|
||||
"""Get the current health status of the service.
|
||||
"""Get health status.
|
||||
|
||||
Get the current health status of the service.
|
||||
|
||||
:returns: Health information indicating if the service is operational.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
|
||||
async def version(self) -> VersionInfo:
|
||||
"""Get the version of the service.
|
||||
"""Get version.
|
||||
|
||||
Get the version of the service.
|
||||
|
||||
:returns: Version information containing the service version number.
|
||||
"""
|
||||
|
|
|
@ -124,7 +124,9 @@ class Models(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
) -> Model:
|
||||
"""Get a model by its identifier.
|
||||
"""Get model.
|
||||
|
||||
Get a model by its identifier.
|
||||
|
||||
:param model_id: The identifier of the model to get.
|
||||
:returns: A Model.
|
||||
|
@ -140,7 +142,9 @@ class Models(Protocol):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model:
|
||||
"""Register a model.
|
||||
"""Register model.
|
||||
|
||||
Register a model.
|
||||
|
||||
:param model_id: The identifier of the model to register.
|
||||
:param provider_model_id: The identifier of the model in the provider.
|
||||
|
@ -156,7 +160,9 @@ class Models(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
) -> None:
|
||||
"""Unregister a model.
|
||||
"""Unregister model.
|
||||
|
||||
Unregister a model.
|
||||
|
||||
:param model_id: The identifier of the model to unregister.
|
||||
"""
|
||||
|
|
|
@ -94,7 +94,9 @@ class ListPromptsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Prompts(Protocol):
|
||||
"""Protocol for prompt management operations."""
|
||||
"""Prompts
|
||||
|
||||
Protocol for prompt management operations."""
|
||||
|
||||
@webmethod(route="/prompts", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_prompts(self) -> ListPromptsResponse:
|
||||
|
@ -109,7 +111,9 @@ class Prompts(Protocol):
|
|||
self,
|
||||
prompt_id: str,
|
||||
) -> ListPromptsResponse:
|
||||
"""List all versions of a specific prompt.
|
||||
"""List prompt versions.
|
||||
|
||||
List all versions of a specific prompt.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to list versions for.
|
||||
:returns: A ListPromptsResponse containing all versions of the prompt.
|
||||
|
@ -122,7 +126,9 @@ class Prompts(Protocol):
|
|||
prompt_id: str,
|
||||
version: int | None = None,
|
||||
) -> Prompt:
|
||||
"""Get a prompt by its identifier and optional version.
|
||||
"""Get prompt.
|
||||
|
||||
Get a prompt by its identifier and optional version.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to get.
|
||||
:param version: The version of the prompt to get (defaults to latest).
|
||||
|
@ -136,7 +142,9 @@ class Prompts(Protocol):
|
|||
prompt: str,
|
||||
variables: list[str] | None = None,
|
||||
) -> Prompt:
|
||||
"""Create a new prompt.
|
||||
"""Create prompt.
|
||||
|
||||
Create a new prompt.
|
||||
|
||||
:param prompt: The prompt text content with variable placeholders.
|
||||
:param variables: List of variable names that can be used in the prompt template.
|
||||
|
@ -153,7 +161,9 @@ class Prompts(Protocol):
|
|||
variables: list[str] | None = None,
|
||||
set_as_default: bool = True,
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt (increments version).
|
||||
"""Update prompt.
|
||||
|
||||
Update an existing prompt (increments version).
|
||||
|
||||
:param prompt_id: The identifier of the prompt to update.
|
||||
:param prompt: The updated prompt text content.
|
||||
|
@ -169,7 +179,9 @@ class Prompts(Protocol):
|
|||
self,
|
||||
prompt_id: str,
|
||||
) -> None:
|
||||
"""Delete a prompt.
|
||||
"""Delete prompt.
|
||||
|
||||
Delete a prompt.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to delete.
|
||||
"""
|
||||
|
@ -181,7 +193,9 @@ class Prompts(Protocol):
|
|||
prompt_id: str,
|
||||
version: int,
|
||||
) -> Prompt:
|
||||
"""Set which version of a prompt should be the default in get_prompt (latest).
|
||||
"""Set prompt version.
|
||||
|
||||
Set which version of a prompt should be the default in get_prompt (latest).
|
||||
|
||||
:param prompt_id: The identifier of the prompt.
|
||||
:param version: The version to set as default.
|
||||
|
|
|
@ -42,13 +42,16 @@ class ListProvidersResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Providers(Protocol):
|
||||
"""
|
||||
"""Providers
|
||||
|
||||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
"""List all available providers.
|
||||
"""List providers.
|
||||
|
||||
List all available providers.
|
||||
|
||||
:returns: A ListProvidersResponse containing information about all providers.
|
||||
"""
|
||||
|
@ -56,7 +59,9 @@ class Providers(Protocol):
|
|||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
"""Get detailed information about a specific provider.
|
||||
"""Get provider.
|
||||
|
||||
Get detailed information about a specific provider.
|
||||
|
||||
:param provider_id: The ID of the provider to inspect.
|
||||
:returns: A ProviderInfo object containing the provider's details.
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
@ -96,16 +96,23 @@ class ShieldStore(Protocol):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Safety(Protocol):
|
||||
"""Safety
|
||||
|
||||
OpenAI-compatible Moderations API.
|
||||
"""
|
||||
|
||||
shield_store: ShieldStore
|
||||
|
||||
@webmethod(route="/safety/run-shield", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse:
|
||||
"""Run a shield.
|
||||
"""Run shield.
|
||||
|
||||
Run a shield.
|
||||
|
||||
:param shield_id: The identifier of the shield to run.
|
||||
:param messages: The messages to run the shield on.
|
||||
|
@ -117,7 +124,9 @@ class Safety(Protocol):
|
|||
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
"""Classifies if text and/or image inputs are potentially harmful.
|
||||
"""Create moderation.
|
||||
|
||||
Classifies if text and/or image inputs are potentially harmful.
|
||||
:param input: Input (or inputs) to classify.
|
||||
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
|
||||
:param model: The content moderation model you would like to use.
|
||||
|
|
|
@ -4,14 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -61,57 +59,3 @@ class ListVectorDBsResponse(BaseModel):
|
|||
"""
|
||||
|
||||
data: list[VectorDB]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class VectorDBs(Protocol):
|
||||
@webmethod(route="/vector-dbs", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
"""List all vector databases.
|
||||
|
||||
:returns: A ListVectorDBsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
) -> VectorDB:
|
||||
"""Get a vector database by its identifier.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to get.
|
||||
:returns: A VectorDB.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorDB:
|
||||
"""Register a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to register.
|
||||
:param embedding_model: The embedding model to use.
|
||||
:param embedding_dimension: The dimension of the embedding model.
|
||||
:param provider_id: The identifier of the provider.
|
||||
:param vector_db_name: The name of the vector database.
|
||||
:param provider_vector_db_id: The identifier of the vector database in the provider.
|
||||
:returns: A VectorDB.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
"""Unregister a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
import uuid
|
||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import Body
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
|
@ -466,6 +467,40 @@ class VectorStoreFilesListInBatchResponse(BaseModel):
|
|||
has_more: bool = False
|
||||
|
||||
|
||||
# extra_body can be accessed via .model_extra
|
||||
@json_schema_type
|
||||
class OpenAICreateVectorStoreRequestWithExtraBody(BaseModel, extra="allow"):
|
||||
"""Request to create a vector store with extra_body support.
|
||||
|
||||
:param name: (Optional) A name for the vector store
|
||||
:param file_ids: List of file IDs to include in the vector store
|
||||
:param expires_after: (Optional) Expiration policy for the vector store
|
||||
:param chunking_strategy: (Optional) Strategy for splitting files into chunks
|
||||
:param metadata: Set of key-value pairs that can be attached to the vector store
|
||||
"""
|
||||
|
||||
name: str | None = None
|
||||
file_ids: list[str] | None = None
|
||||
expires_after: dict[str, Any] | None = None
|
||||
chunking_strategy: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# extra_body can be accessed via .model_extra
|
||||
@json_schema_type
|
||||
class OpenAICreateVectorStoreFileBatchRequestWithExtraBody(BaseModel, extra="allow"):
|
||||
"""Request to create a vector store file batch with extra_body support.
|
||||
|
||||
:param file_ids: A list of File IDs that the vector store should use
|
||||
:param attributes: (Optional) Key-value attributes to store with the files
|
||||
:param chunking_strategy: (Optional) The chunking strategy used to chunk the file(s). Defaults to auto
|
||||
"""
|
||||
|
||||
file_ids: list[str]
|
||||
attributes: dict[str, Any] | None = None
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None
|
||||
|
||||
|
||||
class VectorDBStore(Protocol):
|
||||
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
|
||||
|
||||
|
@ -516,25 +551,11 @@ class VectorIO(Protocol):
|
|||
@webmethod(route="/vector_stores", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
embedding_model: str | None = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)],
|
||||
) -> VectorStoreObject:
|
||||
"""Creates a vector store.
|
||||
|
||||
:param name: A name for the vector store.
|
||||
:param file_ids: A list of File IDs that the vector store should use. Useful for tools like `file_search` that can access files.
|
||||
:param expires_after: The expiration policy for a vector store.
|
||||
:param chunking_strategy: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.
|
||||
:param metadata: Set of 16 key-value pairs that can be attached to an object.
|
||||
:param embedding_model: The embedding model to use for this vector store.
|
||||
:param embedding_dimension: The dimension of the embedding vectors (default: 384).
|
||||
:param provider_id: The ID of the provider to use for this vector store.
|
||||
Generate an OpenAI-compatible vector store with the given parameters.
|
||||
:returns: A VectorStoreObject representing the created vector store.
|
||||
"""
|
||||
...
|
||||
|
@ -827,16 +848,12 @@ class VectorIO(Protocol):
|
|||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_ids: list[str],
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
params: Annotated[OpenAICreateVectorStoreFileBatchRequestWithExtraBody, Body(...)],
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Create a vector store file batch.
|
||||
|
||||
Generate an OpenAI-compatible vector store file batch for the given vector store.
|
||||
:param vector_store_id: The ID of the vector store to create the file batch for.
|
||||
:param file_ids: A list of File IDs that the vector store should use.
|
||||
:param attributes: (Optional) Key-value attributes to store with the files.
|
||||
:param chunking_strategy: (Optional) The chunking strategy used to chunk the file(s). Defaults to auto.
|
||||
:returns: A VectorStoreFileBatchObject representing the created file batch.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -1,495 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
DownloadColumn,
|
||||
Progress,
|
||||
TextColumn,
|
||||
TimeRemainingColumn,
|
||||
TransferSpeedColumn,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||
from llama_stack.models.llama.sku_types import Model
|
||||
|
||||
|
||||
class Download(Subcommand):
|
||||
"""Llama cli for downloading llama toolchain assets"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"download",
|
||||
prog="llama download",
|
||||
description="Download a model from llama.meta.com or Hugging Face Hub",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
setup_download_parser(self.parser)
|
||||
|
||||
|
||||
def setup_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
choices=["meta", "huggingface"],
|
||||
default="meta",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
required=False,
|
||||
help="See `llama model list` or `llama model list --show-all` for the list of available models. Specify multiple model IDs with commas, e.g. --model-id Llama3.2-1B,Llama3.2-3B",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-token",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help="Hugging Face API token. Needed for gated models like llama2/3. Will also try to read environment variable `HF_TOKEN` as default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--meta-url",
|
||||
type=str,
|
||||
required=False,
|
||||
help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-parallel",
|
||||
type=int,
|
||||
required=False,
|
||||
default=3,
|
||||
help="Maximum number of concurrent downloads",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore-patterns",
|
||||
type=str,
|
||||
required=False,
|
||||
default="*.safetensors",
|
||||
help="""For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring
|
||||
safetensors files to avoid downloading duplicate weights.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--manifest-file",
|
||||
type=str,
|
||||
help="For source=meta, you can download models from a manifest file containing a file => URL mapping",
|
||||
required=False,
|
||||
)
|
||||
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadTask:
|
||||
url: str
|
||||
output_file: str
|
||||
total_size: int = 0
|
||||
downloaded_size: int = 0
|
||||
task_id: int | None = None
|
||||
retries: int = 0
|
||||
max_retries: int = 3
|
||||
|
||||
|
||||
class DownloadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CustomTransferSpeedColumn(TransferSpeedColumn):
|
||||
def render(self, task):
|
||||
if task.finished:
|
||||
return "-"
|
||||
return super().render(task)
|
||||
|
||||
|
||||
class ParallelDownloader:
|
||||
def __init__(
|
||||
self,
|
||||
max_concurrent_downloads: int = 3,
|
||||
buffer_size: int = 1024 * 1024,
|
||||
timeout: int = 30,
|
||||
):
|
||||
self.max_concurrent_downloads = max_concurrent_downloads
|
||||
self.buffer_size = buffer_size
|
||||
self.timeout = timeout
|
||||
self.console = Console()
|
||||
self.progress = Progress(
|
||||
TextColumn("[bold blue]{task.description}"),
|
||||
BarColumn(bar_width=40),
|
||||
"[progress.percentage]{task.percentage:>3.1f}%",
|
||||
DownloadColumn(),
|
||||
CustomTransferSpeedColumn(),
|
||||
TimeRemainingColumn(),
|
||||
console=self.console,
|
||||
expand=True,
|
||||
)
|
||||
self.client_options = {
|
||||
"timeout": httpx.Timeout(timeout),
|
||||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
async def retry_with_exponential_backoff(self, task: DownloadTask, func, *args, **kwargs):
|
||||
last_exception = None
|
||||
for attempt in range(task.max_retries):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < task.max_retries - 1:
|
||||
wait_time = min(30, 2**attempt) # Cap at 30 seconds
|
||||
self.console.print(
|
||||
f"[yellow]Attempt {attempt + 1}/{task.max_retries} failed, "
|
||||
f"retrying in {wait_time} seconds: {str(e)}[/yellow]"
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
raise last_exception
|
||||
|
||||
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
|
||||
if task.total_size > 0:
|
||||
self.progress.update(task.task_id, total=task.total_size)
|
||||
return
|
||||
|
||||
async def _get_info():
|
||||
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
try:
|
||||
response = await self.retry_with_exponential_backoff(task, _get_info)
|
||||
|
||||
task.url = str(response.url)
|
||||
task.total_size = int(response.headers.get("Content-Length", 0))
|
||||
|
||||
if task.total_size == 0:
|
||||
raise DownloadError(
|
||||
f"Unable to determine file size for {task.output_file}. "
|
||||
"The server might not support range requests."
|
||||
)
|
||||
|
||||
# Update the progress bar's total size once we know it
|
||||
if task.task_id is not None:
|
||||
self.progress.update(task.task_id, total=task.total_size)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
|
||||
raise
|
||||
|
||||
def verify_file_integrity(self, task: DownloadTask) -> bool:
|
||||
if not os.path.exists(task.output_file):
|
||||
return False
|
||||
return os.path.getsize(task.output_file) == task.total_size
|
||||
|
||||
async def download_chunk(self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int) -> None:
|
||||
async def _download_chunk():
|
||||
headers = {"Range": f"bytes={start}-{end}"}
|
||||
async with client.stream("GET", task.url, headers=headers, **self.client_options) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
with open(task.output_file, "ab") as file:
|
||||
file.seek(start)
|
||||
async for chunk in response.aiter_bytes(self.buffer_size):
|
||||
file.write(chunk)
|
||||
task.downloaded_size += len(chunk)
|
||||
self.progress.update(
|
||||
task.task_id,
|
||||
completed=task.downloaded_size,
|
||||
)
|
||||
|
||||
try:
|
||||
await self.retry_with_exponential_backoff(task, _download_chunk)
|
||||
except Exception as e:
|
||||
raise DownloadError(
|
||||
f"Failed to download chunk {start}-{end} after {task.max_retries} attempts: {str(e)}"
|
||||
) from e
|
||||
|
||||
async def prepare_download(self, task: DownloadTask) -> None:
|
||||
output_dir = os.path.dirname(task.output_file)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if os.path.exists(task.output_file):
|
||||
task.downloaded_size = os.path.getsize(task.output_file)
|
||||
|
||||
async def download_file(self, task: DownloadTask) -> None:
|
||||
try:
|
||||
async with httpx.AsyncClient(**self.client_options) as client:
|
||||
await self.get_file_info(client, task)
|
||||
|
||||
# Check if file is already downloaded
|
||||
if os.path.exists(task.output_file):
|
||||
if self.verify_file_integrity(task):
|
||||
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
|
||||
self.progress.update(task.task_id, completed=task.total_size)
|
||||
return
|
||||
|
||||
await self.prepare_download(task)
|
||||
|
||||
try:
|
||||
# Split the remaining download into chunks
|
||||
chunk_size = 27_000_000_000 # Cloudfront max chunk size
|
||||
chunks = []
|
||||
|
||||
current_pos = task.downloaded_size
|
||||
while current_pos < task.total_size:
|
||||
chunk_end = min(current_pos + chunk_size - 1, task.total_size - 1)
|
||||
chunks.append((current_pos, chunk_end))
|
||||
current_pos = chunk_end + 1
|
||||
|
||||
# Download chunks in sequence
|
||||
for chunk_start, chunk_end in chunks:
|
||||
await self.download_chunk(client, task, chunk_start, chunk_end)
|
||||
|
||||
except Exception as e:
|
||||
raise DownloadError(f"Download failed: {str(e)}") from e
|
||||
|
||||
except Exception as e:
|
||||
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
||||
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
||||
|
||||
def has_disk_space(self, tasks: list[DownloadTask]) -> bool:
|
||||
try:
|
||||
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
|
||||
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
|
||||
free_space = shutil.disk_usage(dir_path).free
|
||||
|
||||
# Add 10% buffer for safety
|
||||
required_space = int(total_remaining_size * 1.1)
|
||||
|
||||
if free_space < required_space:
|
||||
self.console.print(
|
||||
f"[red]Not enough disk space. Required: {required_space // (1024 * 1024)} MB, "
|
||||
f"Available: {free_space // (1024 * 1024)} MB[/red]"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
|
||||
|
||||
async def download_all(self, tasks: list[DownloadTask]) -> None:
|
||||
if not tasks:
|
||||
raise ValueError("No download tasks provided")
|
||||
|
||||
if not os.environ.get("LLAMA_DOWNLOAD_NO_SPACE_CHECK") and not self.has_disk_space(tasks):
|
||||
raise DownloadError("Insufficient disk space for downloads")
|
||||
|
||||
failed_tasks = []
|
||||
|
||||
with self.progress:
|
||||
for task in tasks:
|
||||
desc = f"Downloading {Path(task.output_file).name}"
|
||||
task.task_id = self.progress.add_task(desc, total=task.total_size, completed=task.downloaded_size)
|
||||
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
|
||||
|
||||
async def download_with_semaphore(task: DownloadTask):
|
||||
async with semaphore:
|
||||
try:
|
||||
await self.download_file(task)
|
||||
except Exception as e:
|
||||
failed_tasks.append((task, str(e)))
|
||||
|
||||
await asyncio.gather(*(download_with_semaphore(task) for task in tasks))
|
||||
|
||||
if failed_tasks:
|
||||
self.console.print("\n[red]Some downloads failed:[/red]")
|
||||
for task, error in failed_tasks:
|
||||
self.console.print(f"[red]- {Path(task.output_file).name}: {error}[/red]")
|
||||
raise DownloadError(f"{len(failed_tasks)} downloads failed")
|
||||
|
||||
|
||||
def _hf_download(
|
||||
model: "Model",
|
||||
hf_token: str,
|
||||
ignore_patterns: str,
|
||||
parser: argparse.ArgumentParser,
|
||||
):
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
|
||||
repo_id = model.huggingface_repo
|
||||
if repo_id is None:
|
||||
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
||||
|
||||
output_dir = model_local_dir(model.descriptor())
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
try:
|
||||
true_output_dir = snapshot_download(
|
||||
repo_id,
|
||||
local_dir=output_dir,
|
||||
ignore_patterns=ignore_patterns,
|
||||
token=hf_token,
|
||||
library_name="llama-stack",
|
||||
)
|
||||
except GatedRepoError:
|
||||
parser.error(
|
||||
"It looks like you are trying to access a gated repository. Please ensure you "
|
||||
"have access to the repository and have provided the proper Hugging Face API token "
|
||||
"using the option `--hf-token` or by running `huggingface-cli login`."
|
||||
"You can find your token by visiting https://huggingface.co/settings/tokens"
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub or incorrect Hugging Face token.")
|
||||
except Exception as e:
|
||||
parser.error(e)
|
||||
|
||||
print(f"\nSuccessfully downloaded model to {true_output_dir}")
|
||||
|
||||
|
||||
def _meta_download(
|
||||
model: "Model",
|
||||
model_id: str,
|
||||
meta_url: str,
|
||||
info: "LlamaDownloadInfo",
|
||||
max_concurrent_downloads: int,
|
||||
):
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
|
||||
output_dir = Path(model_local_dir(model.descriptor()))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Create download tasks for each file
|
||||
tasks = []
|
||||
for f in info.files:
|
||||
output_file = str(output_dir / f)
|
||||
url = meta_url.replace("*", f"{info.folder}/{f}")
|
||||
total_size = info.pth_size if "consolidated" in f else 0
|
||||
tasks.append(DownloadTask(url=url, output_file=output_file, total_size=total_size, max_retries=3))
|
||||
|
||||
# Initialize and run parallel downloader
|
||||
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||
asyncio.run(downloader.download_all(tasks))
|
||||
|
||||
cprint(f"\nSuccessfully downloaded model to {output_dir}", color="green", file=sys.stderr)
|
||||
cprint(
|
||||
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
cprint(
|
||||
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
|
||||
color="yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
class ModelEntry(BaseModel):
|
||||
model_id: str
|
||||
files: dict[str, str]
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class Manifest(BaseModel):
|
||||
models: list[ModelEntry]
|
||||
expires_on: datetime
|
||||
|
||||
|
||||
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
|
||||
with open(manifest_file) as f:
|
||||
d = json.load(f)
|
||||
manifest = Manifest(**d)
|
||||
|
||||
if datetime.now(UTC) > manifest.expires_on.astimezone(UTC):
|
||||
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||
|
||||
console = Console()
|
||||
for entry in manifest.models:
|
||||
console.print(f"[blue]Downloading model {entry.model_id}...[/blue]")
|
||||
output_dir = Path(model_local_dir(entry.model_id))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if any(output_dir.iterdir()):
|
||||
console.print(f"[yellow]Output directory {output_dir} is not empty.[/yellow]")
|
||||
|
||||
while True:
|
||||
resp = input("Do you want to (C)ontinue download or (R)estart completely? (continue/restart): ")
|
||||
if resp.lower() in ["restart", "r"]:
|
||||
shutil.rmtree(output_dir)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
break
|
||||
elif resp.lower() in ["continue", "c"]:
|
||||
console.print("[blue]Continuing download...[/blue]")
|
||||
break
|
||||
else:
|
||||
console.print("[red]Invalid response. Please try again.[/red]")
|
||||
|
||||
# Create download tasks for all files in the manifest
|
||||
tasks = [
|
||||
DownloadTask(url=url, output_file=str(output_dir / fname), max_retries=3)
|
||||
for fname, url in entry.files.items()
|
||||
]
|
||||
|
||||
# Initialize and run parallel downloader
|
||||
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
|
||||
asyncio.run(downloader.download_all(tasks))
|
||||
|
||||
|
||||
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
"""Main download command handler"""
|
||||
try:
|
||||
if args.manifest_file:
|
||||
_download_from_manifest(args.manifest_file, args.max_parallel)
|
||||
return
|
||||
|
||||
if args.model_id is None:
|
||||
parser.error("Please provide a model id")
|
||||
return
|
||||
|
||||
# Handle comma-separated model IDs
|
||||
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
|
||||
|
||||
from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model
|
||||
|
||||
from .model.safety_models import (
|
||||
prompt_guard_download_info_map,
|
||||
prompt_guard_model_sku_map,
|
||||
)
|
||||
|
||||
prompt_guard_model_sku_map = prompt_guard_model_sku_map()
|
||||
prompt_guard_download_info_map = prompt_guard_download_info_map()
|
||||
|
||||
for model_id in model_ids:
|
||||
if model_id in prompt_guard_model_sku_map.keys():
|
||||
model = prompt_guard_model_sku_map[model_id]
|
||||
info = prompt_guard_download_info_map[model_id]
|
||||
else:
|
||||
model = resolve_model(model_id)
|
||||
if model is None:
|
||||
parser.error(f"Model {model_id} not found")
|
||||
continue
|
||||
info = llama_meta_net_info(model)
|
||||
|
||||
if args.source == "huggingface":
|
||||
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
|
||||
else:
|
||||
meta_url = args.meta_url or input(
|
||||
f"Please provide the signed URL for model {model_id} you received via email "
|
||||
f"after visiting https://www.llama.com/llama-downloads/ "
|
||||
f"(e.g., https://llama3-1.llamameta.net/*?Policy...): "
|
||||
)
|
||||
if "llamameta.net" not in meta_url:
|
||||
parser.error("Invalid Meta URL provided")
|
||||
_meta_download(model, model_id, meta_url, info, args.max_parallel)
|
||||
|
||||
except Exception as e:
|
||||
parser.error(f"Download failed: {str(e)}")
|
|
@ -6,11 +6,8 @@
|
|||
|
||||
import argparse
|
||||
|
||||
from .download import Download
|
||||
from .model import ModelParser
|
||||
from .stack import StackParser
|
||||
from .stack.utils import print_subcommand_description
|
||||
from .verify_download import VerifyDownload
|
||||
|
||||
|
||||
class LlamaCLIParser:
|
||||
|
@ -30,10 +27,7 @@ class LlamaCLIParser:
|
|||
subparsers = self.parser.add_subparsers(title="subcommands")
|
||||
|
||||
# Add sub-commands
|
||||
ModelParser.create(subparsers)
|
||||
StackParser.create(subparsers)
|
||||
Download.create(subparsers)
|
||||
VerifyDownload.create(subparsers)
|
||||
|
||||
print_subcommand_description(self.parser, subparsers)
|
||||
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .model import ModelParser # noqa
|
|
@ -1,70 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
|
||||
class ModelDescribe(Subcommand):
|
||||
"""Show details about a model"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"describe",
|
||||
prog="llama model describe",
|
||||
description="Show details about a llama model",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_model_describe_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"-m",
|
||||
"--model-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="See `llama model list` or `llama model list --show-all` for the list of available models",
|
||||
)
|
||||
|
||||
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
|
||||
from .safety_models import prompt_guard_model_sku_map
|
||||
|
||||
prompt_guard_model_map = prompt_guard_model_sku_map()
|
||||
if args.model_id in prompt_guard_model_map.keys():
|
||||
model = prompt_guard_model_map[args.model_id]
|
||||
else:
|
||||
model = resolve_model(args.model_id)
|
||||
|
||||
if model is None:
|
||||
self.parser.error(
|
||||
f"Model {args.model_id} not found; try 'llama model list' for a list of available models."
|
||||
)
|
||||
return
|
||||
|
||||
headers = [
|
||||
"Model",
|
||||
model.descriptor(),
|
||||
]
|
||||
|
||||
rows = [
|
||||
("Hugging Face ID", model.huggingface_repo or "<Not Available>"),
|
||||
("Description", model.description),
|
||||
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
|
||||
("Weights format", model.quantization_format.value),
|
||||
("Model params.json", json.dumps(model.arch_args, indent=4)),
|
||||
]
|
||||
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
|
@ -1,24 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
class ModelDownload(Subcommand):
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"download",
|
||||
prog="llama model download",
|
||||
description="Download a model from llama.meta.com or Hugging Face Hub",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
from llama_stack.cli.download import setup_download_parser
|
||||
|
||||
setup_download_parser(self.parser)
|
|
@ -1,119 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
|
||||
|
||||
def _get_model_size(model_dir):
|
||||
return sum(f.stat().st_size for f in Path(model_dir).rglob("*") if f.is_file())
|
||||
|
||||
|
||||
def _convert_to_model_descriptor(model):
|
||||
for m in all_registered_models():
|
||||
if model == m.descriptor().replace(":", "-"):
|
||||
return str(m.descriptor())
|
||||
return str(model)
|
||||
|
||||
|
||||
def _run_model_list_downloaded_cmd() -> None:
|
||||
headers = ["Model", "Size", "Modified Time"]
|
||||
|
||||
rows = []
|
||||
for model in os.listdir(DEFAULT_CHECKPOINT_DIR):
|
||||
abs_path = os.path.join(DEFAULT_CHECKPOINT_DIR, model)
|
||||
space_usage = _get_model_size(abs_path)
|
||||
model_size = f"{space_usage / (1024**3):.2f} GB"
|
||||
modified_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(abs_path)))
|
||||
rows.append(
|
||||
[
|
||||
_convert_to_model_descriptor(model),
|
||||
model_size,
|
||||
modified_time,
|
||||
]
|
||||
)
|
||||
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
||||
|
||||
|
||||
class ModelList(Subcommand):
|
||||
"""List available llama models"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"list",
|
||||
prog="llama model list",
|
||||
description="Show available llama models",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_model_list_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"--show-all",
|
||||
action="store_true",
|
||||
help="Show all models (not just defaults)",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--downloaded",
|
||||
action="store_true",
|
||||
help="List the downloaded models",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-s",
|
||||
"--search",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Search for the input string as a substring in the model descriptor(ID)",
|
||||
)
|
||||
|
||||
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
from .safety_models import prompt_guard_model_skus
|
||||
|
||||
if args.downloaded:
|
||||
return _run_model_list_downloaded_cmd()
|
||||
|
||||
headers = [
|
||||
"Model Descriptor(ID)",
|
||||
"Hugging Face Repo",
|
||||
"Context Length",
|
||||
]
|
||||
|
||||
rows = []
|
||||
for model in all_registered_models() + prompt_guard_model_skus():
|
||||
if not args.show_all and not model.is_featured:
|
||||
continue
|
||||
|
||||
descriptor = model.descriptor()
|
||||
if not args.search or args.search.lower() in descriptor.lower():
|
||||
rows.append(
|
||||
[
|
||||
descriptor,
|
||||
model.huggingface_repo,
|
||||
f"{model.max_seq_length // 1024}K",
|
||||
]
|
||||
)
|
||||
if len(rows) == 0:
|
||||
print(f"Did not find any model matching `{args.search}`.")
|
||||
else:
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
|
@ -1,43 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.model.describe import ModelDescribe
|
||||
from llama_stack.cli.model.download import ModelDownload
|
||||
from llama_stack.cli.model.list import ModelList
|
||||
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
||||
from llama_stack.cli.model.remove import ModelRemove
|
||||
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
||||
from llama_stack.cli.stack.utils import print_subcommand_description
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
class ModelParser(Subcommand):
|
||||
"""Llama cli for model interface apis"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"model",
|
||||
prog="llama model",
|
||||
description="Work with llama models",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
self.parser.set_defaults(func=lambda args: self.parser.print_help())
|
||||
|
||||
subparsers = self.parser.add_subparsers(title="model_subcommands")
|
||||
|
||||
# Add sub-commands
|
||||
ModelDownload.create(subparsers)
|
||||
ModelList.create(subparsers)
|
||||
ModelPromptFormat.create(subparsers)
|
||||
ModelDescribe.create(subparsers)
|
||||
ModelVerifyDownload.create(subparsers)
|
||||
ModelRemove.create(subparsers)
|
||||
|
||||
print_subcommand_description(self.parser, subparsers)
|
|
@ -1,133 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import textwrap
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class ModelPromptFormat(Subcommand):
|
||||
"""Llama model cli for describe a model prompt format (message formats)"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"prompt-format",
|
||||
prog="llama model prompt-format",
|
||||
description="Show llama model message formats",
|
||||
epilog=textwrap.dedent(
|
||||
"""
|
||||
Example:
|
||||
llama model prompt-format <options>
|
||||
"""
|
||||
),
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_model_template_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"-m",
|
||||
"--model-name",
|
||||
type=str,
|
||||
help="Example: Llama3.1-8B or Llama3.2-11B-Vision, etc\n"
|
||||
"(Run `llama model list` to see a list of valid model names)",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-l",
|
||||
"--list",
|
||||
action="store_true",
|
||||
help="List all available models",
|
||||
)
|
||||
|
||||
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
||||
import importlib.resources
|
||||
|
||||
# Only Llama 3.1 and 3.2 are supported
|
||||
supported_model_ids = [
|
||||
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||
]
|
||||
|
||||
model_list = [m.value for m in supported_model_ids]
|
||||
|
||||
if args.list:
|
||||
headers = ["Model(s)"]
|
||||
rows = []
|
||||
for m in model_list:
|
||||
rows.append(
|
||||
[
|
||||
m,
|
||||
]
|
||||
)
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
separate_rows=True,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
model_id = CoreModelId(args.model_name)
|
||||
except ValueError:
|
||||
self.parser.error(
|
||||
f"{args.model_name} is not a valid Model. Choose one from the list of valid models. "
|
||||
f"Run `llama model list` to see the valid model names."
|
||||
)
|
||||
|
||||
if model_id not in supported_model_ids:
|
||||
self.parser.error(
|
||||
f"{model_id} is not a valid Model. Choose one from the list of valid models. "
|
||||
f"Run `llama model list` to see the valid model names."
|
||||
)
|
||||
|
||||
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
|
||||
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
|
||||
llama_3_2_vision_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "vision_prompt_format.md"
|
||||
if model_family(model_id) == ModelFamily.llama3_1:
|
||||
with importlib.resources.as_file(llama_3_1_file) as f:
|
||||
content = f.open("r").read()
|
||||
elif model_family(model_id) == ModelFamily.llama3_2:
|
||||
if is_multimodal(model_id):
|
||||
with importlib.resources.as_file(llama_3_2_vision_file) as f:
|
||||
content = f.open("r").read()
|
||||
else:
|
||||
with importlib.resources.as_file(llama_3_2_text_file) as f:
|
||||
content = f.open("r").read()
|
||||
|
||||
render_markdown_to_pager(content)
|
||||
|
||||
|
||||
def render_markdown_to_pager(markdown_content: str):
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.style import Style
|
||||
from rich.text import Text
|
||||
|
||||
class LeftAlignedHeaderMarkdown(Markdown):
|
||||
def parse_header(self, token):
|
||||
level = token.type.count("h")
|
||||
content = Text(token.content)
|
||||
header_style = Style(color="bright_blue", bold=True)
|
||||
header = Text(f"{'#' * level} ", style=header_style) + content
|
||||
self.add_text(header)
|
||||
|
||||
# Render the Markdown
|
||||
md = LeftAlignedHeaderMarkdown(markdown_content)
|
||||
|
||||
# Capture the rendered output
|
||||
output = StringIO()
|
||||
console = Console(file=output, force_terminal=True, width=100) # Set a fixed width
|
||||
console.print(md)
|
||||
rendered_content = output.getvalue()
|
||||
print(rendered_content)
|
|
@ -1,68 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
|
||||
class ModelRemove(Subcommand):
|
||||
"""Remove the downloaded llama model"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"remove",
|
||||
prog="llama model remove",
|
||||
description="Remove the downloaded llama model",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_model_remove_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
required=True,
|
||||
help="Specify the llama downloaded model name, see `llama model list --downloaded`",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-f",
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Used to forcefully remove the llama model from the storage without further confirmation",
|
||||
)
|
||||
|
||||
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
|
||||
from .safety_models import prompt_guard_model_sku_map
|
||||
|
||||
prompt_guard_model_map = prompt_guard_model_sku_map()
|
||||
|
||||
if args.model in prompt_guard_model_map.keys():
|
||||
model = prompt_guard_model_map[args.model]
|
||||
else:
|
||||
model = resolve_model(args.model)
|
||||
|
||||
model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model.replace(":", "-"))
|
||||
|
||||
if model is None or not os.path.isdir(model_path):
|
||||
print(f"'{args.model}' is not a valid llama model or does not exist.")
|
||||
return
|
||||
|
||||
if args.force:
|
||||
shutil.rmtree(model_path)
|
||||
print(f"{args.model} removed.")
|
||||
else:
|
||||
if input(f"Are you sure you want to remove {args.model}? (y/n): ").strip().lower() == "y":
|
||||
shutil.rmtree(model_path)
|
||||
print(f"{args.model} removed.")
|
||||
else:
|
||||
print("Removal aborted.")
|
|
@ -1,64 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||
from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
|
||||
|
||||
|
||||
class PromptGuardModel(BaseModel):
|
||||
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
||||
|
||||
model_id: str
|
||||
huggingface_repo: str
|
||||
description: str = "Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
|
||||
is_featured: bool = False
|
||||
max_seq_length: int = 512
|
||||
is_instruct_model: bool = False
|
||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||
arch_args: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def descriptor(self) -> str:
|
||||
return self.model_id
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
def prompt_guard_model_skus():
|
||||
return [
|
||||
PromptGuardModel(model_id="Prompt-Guard-86M", huggingface_repo="meta-llama/Prompt-Guard-86M"),
|
||||
PromptGuardModel(
|
||||
model_id="Llama-Prompt-Guard-2-86M",
|
||||
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-86M",
|
||||
),
|
||||
PromptGuardModel(
|
||||
model_id="Llama-Prompt-Guard-2-22M",
|
||||
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-22M",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def prompt_guard_model_sku_map() -> dict[str, Any]:
|
||||
return {model.model_id: model for model in prompt_guard_model_skus()}
|
||||
|
||||
|
||||
def prompt_guard_download_info_map() -> dict[str, LlamaDownloadInfo]:
|
||||
return {
|
||||
model.model_id: LlamaDownloadInfo(
|
||||
folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id,
|
||||
files=[
|
||||
"model.safetensors",
|
||||
"special_tokens_map.json",
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
],
|
||||
pth_size=1,
|
||||
)
|
||||
for model in prompt_guard_model_skus()
|
||||
}
|
|
@ -1,24 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
class ModelVerifyDownload(Subcommand):
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"verify-download",
|
||||
prog="llama model verify-download",
|
||||
description="Verify the downloaded checkpoints' checksums for models downloaded from Meta",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
from llama_stack.cli.verify_download import setup_verify_download_parser
|
||||
|
||||
setup_verify_download_parser(self.parser)
|
|
@ -439,12 +439,24 @@ def _run_stack_build_command_from_build_config(
|
|||
|
||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
||||
cprint(
|
||||
"You can run the new Llama Stack distro via: "
|
||||
+ colored(f"llama stack run {run_config_file} --image-type {build_config.image_type}", "blue"),
|
||||
color="green",
|
||||
file=sys.stderr,
|
||||
)
|
||||
if build_config.image_type == LlamaStackImageType.VENV:
|
||||
cprint(
|
||||
"You can run the new Llama Stack distro (after activating "
|
||||
+ colored(image_name, "cyan")
|
||||
+ ") via: "
|
||||
+ colored(f"llama stack run {run_config_file}", "blue"),
|
||||
color="green",
|
||||
file=sys.stderr,
|
||||
)
|
||||
elif build_config.image_type == LlamaStackImageType.CONTAINER:
|
||||
cprint(
|
||||
"You can run the container with: "
|
||||
+ colored(
|
||||
f"docker run -p 8321:8321 -v ~/.llama:/root/.llama localhost/{image_name} --port 8321", "blue"
|
||||
),
|
||||
color="green",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return distro_path
|
||||
else:
|
||||
return _generate_run_config(build_config, build_dir, image_name)
|
||||
|
|
|
@ -6,11 +6,18 @@
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import ssl
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import uvicorn
|
||||
import yaml
|
||||
|
||||
from llama_stack.cli.stack.utils import ImageType
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.core.datatypes import LoggingConfig, StackRunConfig
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
@ -48,18 +55,12 @@ class StackRun(Subcommand):
|
|||
"--image-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the image to run. Defaults to the current environment",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
|
||||
metavar="KEY=VALUE",
|
||||
help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--image-type",
|
||||
type=str,
|
||||
help="Image Type used during the build. This can be only venv.",
|
||||
help="[DEPRECATED] This flag is no longer supported. Please activate your virtual environment before running.",
|
||||
choices=[e.value for e in ImageType if e.value != ImageType.CONTAINER.value],
|
||||
)
|
||||
self.parser.add_argument(
|
||||
|
@ -68,48 +69,22 @@ class StackRun(Subcommand):
|
|||
help="Start the UI server",
|
||||
)
|
||||
|
||||
def _resolve_config_and_distro(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
|
||||
"""Resolve config file path and distribution name from args.config"""
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
if not args.config:
|
||||
return None, None
|
||||
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
distro_name = None
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if this is a distribution
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "distributions" / args.config / "run.yaml"
|
||||
if config_file.exists():
|
||||
distro_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
|
||||
if not config_file.is_file():
|
||||
self.parser.error(
|
||||
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
||||
)
|
||||
|
||||
return config_file, distro_name
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
||||
from llama_stack.core.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.core.utils.exec import formulate_run_args, run_command
|
||||
|
||||
if args.image_type or args.image_name:
|
||||
self.parser.error(
|
||||
"The --image-type and --image-name flags are no longer supported.\n\n"
|
||||
"Please activate your virtual environment manually before running `llama stack run`.\n\n"
|
||||
"For example:\n"
|
||||
" source /path/to/venv/bin/activate\n"
|
||||
" llama stack run <config>\n"
|
||||
)
|
||||
|
||||
if args.enable_ui:
|
||||
self._start_ui_development_server(args.port)
|
||||
image_type, image_name = args.image_type, args.image_name
|
||||
|
||||
if args.config:
|
||||
try:
|
||||
|
@ -121,10 +96,6 @@ class StackRun(Subcommand):
|
|||
else:
|
||||
config_file = None
|
||||
|
||||
# Check if config is required based on image type
|
||||
if image_type == ImageType.VENV.value and not config_file:
|
||||
self.parser.error("Config file is required for venv environment")
|
||||
|
||||
if config_file:
|
||||
logger.info(f"Using run configuration: {config_file}")
|
||||
|
||||
|
@ -139,50 +110,67 @@ class StackRun(Subcommand):
|
|||
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||
except AttributeError as e:
|
||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||
|
||||
self._uvicorn_run(config_file, args)
|
||||
|
||||
def _uvicorn_run(self, config_file: Path | None, args: argparse.Namespace) -> None:
|
||||
if not config_file:
|
||||
self.parser.error("Config file is required")
|
||||
|
||||
config_file = resolve_config_or_distro(str(config_file), Mode.RUN)
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
port = args.port or config.server.port
|
||||
host = config.server.host or ["::", "0.0.0.0"]
|
||||
|
||||
# Set the config file in environment so create_app can find it
|
||||
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
|
||||
|
||||
uvicorn_config = {
|
||||
"factory": True,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
}
|
||||
|
||||
keyfile = config.server.tls_keyfile
|
||||
certfile = config.server.tls_certfile
|
||||
if keyfile and certfile:
|
||||
uvicorn_config["ssl_keyfile"] = config.server.tls_keyfile
|
||||
uvicorn_config["ssl_certfile"] = config.server.tls_certfile
|
||||
if config.server.tls_cafile:
|
||||
uvicorn_config["ssl_ca_certs"] = config.server.tls_cafile
|
||||
uvicorn_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
||||
|
||||
logger.info(
|
||||
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
||||
)
|
||||
else:
|
||||
config = None
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
# If neither image type nor image name is provided, assume the server should be run directly
|
||||
# using the current environment packages.
|
||||
if not image_type and not image_name:
|
||||
logger.info("No image type or image name provided. Assuming environment packages.")
|
||||
from llama_stack.core.server.server import main as server_main
|
||||
logger.info(f"Listening on {host}:{port}")
|
||||
|
||||
# Build the server args from the current args passed to the CLI
|
||||
server_args = argparse.Namespace()
|
||||
for arg in vars(args):
|
||||
# If this is a function, avoid passing it
|
||||
# "args" contains:
|
||||
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
|
||||
if callable(getattr(args, arg)):
|
||||
continue
|
||||
if arg == "config":
|
||||
server_args.config = str(config_file)
|
||||
else:
|
||||
setattr(server_args, arg, getattr(args, arg))
|
||||
|
||||
# Run the server
|
||||
server_main(server_args)
|
||||
else:
|
||||
run_args = formulate_run_args(image_type, image_name)
|
||||
|
||||
run_args.extend([str(args.port)])
|
||||
|
||||
if config_file:
|
||||
run_args.extend(["--config", str(config_file)])
|
||||
|
||||
if args.env:
|
||||
for env_var in args.env:
|
||||
if "=" not in env_var:
|
||||
self.parser.error(f"Environment variable '{env_var}' must be in KEY=VALUE format")
|
||||
return
|
||||
key, value = env_var.split("=", 1) # split on first = only
|
||||
if not key:
|
||||
self.parser.error(f"Environment variable '{env_var}' has empty key")
|
||||
return
|
||||
run_args.extend(["--env", f"{key}={value}"])
|
||||
|
||||
run_command(run_args)
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
# stack trace when using Ctrl+C or kill -2 (SIGINT).
|
||||
# SIGTERM (kill -15) works fine without this because Python doesn't
|
||||
# have a default handler for it.
|
||||
#
|
||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
uvicorn.run("llama_stack.core.server.server:create_app", **uvicorn_config)
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
|
||||
def _start_ui_development_server(self, stack_server_port: int):
|
||||
logger.info("Attempting to start UI development server...")
|
||||
|
|
|
@ -1,141 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerificationResult:
|
||||
filename: str
|
||||
expected_hash: str
|
||||
actual_hash: str | None
|
||||
exists: bool
|
||||
matches: bool
|
||||
|
||||
|
||||
class VerifyDownload(Subcommand):
|
||||
"""Llama cli for verifying downloaded model files"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"verify-download",
|
||||
prog="llama verify-download",
|
||||
description="Verify integrity of downloaded model files",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
setup_verify_download_parser(self.parser)
|
||||
|
||||
|
||||
def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
required=True,
|
||||
help="Model ID to verify (only for models downloaded from Meta)",
|
||||
)
|
||||
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
|
||||
|
||||
|
||||
def calculate_sha256(filepath: Path, chunk_size: int = 8192) -> str:
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(filepath, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b""):
|
||||
sha256_hash.update(chunk)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def load_checksums(checklist_path: Path) -> dict[str, str]:
|
||||
checksums = {}
|
||||
with open(checklist_path) as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
sha256sum, filepath = line.strip().split(" ", 1)
|
||||
# Remove leading './' if present
|
||||
filepath = filepath.lstrip("./")
|
||||
checksums[filepath] = sha256sum
|
||||
return checksums
|
||||
|
||||
|
||||
def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -> list[VerificationResult]:
|
||||
results = []
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
console=console,
|
||||
) as progress:
|
||||
for filepath, expected_hash in checksums.items():
|
||||
full_path = model_dir / filepath
|
||||
task_id = progress.add_task(f"Verifying {filepath}...", total=None)
|
||||
|
||||
exists = full_path.exists()
|
||||
actual_hash = None
|
||||
matches = False
|
||||
|
||||
if exists:
|
||||
actual_hash = calculate_sha256(full_path)
|
||||
matches = actual_hash == expected_hash
|
||||
|
||||
results.append(
|
||||
VerificationResult(
|
||||
filename=filepath,
|
||||
expected_hash=expected_hash,
|
||||
actual_hash=actual_hash,
|
||||
exists=exists,
|
||||
matches=matches,
|
||||
)
|
||||
)
|
||||
|
||||
progress.remove_task(task_id)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
|
||||
console = Console()
|
||||
model_dir = Path(model_local_dir(args.model_id))
|
||||
checklist_path = model_dir / "checklist.chk"
|
||||
|
||||
if not model_dir.exists():
|
||||
parser.error(f"Model directory not found: {model_dir}")
|
||||
|
||||
if not checklist_path.exists():
|
||||
parser.error(f"Checklist file not found: {checklist_path}")
|
||||
|
||||
checksums = load_checksums(checklist_path)
|
||||
results = verify_files(model_dir, checksums, console)
|
||||
|
||||
# Print results
|
||||
console.print("\nVerification Results:")
|
||||
|
||||
all_good = True
|
||||
for result in results:
|
||||
if not result.exists:
|
||||
console.print(f"[red]❌ {result.filename}: File not found[/red]")
|
||||
all_good = False
|
||||
elif not result.matches:
|
||||
console.print(
|
||||
f"[red]❌ {result.filename}: Hash mismatch[/red]\n"
|
||||
f" Expected: {result.expected_hash}\n"
|
||||
f" Got: {result.actual_hash}"
|
||||
)
|
||||
all_good = False
|
||||
else:
|
||||
console.print(f"[green]✓ {result.filename}: Verified[/green]")
|
||||
|
||||
if all_good:
|
||||
console.print("\n[green]All files verified successfully![/green]")
|
|
@ -324,14 +324,14 @@ fi
|
|||
RUN pip uninstall -y uv
|
||||
EOF
|
||||
|
||||
# If a run config is provided, we use the --config flag
|
||||
# If a run config is provided, we use the llama stack CLI
|
||||
if [[ -n "$run_config" ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$RUN_CONFIG_PATH"]
|
||||
ENTRYPOINT ["llama", "stack", "run", "$RUN_CONFIG_PATH"]
|
||||
EOF
|
||||
elif [[ "$distro_or_config" != *.yaml ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$distro_or_config"]
|
||||
ENTRYPOINT ["llama", "stack", "run", "$distro_or_config"]
|
||||
EOF
|
||||
fi
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import (
|
|||
sqlstore_impl,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai::conversations")
|
||||
logger = get_logger(name=__name__, category="openai_conversations")
|
||||
|
||||
|
||||
class ConversationServiceConfig(BaseModel):
|
||||
|
@ -196,12 +196,15 @@ class ConversationServiceImpl(Conversations):
|
|||
await self._get_validated_conversation(conversation_id)
|
||||
|
||||
created_items = []
|
||||
created_at = int(time.time())
|
||||
base_time = int(time.time())
|
||||
|
||||
for item in items:
|
||||
for i, item in enumerate(items):
|
||||
item_dict = item.model_dump()
|
||||
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||
|
||||
# make each timestamp unique to maintain order
|
||||
created_at = base_time + i
|
||||
|
||||
item_record = {
|
||||
"id": item_id,
|
||||
"conversation_id": conversation_id,
|
||||
|
|
|
@ -47,10 +47,6 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
|||
routing_table_api=Api.shields,
|
||||
router_api=Api.safety,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.vector_dbs,
|
||||
router_api=Api.vector_io,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.datasets,
|
||||
router_api=Api.datasetio,
|
||||
|
@ -243,6 +239,7 @@ def get_external_providers_from_module(
|
|||
spec = module.get_provider_spec()
|
||||
else:
|
||||
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
spec = ProviderSpec(
|
||||
api=Api(provider_api),
|
||||
provider_type=provider.provider_type,
|
||||
|
@ -251,9 +248,20 @@ def get_external_providers_from_module(
|
|||
config_class="",
|
||||
)
|
||||
provider_type = provider.provider_type
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
# return a partially filled out spec that the build script will populate.
|
||||
registry[Api(provider_api)][provider_type] = spec
|
||||
if isinstance(spec, list):
|
||||
# optionally allow people to pass inline and remote provider specs as a returned list.
|
||||
# with the old method, users could pass in directories of specs using overlapping code
|
||||
# we want to ensure we preserve that flexibility in this method.
|
||||
logger.info(
|
||||
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
|
||||
)
|
||||
for provider_spec in spec:
|
||||
if provider_spec.provider_type != provider.provider_type:
|
||||
continue
|
||||
logger.info(f"Adding {provider.provider_type} to registry")
|
||||
registry[Api(provider_api)][provider.provider_type] = provider_spec
|
||||
else:
|
||||
registry[Api(provider_api)][provider_type] = spec
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
||||
|
|
42
llama_stack/core/id_generation.py
Normal file
42
llama_stack/core/id_generation.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
IdFactory = Callable[[], str]
|
||||
IdOverride = Callable[[str, IdFactory], str]
|
||||
|
||||
_id_override: IdOverride | None = None
|
||||
|
||||
|
||||
def generate_object_id(kind: str, factory: IdFactory) -> str:
|
||||
"""Generate an identifier for the given kind using the provided factory.
|
||||
|
||||
Allows tests to override ID generation deterministically by installing an
|
||||
override callback via :func:`set_id_override`.
|
||||
"""
|
||||
|
||||
override = _id_override
|
||||
if override is not None:
|
||||
return override(kind, factory)
|
||||
return factory()
|
||||
|
||||
|
||||
def set_id_override(override: IdOverride) -> IdOverride | None:
|
||||
"""Install an override used to generate deterministic identifiers."""
|
||||
|
||||
global _id_override
|
||||
|
||||
previous = _id_override
|
||||
_id_override = override
|
||||
return previous
|
||||
|
||||
|
||||
def reset_id_override(previous: IdOverride | None) -> None:
|
||||
"""Restore the previous override returned by :func:`set_id_override`."""
|
||||
|
||||
global _id_override
|
||||
_id_override = previous
|
|
@ -54,6 +54,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
@ -383,7 +384,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
|
||||
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
@ -446,7 +447,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
body = self._convert_body(path, options.method, body)
|
||||
# Prepare body for the function call (handles both Pydantic and traditional params)
|
||||
body = self._convert_body(func, body)
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
@ -493,21 +495,32 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return await response.parse()
|
||||
|
||||
def _convert_body(
|
||||
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
|
||||
) -> dict:
|
||||
def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||
exclude_params = exclude_params or set()
|
||||
|
||||
func, _, _, _ = find_matching_route(method, path, self.route_impls)
|
||||
sig = inspect.signature(func)
|
||||
params_list = [p for p in sig.parameters.values() if p.name != "self"]
|
||||
# Flatten if there's a single unwrapped body parameter (BaseModel or Annotated[BaseModel, Body(embed=False)])
|
||||
if len(params_list) == 1:
|
||||
param = params_list[0]
|
||||
param_type = param.annotation
|
||||
if is_unwrapped_body_param(param_type):
|
||||
base_type = get_args(param_type)[0]
|
||||
return {param.name: base_type(**body)}
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
|
||||
|
||||
# Check if there's an unwrapped body parameter among multiple parameters
|
||||
# (e.g., path param + body param like: vector_store_id: str, params: Annotated[Model, Body(...)])
|
||||
unwrapped_body_param = None
|
||||
for param in params_list:
|
||||
if is_unwrapped_body_param(param.annotation):
|
||||
unwrapped_body_param = param
|
||||
break
|
||||
|
||||
# Convert parameters to Pydantic models where needed
|
||||
converted_body = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
|
@ -517,5 +530,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
converted_body[param_name] = value
|
||||
else:
|
||||
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
|
||||
elif unwrapped_body_param and param.name == unwrapped_body_param.name:
|
||||
# This is the unwrapped body param - construct it from remaining body keys
|
||||
base_type = get_args(param.annotation)[0]
|
||||
# Extract only the keys that aren't already used by other params
|
||||
remaining_keys = {k: v for k, v in body.items() if k not in converted_body}
|
||||
converted_body[param.name] = base_type(**remaining_keys)
|
||||
|
||||
return converted_body
|
||||
|
|
|
@ -28,7 +28,6 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
|||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.core.client import get_client_impl
|
||||
|
@ -55,7 +54,6 @@ from llama_stack.providers.datatypes import (
|
|||
ScoringFunctionsProtocolPrivate,
|
||||
ShieldsProtocolPrivate,
|
||||
ToolGroupsProtocolPrivate,
|
||||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
@ -81,7 +79,6 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
Api.inspect: Inspect,
|
||||
Api.batches: Batches,
|
||||
Api.vector_io: VectorIO,
|
||||
Api.vector_dbs: VectorDBs,
|
||||
Api.models: Models,
|
||||
Api.safety: Safety,
|
||||
Api.shields: Shields,
|
||||
|
@ -125,7 +122,6 @@ def additional_protocols_map() -> dict[Api, Any]:
|
|||
return {
|
||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||
Api.scoring: (
|
||||
|
@ -150,6 +146,7 @@ async def resolve_impls(
|
|||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
policy: list[AccessRule],
|
||||
internal_impls: dict[Api, Any] | None = None,
|
||||
) -> dict[Api, Any]:
|
||||
"""
|
||||
Resolves provider implementations by:
|
||||
|
@ -172,7 +169,7 @@ async def resolve_impls(
|
|||
|
||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy)
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy, internal_impls)
|
||||
|
||||
|
||||
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||
|
@ -280,9 +277,10 @@ async def instantiate_providers(
|
|||
dist_registry: DistributionRegistry,
|
||||
run_config: StackRunConfig,
|
||||
policy: list[AccessRule],
|
||||
internal_impls: dict[Api, Any] | None = None,
|
||||
) -> dict[Api, Any]:
|
||||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: dict[Api, Any] = {}
|
||||
impls: dict[Api, Any] = internal_impls.copy() if internal_impls else {}
|
||||
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||
for api_str, provider in sorted_providers:
|
||||
# Skip providers that are not enabled
|
||||
|
|
|
@ -31,10 +31,8 @@ async def get_routing_table_impl(
|
|||
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||
from ..routing_tables.shields import ShieldsRoutingTable
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||
|
||||
api_to_tables = {
|
||||
"vector_dbs": VectorDBsRoutingTable,
|
||||
"models": ModelsRoutingTable,
|
||||
"shields": ShieldsRoutingTable,
|
||||
"datasets": DatasetsRoutingTable,
|
||||
|
|
|
@ -10,9 +10,10 @@ from collections.abc import AsyncGenerator, AsyncIterator
|
|||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import Body
|
||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import Field, TypeAdapter
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
@ -31,15 +32,17 @@ from llama_stack.apis.inference import (
|
|||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIChoiceLogprobs,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAICompletionWithInputMessages,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
Order,
|
||||
StopReason,
|
||||
ToolPromptFormat,
|
||||
|
@ -181,61 +184,23 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
|
||||
) -> OpenAICompletion:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||
)
|
||||
model_obj = await self._get_model(model, ModelType.llm)
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
suffix=suffix,
|
||||
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
|
||||
)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
return await provider.openai_completion(**params)
|
||||
if params.stream:
|
||||
return await provider.openai_completion(params)
|
||||
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
||||
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
||||
# response_stream = await provider.openai_completion(**params)
|
||||
|
||||
response = await provider.openai_completion(**params)
|
||||
response = await provider.openai_completion(params)
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
|
@ -254,93 +219,49 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
params: Annotated[OpenAIChatCompletionRequestWithExtraBody, Body(...)],
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
|
||||
)
|
||||
model_obj = await self._get_model(model, ModelType.llm)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
|
||||
# Use the OpenAI client for a bit of extra input validation without
|
||||
# exposing the OpenAI client itself as part of our API surface
|
||||
if tool_choice:
|
||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
|
||||
if tools is None:
|
||||
if params.tool_choice:
|
||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
|
||||
if params.tools is None:
|
||||
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if params.tools:
|
||||
for tool in params.tools:
|
||||
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||
|
||||
# Some providers make tool calls even when tool_choice is "none"
|
||||
# so just clear them both out to avoid unexpected tool calls
|
||||
if tool_choice == "none" and tools is not None:
|
||||
tool_choice = None
|
||||
tools = None
|
||||
if params.tool_choice == "none" and params.tools is not None:
|
||||
params.tool_choice = None
|
||||
params.tools = None
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
response_stream = await provider.openai_chat_completion(**params)
|
||||
if params.stream:
|
||||
response_stream = await provider.openai_chat_completion(params)
|
||||
|
||||
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
|
||||
# We need to add metrics to each chunk and store the final completion
|
||||
return self.stream_tokens_and_compute_metrics_openai_chat(
|
||||
response=response_stream,
|
||||
model=model_obj,
|
||||
messages=messages,
|
||||
messages=params.messages,
|
||||
)
|
||||
|
||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||
|
||||
# Store the response with the ID that will be returned to the client
|
||||
if self.store:
|
||||
asyncio.create_task(self.store.store_chat_completion(response, messages))
|
||||
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
|
||||
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
|
@ -359,26 +280,18 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
params: Annotated[OpenAIEmbeddingsRequestWithExtraBody, Body(...)],
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
|
||||
)
|
||||
model_obj = await self._get_model(model, ModelType.embedding)
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
input=input,
|
||||
encoding_format=encoding_format,
|
||||
dimensions=dimensions,
|
||||
user=user,
|
||||
f"InferenceRouter.openai_embeddings: model={params.model}, input_type={type(params.input)}, encoding_format={params.encoding_format}, dimensions={params.dimensions}",
|
||||
)
|
||||
model_obj = await self._get_model(params.model, ModelType.embedding)
|
||||
|
||||
# Update model to use resolved identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_embeddings(**params)
|
||||
return await provider.openai_embeddings(params)
|
||||
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
|
@ -396,8 +309,10 @@ class InferenceRouter(Inference):
|
|||
return await self.store.get_chat_completion(completion_id)
|
||||
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
||||
|
||||
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
||||
response = await provider.openai_chat_completion(**params)
|
||||
async def _nonstream_openai_chat_completion(
|
||||
self, provider: Inference, params: OpenAIChatCompletionRequestWithExtraBody
|
||||
) -> OpenAIChatCompletion:
|
||||
response = await provider.openai_chat_completion(params)
|
||||
for choice in response.choices:
|
||||
# some providers return an empty list for no tool calls in non-streaming responses
|
||||
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
||||
|
@ -611,7 +526,7 @@ class InferenceRouter(Inference):
|
|||
completion_text += "".join(choice_data["content_parts"])
|
||||
|
||||
# Add metrics to the chunk
|
||||
if self.telemetry and chunk.usage:
|
||||
if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
|
|
|
@ -6,12 +6,16 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import Body
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||
OpenAICreateVectorStoreRequestWithExtraBody,
|
||||
QueryChunksResponse,
|
||||
SearchRankingOptions,
|
||||
VectorIO,
|
||||
|
@ -51,30 +55,18 @@ class VectorIORouter(VectorIO):
|
|||
logger.debug("VectorIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
|
||||
"""Get the first available embedding model identifier."""
|
||||
try:
|
||||
# Get all models from the routing table
|
||||
all_models = await self.routing_table.get_all_with_type("model")
|
||||
async def _get_embedding_model_dimension(self, embedding_model_id: str) -> int:
|
||||
"""Get the embedding dimension for a specific embedding model."""
|
||||
all_models = await self.routing_table.get_all_with_type("model")
|
||||
|
||||
# Filter for embedding models
|
||||
embedding_models = [
|
||||
model
|
||||
for model in all_models
|
||||
if hasattr(model, "model_type") and model.model_type == ModelType.embedding
|
||||
]
|
||||
|
||||
if embedding_models:
|
||||
dimension = embedding_models[0].metadata.get("embedding_dimension", None)
|
||||
for model in all_models:
|
||||
if model.identifier == embedding_model_id and model.model_type == ModelType.embedding:
|
||||
dimension = model.metadata.get("embedding_dimension")
|
||||
if dimension is None:
|
||||
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
|
||||
return embedding_models[0].identifier, dimension
|
||||
else:
|
||||
logger.warning("No embedding models found in the routing table")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting embedding models: {e}")
|
||||
return None
|
||||
raise ValueError(f"Embedding model '{embedding_model_id}' has no embedding_dimension in metadata")
|
||||
return int(dimension)
|
||||
|
||||
raise ValueError(f"Embedding model '{embedding_model_id}' not found or not an embedding model")
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
|
@ -120,24 +112,35 @@ class VectorIORouter(VectorIO):
|
|||
# OpenAI Vector Stores API endpoints
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
embedding_model: str | None = None,
|
||||
embedding_dimension: int | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)],
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
|
||||
# Extract llama-stack-specific parameters from extra_body
|
||||
extra = params.model_extra or {}
|
||||
embedding_model = extra.get("embedding_model")
|
||||
embedding_dimension = extra.get("embedding_dimension")
|
||||
provider_id = extra.get("provider_id")
|
||||
|
||||
# If no embedding model is provided, use the first available one
|
||||
logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={provider_id}")
|
||||
|
||||
# Require explicit embedding model specification
|
||||
if embedding_model is None:
|
||||
embedding_model_info = await self._get_first_embedding_model()
|
||||
if embedding_model_info is None:
|
||||
raise ValueError("No embedding model provided and no embedding models available in the system")
|
||||
embedding_model, embedding_dimension = embedding_model_info
|
||||
logger.info(f"No embedding model specified, using first available: {embedding_model}")
|
||||
raise ValueError("embedding_model is required in extra_body when creating a vector store")
|
||||
|
||||
if embedding_dimension is None:
|
||||
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
||||
|
||||
# Auto-select provider if not specified
|
||||
if provider_id is None:
|
||||
num_providers = len(self.routing_table.impls_by_provider_id)
|
||||
if num_providers == 0:
|
||||
raise ValueError("No vector_io providers available")
|
||||
if num_providers > 1:
|
||||
available_providers = list(self.routing_table.impls_by_provider_id.keys())
|
||||
raise ValueError(
|
||||
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
|
||||
f"Available providers: {available_providers}"
|
||||
)
|
||||
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
||||
|
||||
vector_db_id = f"vs_{uuid.uuid4()}"
|
||||
registered_vector_db = await self.routing_table.register_vector_db(
|
||||
|
@ -146,20 +149,19 @@ class VectorIORouter(VectorIO):
|
|||
embedding_dimension=embedding_dimension,
|
||||
provider_id=provider_id,
|
||||
provider_vector_db_id=vector_db_id,
|
||||
vector_db_name=name,
|
||||
vector_db_name=params.name,
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
|
||||
return await provider.openai_create_vector_store(
|
||||
name=name,
|
||||
file_ids=file_ids,
|
||||
expires_after=expires_after,
|
||||
chunking_strategy=chunking_strategy,
|
||||
metadata=metadata,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=registered_vector_db.provider_id,
|
||||
provider_vector_db_id=registered_vector_db.provider_resource_id,
|
||||
)
|
||||
|
||||
# Update model_extra with registered values so provider uses the already-registered vector_db
|
||||
if params.model_extra is None:
|
||||
params.model_extra = {}
|
||||
params.model_extra["provider_vector_db_id"] = registered_vector_db.provider_resource_id
|
||||
params.model_extra["provider_id"] = registered_vector_db.provider_id
|
||||
params.model_extra["embedding_model"] = embedding_model
|
||||
params.model_extra["embedding_dimension"] = embedding_dimension
|
||||
|
||||
return await provider.openai_create_vector_store(params)
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
|
@ -219,7 +221,8 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
||||
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
|
@ -229,7 +232,8 @@ class VectorIORouter(VectorIO):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
||||
return await self.routing_table.openai_update_vector_store(
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
|
@ -241,7 +245,8 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
||||
return await self.routing_table.openai_delete_vector_store(vector_store_id)
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store(vector_store_id)
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
|
@ -254,7 +259,8 @@ class VectorIORouter(VectorIO):
|
|||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
||||
return await self.routing_table.openai_search_vector_store(
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
|
@ -272,7 +278,8 @@ class VectorIORouter(VectorIO):
|
|||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
||||
return await self.routing_table.openai_attach_file_to_vector_store(
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
|
@ -289,7 +296,8 @@ class VectorIORouter(VectorIO):
|
|||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
||||
return await self.routing_table.openai_list_files_in_vector_store(
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
|
@ -304,7 +312,8 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
|
||||
return await self.routing_table.openai_retrieve_vector_store_file(
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
@ -315,7 +324,8 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
||||
return await self.routing_table.openai_retrieve_vector_store_file_contents(
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
@ -327,7 +337,8 @@ class VectorIORouter(VectorIO):
|
|||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
|
||||
return await self.routing_table.openai_update_vector_store_file(
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
|
@ -339,7 +350,8 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
|
||||
return await self.routing_table.openai_delete_vector_store_file(
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
@ -370,17 +382,13 @@ class VectorIORouter(VectorIO):
|
|||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_ids: list[str],
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
params: Annotated[OpenAICreateVectorStoreFileBatchRequestWithExtraBody, Body(...)],
|
||||
) -> VectorStoreFileBatchObject:
|
||||
logger.debug(f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(file_ids)} files")
|
||||
return await self.routing_table.openai_create_vector_store_file_batch(
|
||||
vector_store_id=vector_store_id,
|
||||
file_ids=file_ids,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
logger.debug(
|
||||
f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(params.file_ids)} files"
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_create_vector_store_file_batch(vector_store_id, params)
|
||||
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
|
@ -388,7 +396,8 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
||||
return await self.routing_table.openai_retrieve_vector_store_file_batch(
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
@ -404,7 +413,8 @@ class VectorIORouter(VectorIO):
|
|||
order: str | None = "desc",
|
||||
) -> VectorStoreFilesListInBatchResponse:
|
||||
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
||||
return await self.routing_table.openai_list_files_in_vector_store_file_batch(
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
after=after,
|
||||
|
@ -420,7 +430,8 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
logger.debug(f"VectorIORouter.openai_cancel_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
||||
return await self.routing_table.openai_cancel_vector_store_file_batch(
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_cancel_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
|
|
@ -9,7 +9,6 @@ from typing import Any
|
|||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.access_control.datatypes import Action
|
||||
from llama_stack.core.datatypes import (
|
||||
|
@ -17,6 +16,7 @@ from llama_stack.core.datatypes import (
|
|||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
ScoringFnWithOwner,
|
||||
)
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.core.store import DistributionRegistry
|
||||
|
@ -114,7 +114,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
elif api == Api.scoring:
|
||||
p.scoring_function_store = self
|
||||
scoring_functions = await p.list_scoring_functions()
|
||||
await add_objects(scoring_functions, pid, ScoringFn)
|
||||
await add_objects(scoring_functions, pid, ScoringFnWithOwner)
|
||||
elif api == Api.eval:
|
||||
p.benchmark_store = self
|
||||
elif api == Api.tool_runtime:
|
||||
|
@ -134,15 +134,12 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
from .scoring_functions import ScoringFunctionsRoutingTable
|
||||
from .shields import ShieldsRoutingTable
|
||||
from .toolgroups import ToolGroupsRoutingTable
|
||||
from .vector_dbs import VectorDBsRoutingTable
|
||||
|
||||
def apiname_object():
|
||||
if isinstance(self, ModelsRoutingTable):
|
||||
return ("Inference", "model")
|
||||
elif isinstance(self, ShieldsRoutingTable):
|
||||
return ("Safety", "shield")
|
||||
elif isinstance(self, VectorDBsRoutingTable):
|
||||
return ("VectorIO", "vector_db")
|
||||
elif isinstance(self, DatasetsRoutingTable):
|
||||
return ("DatasetIO", "dataset")
|
||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||
|
|
|
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
try:
|
||||
models = await provider.list_models()
|
||||
except Exception as e:
|
||||
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
logger.debug(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
continue
|
||||
|
||||
self.listed_providers.add(provider_id)
|
||||
|
@ -67,6 +67,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
|
||||
return self.impls_by_provider_id[model.provider_id]
|
||||
|
||||
async def has_model(self, model_id: str) -> bool:
|
||||
"""
|
||||
Check if a model exists in the routing table.
|
||||
|
||||
:param model_id: The model identifier to check
|
||||
:return: True if the model exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
await lookup_model(self, model_id)
|
||||
return True
|
||||
except ModelNotFoundError:
|
||||
return False
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
@ -1,247 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import (
|
||||
VectorDBWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
||||
|
||||
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
||||
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||
if vector_db is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
return vector_db
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
) -> VectorDB:
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
if len(self.impls_by_provider_id) > 1:
|
||||
logger.warning(
|
||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||
)
|
||||
else:
|
||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
model = await lookup_model(self, embedding_model)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(embedding_model)
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||
|
||||
provider = self.impls_by_provider_id[provider_id]
|
||||
logger.warning(
|
||||
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
|
||||
)
|
||||
vector_store = await provider.openai_create_vector_store(
|
||||
name=vector_db_name or vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=model.metadata["embedding_dimension"],
|
||||
provider_id=provider_id,
|
||||
provider_vector_db_id=provider_vector_db_id,
|
||||
)
|
||||
|
||||
vector_store_id = vector_store.id
|
||||
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
|
||||
logger.warning(
|
||||
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
|
||||
)
|
||||
|
||||
vector_db_data = {
|
||||
"identifier": vector_store_id,
|
||||
"type": ResourceType.vector_db.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": actual_provider_vector_db_id,
|
||||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
"vector_db_name": vector_store.name,
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
return vector_db
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
existing_vector_db = await self.get_vector_db(vector_db_id)
|
||||
await self.unregister_object(existing_vector_db)
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||
await self.unregister_vector_db(vector_store_id)
|
||||
return result
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
max_num_results=max_num_results,
|
||||
ranking_options=ranking_options,
|
||||
rewrite_query=rewrite_query,
|
||||
search_mode=search_mode,
|
||||
)
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
|
@ -27,6 +27,11 @@ class AuthenticationMiddleware:
|
|||
3. Extracts user attributes from the provider's response
|
||||
4. Makes these attributes available to the route handlers for access control
|
||||
|
||||
Unauthenticated Access:
|
||||
Endpoints can opt out of authentication by setting require_authentication=False
|
||||
in their @webmethod decorator. This is typically used for operational endpoints
|
||||
like /health and /version to support monitoring, load balancers, and observability tools.
|
||||
|
||||
The middleware supports multiple authentication providers through the AuthProvider interface:
|
||||
- Kubernetes: Validates tokens against the Kubernetes API server
|
||||
- Custom: Validates tokens against a custom endpoint
|
||||
|
@ -88,7 +93,26 @@ class AuthenticationMiddleware:
|
|||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
# First, handle authentication
|
||||
# Find the route and check if authentication is required
|
||||
path = scope.get("path", "")
|
||||
method = scope.get("method", hdrs.METH_GET)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
|
||||
webmethod = None
|
||||
try:
|
||||
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass here to run auth anyways
|
||||
pass
|
||||
|
||||
# If webmethod explicitly sets require_authentication=False, allow without auth
|
||||
if webmethod and webmethod.require_authentication is False:
|
||||
logger.debug(f"Allowing unauthenticated access to endpoint: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Handle authentication
|
||||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
|
@ -127,19 +151,7 @@ class AuthenticationMiddleware:
|
|||
)
|
||||
|
||||
# Scope-based API access control
|
||||
path = scope.get("path", "")
|
||||
method = scope.get("method", hdrs.METH_GET)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
|
||||
try:
|
||||
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if webmethod.required_scope:
|
||||
if webmethod and webmethod.required_scope:
|
||||
user = user_from_scope(scope)
|
||||
if not _has_required_scope(webmethod.required_scope, user):
|
||||
return await self._send_auth_error(
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import functools
|
||||
|
@ -12,7 +11,6 @@ import inspect
|
|||
import json
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
@ -35,7 +33,6 @@ from pydantic import BaseModel, ValidationError
|
|||
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
|
@ -55,7 +52,6 @@ from llama_stack.core.stack import (
|
|||
Stack,
|
||||
cast_image_name_to_string,
|
||||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
|
@ -142,6 +138,13 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
return HTTPException(status_code=httpx.codes.NOT_IMPLEMENTED, detail=f"Not implemented: {str(exc)}")
|
||||
elif isinstance(exc, AuthenticationRequiredError):
|
||||
return HTTPException(status_code=httpx.codes.UNAUTHORIZED, detail=f"Authentication required: {str(exc)}")
|
||||
elif hasattr(exc, "status_code") and isinstance(getattr(exc, "status_code", None), int):
|
||||
# Handle provider SDK exceptions (e.g., OpenAI's APIStatusError and subclasses)
|
||||
# These include AuthenticationError (401), PermissionDeniedError (403), etc.
|
||||
# This preserves the actual HTTP status code from the provider
|
||||
status_code = exc.status_code
|
||||
detail = str(exc)
|
||||
return HTTPException(status_code=status_code, detail=detail)
|
||||
else:
|
||||
return HTTPException(
|
||||
status_code=httpx.codes.INTERNAL_SERVER_ERROR,
|
||||
|
@ -181,7 +184,17 @@ async def lifespan(app: StackApp):
|
|||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
# If there's a stream parameter at top level, use it
|
||||
if "stream" in kwargs:
|
||||
return kwargs["stream"]
|
||||
|
||||
# If there's a stream parameter inside a "params" parameter, e.g. openai_chat_completion() use it
|
||||
if "params" in kwargs:
|
||||
params = kwargs["params"]
|
||||
if hasattr(params, "stream"):
|
||||
return params.stream
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def maybe_await(value):
|
||||
|
@ -236,15 +249,31 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
test_context_token = None
|
||||
test_context_var = None
|
||||
reset_test_context_fn = None
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user):
|
||||
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
|
||||
from llama_stack.core.testing_context import (
|
||||
TEST_CONTEXT,
|
||||
reset_test_context,
|
||||
sync_test_context_from_provider_data,
|
||||
)
|
||||
|
||||
test_context_token = sync_test_context_from_provider_data()
|
||||
test_context_var = TEST_CONTEXT
|
||||
reset_test_context_fn = reset_test_context
|
||||
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
try:
|
||||
if is_streaming:
|
||||
gen = preserve_contexts_async_generator(
|
||||
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||
)
|
||||
context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||
if test_context_var is not None:
|
||||
context_vars.append(test_context_var)
|
||||
gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), context_vars)
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
|
@ -262,6 +291,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
else:
|
||||
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
if test_context_token is not None and reset_test_context_fn is not None:
|
||||
reset_test_context_fn(test_context_token)
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
|
@ -333,23 +365,18 @@ class ClientVersionMiddleware:
|
|||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def create_app(
|
||||
config_file: str | None = None,
|
||||
env_vars: list[str] | None = None,
|
||||
) -> StackApp:
|
||||
def create_app() -> StackApp:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
Args:
|
||||
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
|
||||
env_vars: List of environment variables in KEY=value format.
|
||||
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
|
||||
This factory function reads configuration from environment variables:
|
||||
- LLAMA_STACK_CONFIG: Path to config file (required)
|
||||
|
||||
Returns:
|
||||
Configured StackApp instance.
|
||||
"""
|
||||
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
|
||||
config_file = os.getenv("LLAMA_STACK_CONFIG")
|
||||
if config_file is None:
|
||||
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
|
||||
raise ValueError("LLAMA_STACK_CONFIG environment variable is required")
|
||||
|
||||
config_file = resolve_config_or_distro(config_file, Mode.RUN)
|
||||
|
||||
|
@ -361,16 +388,6 @@ def create_app(
|
|||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||
|
||||
if env_vars:
|
||||
for env_pair in env_vars:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
|
||||
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
|
@ -494,101 +511,6 @@ def create_app(
|
|||
return app
|
||||
|
||||
|
||||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
|
||||
add_config_distro_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||
)
|
||||
|
||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||
# parsed from the command line
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
config_or_distro = get_config_from_args(args)
|
||||
|
||||
try:
|
||||
app = create_app(
|
||||
config_file=config_or_distro,
|
||||
env_vars=args.env,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating app: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
import uvicorn
|
||||
|
||||
# Configure SSL if certificates are provided
|
||||
port = args.port or config.server.port
|
||||
|
||||
ssl_config = None
|
||||
keyfile = config.server.tls_keyfile
|
||||
certfile = config.server.tls_certfile
|
||||
|
||||
if keyfile and certfile:
|
||||
ssl_config = {
|
||||
"ssl_keyfile": keyfile,
|
||||
"ssl_certfile": certfile,
|
||||
}
|
||||
if config.server.tls_cafile:
|
||||
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
|
||||
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
||||
logger.info(
|
||||
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
listen_host = config.server.host or ["::", "0.0.0.0"]
|
||||
logger.info(f"Listening on {listen_host}:{port}")
|
||||
|
||||
uvicorn_config = {
|
||||
"app": app,
|
||||
"host": listen_host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
}
|
||||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
# stack trace when using Ctrl+C or kill -2 (SIGINT).
|
||||
# SIGTERM (kill -15) works fine without this because Python doesn't
|
||||
# have a default handler for it.
|
||||
#
|
||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
"""Logs the run config with redacted fields and disabled providers removed."""
|
||||
logger.info("Run configuration:")
|
||||
|
@ -615,7 +537,3 @@ def remove_disabled_providers(obj):
|
|||
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -33,7 +33,6 @@ from llama_stack.apis.shields import Shields
|
|||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
|
@ -53,7 +52,6 @@ logger = get_logger(name=__name__, category="core")
|
|||
|
||||
class LlamaStack(
|
||||
Providers,
|
||||
VectorDBs,
|
||||
Inference,
|
||||
Agents,
|
||||
Safety,
|
||||
|
@ -83,7 +81,6 @@ class LlamaStack(
|
|||
RESOURCES = [
|
||||
("models", Api.models, "register_model", "list_models"),
|
||||
("shields", Api.shields, "register_shield", "list_shields"),
|
||||
("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
|
||||
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||
(
|
||||
"scoring_fns",
|
||||
|
@ -274,22 +271,6 @@ def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|||
return config_dict
|
||||
|
||||
|
||||
def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
||||
"""Validate and split an environment variable key-value pair."""
|
||||
try:
|
||||
key, value = env_pair.split("=", 1)
|
||||
key = key.strip()
|
||||
if not key:
|
||||
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
|
||||
if not all(c.isalnum() or c == "_" for c in key):
|
||||
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
|
||||
return key, value
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value"
|
||||
) from e
|
||||
|
||||
|
||||
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||
|
||||
|
@ -332,22 +313,27 @@ class Stack:
|
|||
# asked for in the run config.
|
||||
async def initialize(self):
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||
from llama_stack.testing.api_recorder import setup_api_recording
|
||||
|
||||
global TEST_RECORDING_CONTEXT
|
||||
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
||||
TEST_RECORDING_CONTEXT = setup_api_recording()
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
TEST_RECORDING_CONTEXT.__enter__()
|
||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
logger.info(f"API recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
|
||||
dist_registry, _ = await create_dist_registry(self.run_config.persistence, self.run_config.image_name)
|
||||
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
||||
impls = await resolve_impls(
|
||||
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
|
||||
)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, self.run_config)
|
||||
internal_impls = {}
|
||||
add_internal_implementations(internal_impls, self.run_config)
|
||||
|
||||
impls = await resolve_impls(
|
||||
self.run_config,
|
||||
self.provider_registry or get_provider_registry(self.run_config),
|
||||
dist_registry,
|
||||
policy,
|
||||
internal_impls,
|
||||
)
|
||||
|
||||
if Api.prompts in impls:
|
||||
await impls[Api.prompts].initialize()
|
||||
|
@ -397,7 +383,7 @@ class Stack:
|
|||
try:
|
||||
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference recording cleanup: {e}")
|
||||
logger.error(f"Error during API recording cleanup: {e}")
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
if REGISTRY_REFRESH_TASK:
|
||||
|
|
|
@ -25,7 +25,7 @@ error_handler() {
|
|||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>] [--env KEY=VALUE]..."
|
||||
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -43,7 +43,6 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
|||
|
||||
# Initialize variables
|
||||
yaml_config=""
|
||||
env_vars=""
|
||||
other_args=""
|
||||
|
||||
# Process remaining arguments
|
||||
|
@ -58,15 +57,6 @@ while [[ $# -gt 0 ]]; do
|
|||
exit 1
|
||||
fi
|
||||
;;
|
||||
--env)
|
||||
if [[ -n "$2" ]]; then
|
||||
env_vars="$env_vars --env $2"
|
||||
shift 2
|
||||
else
|
||||
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
other_args="$other_args $1"
|
||||
shift
|
||||
|
@ -116,10 +106,9 @@ if [[ "$env_type" == "venv" ]]; then
|
|||
yaml_config_arg=""
|
||||
fi
|
||||
|
||||
$PYTHON_BINARY -m llama_stack.core.server.server \
|
||||
llama stack run \
|
||||
$yaml_config_arg \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
$other_args
|
||||
elif [[ "$env_type" == "container" ]]; then
|
||||
echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}"
|
||||
|
|
|
@ -95,9 +95,11 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
|
||||
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
||||
existing_obj = await self.get(obj.type, obj.identifier)
|
||||
# dont register if the object's providerid already exists
|
||||
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
||||
return False
|
||||
if existing_obj and existing_obj != obj:
|
||||
raise ValueError(
|
||||
f"Object of type '{obj.type}' and identifier '{obj.identifier}' already exists. "
|
||||
"Unregister it first if you want to replace it."
|
||||
)
|
||||
|
||||
await self.kvstore.set(
|
||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||
|
|
44
llama_stack/core/testing_context.py
Normal file
44
llama_stack/core/testing_context.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
|
||||
|
||||
TEST_CONTEXT: ContextVar[str | None] = ContextVar("llama_stack_test_context", default=None)
|
||||
|
||||
|
||||
def get_test_context() -> str | None:
|
||||
return TEST_CONTEXT.get()
|
||||
|
||||
|
||||
def set_test_context(value: str | None):
|
||||
return TEST_CONTEXT.set(value)
|
||||
|
||||
|
||||
def reset_test_context(token) -> None:
|
||||
TEST_CONTEXT.reset(token)
|
||||
|
||||
|
||||
def sync_test_context_from_provider_data():
|
||||
"""Sync test context from provider data when running in server test mode."""
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
|
||||
return None
|
||||
|
||||
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||
if stack_config_type != "server":
|
||||
return None
|
||||
|
||||
try:
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
except LookupError:
|
||||
provider_data = None
|
||||
|
||||
if provider_data and "__test_id" in provider_data:
|
||||
return TEST_CONTEXT.set(provider_data["__test_id"])
|
||||
|
||||
return None
|
|
@ -11,19 +11,17 @@ from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks
|
|||
from llama_stack.core.ui.page.distribution.models import models
|
||||
from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions
|
||||
from llama_stack.core.ui.page.distribution.shields import shields
|
||||
from llama_stack.core.ui.page.distribution.vector_dbs import vector_dbs
|
||||
|
||||
|
||||
def resources_page():
|
||||
options = [
|
||||
"Models",
|
||||
"Vector Databases",
|
||||
"Shields",
|
||||
"Scoring Functions",
|
||||
"Datasets",
|
||||
"Benchmarks",
|
||||
]
|
||||
icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"]
|
||||
icons = ["magic", "shield", "file-bar-graph", "database", "list-task"]
|
||||
selected_resource = option_menu(
|
||||
None,
|
||||
options,
|
||||
|
@ -37,8 +35,6 @@ def resources_page():
|
|||
)
|
||||
if selected_resource == "Benchmarks":
|
||||
benchmarks()
|
||||
elif selected_resource == "Vector Databases":
|
||||
vector_dbs()
|
||||
elif selected_resource == "Datasets":
|
||||
datasets()
|
||||
elif selected_resource == "Models":
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def vector_dbs():
|
||||
st.header("Vector Databases")
|
||||
vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()}
|
||||
|
||||
if len(vector_dbs_info) > 0:
|
||||
selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys()))
|
||||
st.json(vector_dbs_info[selected_vector_db])
|
||||
else:
|
||||
st.info("No vector databases found")
|
|
@ -1,301 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
||||
|
||||
from llama_stack.apis.common.content_types import ToolCallDelta
|
||||
from llama_stack.core.ui.modules.api import llama_stack_api
|
||||
from llama_stack.core.ui.modules.utils import data_url_from_file
|
||||
|
||||
|
||||
def rag_chat_page():
|
||||
st.title("🦙 RAG")
|
||||
|
||||
def reset_agent_and_chat():
|
||||
st.session_state.clear()
|
||||
st.cache_resource.clear()
|
||||
|
||||
def should_disable_input():
|
||||
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
|
||||
|
||||
def log_message(message):
|
||||
with st.chat_message(message["role"]):
|
||||
if "tool_output" in message and message["tool_output"]:
|
||||
with st.expander(label="Tool Output", expanded=False, icon="🛠"):
|
||||
st.write(message["tool_output"])
|
||||
st.markdown(message["content"])
|
||||
|
||||
with st.sidebar:
|
||||
# File/Directory Upload Section
|
||||
st.subheader("Upload Documents", divider=True)
|
||||
uploaded_files = st.file_uploader(
|
||||
"Upload file(s) or directory",
|
||||
accept_multiple_files=True,
|
||||
type=["txt", "pdf", "doc", "docx"], # Add more file types as needed
|
||||
)
|
||||
# Process uploaded files
|
||||
if uploaded_files:
|
||||
st.success(f"Successfully uploaded {len(uploaded_files)} files")
|
||||
# Add memory bank name input field
|
||||
vector_db_name = st.text_input(
|
||||
"Document Collection Name",
|
||||
value="rag_vector_db",
|
||||
help="Enter a unique identifier for this document collection",
|
||||
)
|
||||
if st.button("Create Document Collection"):
|
||||
documents = [
|
||||
RAGDocument(
|
||||
document_id=uploaded_file.name,
|
||||
content=data_url_from_file(uploaded_file),
|
||||
)
|
||||
for i, uploaded_file in enumerate(uploaded_files)
|
||||
]
|
||||
|
||||
providers = llama_stack_api.client.providers.list()
|
||||
vector_io_provider = None
|
||||
|
||||
for x in providers:
|
||||
if x.api == "vector_io":
|
||||
vector_io_provider = x.provider_id
|
||||
|
||||
llama_stack_api.client.vector_dbs.register(
|
||||
vector_db_id=vector_db_name, # Use the user-provided name
|
||||
embedding_dimension=384,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
provider_id=vector_io_provider,
|
||||
)
|
||||
|
||||
# insert documents using the custom vector db name
|
||||
llama_stack_api.client.tool_runtime.rag_tool.insert(
|
||||
vector_db_id=vector_db_name, # Use the user-provided name
|
||||
documents=documents,
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
st.success("Vector database created successfully!")
|
||||
|
||||
st.subheader("RAG Parameters", divider=True)
|
||||
|
||||
rag_mode = st.radio(
|
||||
"RAG mode",
|
||||
["Direct", "Agent-based"],
|
||||
captions=[
|
||||
"RAG is performed by directly retrieving the information and augmenting the user query",
|
||||
"RAG is performed by an agent activating a dedicated knowledge search tool.",
|
||||
],
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
|
||||
# select memory banks
|
||||
vector_dbs = llama_stack_api.client.vector_dbs.list()
|
||||
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
|
||||
selected_vector_dbs = st.multiselect(
|
||||
label="Select Document Collections to use in RAG queries",
|
||||
options=vector_dbs,
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
|
||||
st.subheader("Inference Parameters", divider=True)
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
|
||||
selected_model = st.selectbox(
|
||||
label="Choose a model",
|
||||
options=available_models,
|
||||
index=0,
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
system_prompt = st.text_area(
|
||||
"System Prompt",
|
||||
value="You are a helpful assistant. ",
|
||||
help="Initial instructions given to the AI to set its behavior and context",
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
temperature = st.slider(
|
||||
"Temperature",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.0,
|
||||
step=0.1,
|
||||
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
|
||||
top_p = st.slider(
|
||||
"Top P",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.95,
|
||||
step=0.1,
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
|
||||
# Add clear chat button to sidebar
|
||||
if st.button("Clear Chat", use_container_width=True):
|
||||
reset_agent_and_chat()
|
||||
st.rerun()
|
||||
|
||||
# Chat Interface
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
if "displayed_messages" not in st.session_state:
|
||||
st.session_state.displayed_messages = []
|
||||
|
||||
# Display chat history
|
||||
for message in st.session_state.displayed_messages:
|
||||
log_message(message)
|
||||
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
"type": "top_p",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
return Agent(
|
||||
llama_stack_api.client,
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
tools=[
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
"vector_db_ids": list(selected_vector_dbs),
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
if rag_mode == "Agent-based":
|
||||
agent = create_agent()
|
||||
if "agent_session_id" not in st.session_state:
|
||||
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
|
||||
|
||||
session_id = st.session_state["agent_session_id"]
|
||||
|
||||
def agent_process_prompt(prompt):
|
||||
# Add user message to chat history
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Send the prompt to the agent
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Display assistant response
|
||||
with st.chat_message("assistant"):
|
||||
retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠")
|
||||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
retrieval_response = ""
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
if log.role == "tool_execution":
|
||||
retrieval_response += log.content.replace("====", "").strip()
|
||||
retrieval_message_placeholder.write(retrieval_response)
|
||||
else:
|
||||
full_response += log.content
|
||||
message_placeholder.markdown(full_response + "▌")
|
||||
message_placeholder.markdown(full_response)
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
st.session_state.displayed_messages.append(
|
||||
{"role": "assistant", "content": full_response, "tool_output": retrieval_response}
|
||||
)
|
||||
|
||||
def direct_process_prompt(prompt):
|
||||
# Add the system prompt in the beginning of the conversation
|
||||
if len(st.session_state.messages) == 0:
|
||||
st.session_state.messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# Query the vector DB
|
||||
rag_response = llama_stack_api.client.tool_runtime.rag_tool.query(
|
||||
content=prompt, vector_db_ids=list(selected_vector_dbs)
|
||||
)
|
||||
prompt_context = rag_response.content
|
||||
|
||||
with st.chat_message("assistant"):
|
||||
with st.expander(label="Retrieval Output", expanded=False):
|
||||
st.write(prompt_context)
|
||||
|
||||
retrieval_message_placeholder = st.empty()
|
||||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
retrieval_response = ""
|
||||
|
||||
# Construct the extended prompt
|
||||
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
|
||||
|
||||
# Run inference directly
|
||||
st.session_state.messages.append({"role": "user", "content": extended_prompt})
|
||||
response = llama_stack_api.client.inference.chat_completion(
|
||||
messages=st.session_state.messages,
|
||||
model_id=selected_model,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Display assistant response
|
||||
for chunk in response:
|
||||
response_delta = chunk.event.delta
|
||||
if isinstance(response_delta, ToolCallDelta):
|
||||
retrieval_response += response_delta.tool_call.replace("====", "").strip()
|
||||
retrieval_message_placeholder.info(retrieval_response)
|
||||
else:
|
||||
full_response += chunk.event.delta.text
|
||||
message_placeholder.markdown(full_response + "▌")
|
||||
message_placeholder.markdown(full_response)
|
||||
|
||||
response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"}
|
||||
st.session_state.messages.append(response_dict)
|
||||
st.session_state.displayed_messages.append(response_dict)
|
||||
|
||||
# Chat input
|
||||
if prompt := st.chat_input("Ask a question about your documents"):
|
||||
# Add user message to chat history
|
||||
st.session_state.displayed_messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Display user message
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
# store the prompt to process it after page refresh
|
||||
st.session_state.prompt = prompt
|
||||
|
||||
# force page refresh to disable the settings widgets
|
||||
st.rerun()
|
||||
|
||||
if "prompt" in st.session_state and st.session_state.prompt is not None:
|
||||
if rag_mode == "Agent-based":
|
||||
agent_process_prompt(st.session_state.prompt)
|
||||
else: # rag_mode == "Direct"
|
||||
direct_process_prompt(st.session_state.prompt)
|
||||
st.session_state.prompt = None
|
||||
|
||||
|
||||
rag_chat_page()
|
|
@ -117,11 +117,11 @@ docker run -it \
|
|||
# NOTE: mount the llama-stack directory if testing local changes else not needed
|
||||
-v $HOME/git/llama-stack:/app/llama-stack-source \
|
||||
# localhost/distribution-dell:dev if building / testing locally
|
||||
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
-e DEH_URL=$DEH_URL \
|
||||
-e CHROMA_URL=$CHROMA_URL \
|
||||
llamastack/distribution-{{ name }}\
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
--port $LLAMA_STACK_PORT
|
||||
|
||||
```
|
||||
|
||||
|
@ -142,14 +142,14 @@ docker run \
|
|||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v $HOME/.llama:/root/.llama \
|
||||
-v ./llama_stack/distributions/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
-e DEH_URL=$DEH_URL \
|
||||
-e SAFETY_MODEL=$SAFETY_MODEL \
|
||||
-e DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
-e CHROMA_URL=$CHROMA_URL \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
||||
--env DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via Conda
|
||||
|
@ -158,21 +158,21 @@ Make sure you have done `pip install llama-stack` and have the Llama Stack CLI a
|
|||
|
||||
```bash
|
||||
llama stack build --distro {{ name }} --image-type conda
|
||||
llama stack run {{ name }}
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
DEH_URL=$DEH_URL \
|
||||
CHROMA_URL=$CHROMA_URL \
|
||||
llama stack run {{ name }} \
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, use:
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
DEH_URL=$DEH_URL \
|
||||
SAFETY_MODEL=$SAFETY_MODEL \
|
||||
DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
CHROMA_URL=$CHROMA_URL \
|
||||
llama stack run ./run-with-safety.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
||||
--env DEH_SAFETY_URL=$DEH_SAFETY_URL \
|
||||
--env CHROMA_URL=$CHROMA_URL
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
|
|
@ -101,6 +101,9 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/inference_store.db
|
||||
conversations_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/conversations.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -29,31 +29,7 @@ The following environment variables can be configured:
|
|||
|
||||
## Prerequisite: Downloading Models
|
||||
|
||||
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
||||
|
||||
```
|
||||
$ llama model list --downloaded
|
||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
|
||||
┃ Model ┃ Size ┃ Modified Time ┃
|
||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
|
||||
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
|
||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
|
||||
└─────────────────────────────────────────┴──────────┴─────────────────────┘
|
||||
Please check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models using the Hugging Face CLI.
|
||||
```
|
||||
|
||||
## Running the Distribution
|
||||
|
@ -72,9 +48,9 @@ docker run \
|
|||
--gpu all \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, use:
|
||||
|
@ -86,10 +62,10 @@ docker run \
|
|||
--gpu all \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-e INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
-e SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via venv
|
||||
|
@ -98,16 +74,16 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
|
|||
|
||||
```bash
|
||||
llama stack build --distro {{ name }} --image-type venv
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
llama stack run distributions/{{ name }}/run.yaml \
|
||||
--port 8321 \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
--port 8321
|
||||
```
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, use:
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
SAFETY_MODEL=meta-llama/Llama-Guard-3-1B \
|
||||
llama stack run distributions/{{ name }}/run-with-safety.yaml \
|
||||
--port 8321 \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
--port 8321
|
||||
```
|
||||
|
|
|
@ -114,6 +114,9 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/inference_store.db
|
||||
conversations_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/conversations.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -118,10 +118,10 @@ docker run \
|
|||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via venv
|
||||
|
@ -131,10 +131,10 @@ If you've set up your local development environment, you can also build the imag
|
|||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
--port 8321
|
||||
```
|
||||
|
||||
## Example Notebooks
|
||||
|
|
|
@ -103,6 +103,9 @@ metadata_store:
|
|||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
|
||||
conversations_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/conversations.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -181,6 +181,7 @@ class RunConfigSettings(BaseModel):
|
|||
default_benchmarks: list[BenchmarkInput] | None = None
|
||||
metadata_store: dict | None = None
|
||||
inference_store: dict | None = None
|
||||
conversations_store: dict | None = None
|
||||
|
||||
def run_config(
|
||||
self,
|
||||
|
@ -240,6 +241,11 @@ class RunConfigSettings(BaseModel):
|
|||
__distro_dir__=f"~/.llama/distributions/{name}",
|
||||
db_name="inference_store.db",
|
||||
),
|
||||
"conversations_store": self.conversations_store
|
||||
or SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=f"~/.llama/distributions/{name}",
|
||||
db_name="conversations.db",
|
||||
),
|
||||
"models": [m.model_dump(exclude_none=True) for m in (self.default_models or [])],
|
||||
"shields": [s.model_dump(exclude_none=True) for s in (self.default_shields or [])],
|
||||
"vector_dbs": [],
|
||||
|
|
|
@ -3,3 +3,5 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .watsonx import get_distribution_template # noqa: F401
|
||||
|
|
|
@ -3,44 +3,33 @@ distribution_spec:
|
|||
description: Use watsonx for running LLM inference
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: watsonx
|
||||
provider_type: remote::watsonx
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
- provider_type: remote::watsonx
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
- provider_type: inline::faiss
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
- provider_type: inline::llama-guard
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
- provider_type: remote::huggingface
|
||||
- provider_type: inline::localfs
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
- provider_type: inline::basic
|
||||
- provider_type: inline::llm-as-judge
|
||||
- provider_type: inline::braintrust
|
||||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- aiosqlite
|
||||
|
|
|
@ -4,17 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
||||
|
@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
|||
config=WatsonXConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
embedding_provider = Provider(
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
available_models = {
|
||||
"watsonx": MODEL_ENTRIES,
|
||||
}
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
|
@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
|||
),
|
||||
]
|
||||
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
provider_id="sentence-transformers",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
default_models, _ = get_model_registry(available_models)
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="remote_hosted",
|
||||
description="Use watsonx for running LLM inference",
|
||||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider, embedding_provider],
|
||||
"inference": [inference_provider],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_models=default_models + [embedding_model],
|
||||
default_models=[],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
|
|
|
@ -30,13 +30,19 @@ CATEGORIES = [
|
|||
"tools",
|
||||
"client",
|
||||
"telemetry",
|
||||
"openai",
|
||||
"openai_responses",
|
||||
"openai_conversations",
|
||||
"testing",
|
||||
"providers",
|
||||
"models",
|
||||
"files",
|
||||
"vector_io",
|
||||
"tool_runtime",
|
||||
"cli",
|
||||
"post_training",
|
||||
"scoring",
|
||||
"tests",
|
||||
]
|
||||
UNCATEGORIZED = "uncategorized"
|
||||
|
||||
|
@ -128,7 +134,10 @@ def strip_rich_markup(text):
|
|||
|
||||
class CustomRichHandler(RichHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["console"] = Console()
|
||||
# Set a reasonable default width for console output, especially when redirected to files
|
||||
console_width = int(os.environ.get("LLAMA_STACK_LOG_WIDTH", "120"))
|
||||
# Don't force terminal codes to avoid ANSI escape codes in log files
|
||||
kwargs["console"] = Console(width=console_width)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def emit(self, record):
|
||||
|
@ -261,11 +270,12 @@ def get_logger(
|
|||
if root_category in _category_levels:
|
||||
log_level = _category_levels[root_category]
|
||||
else:
|
||||
log_level = _category_levels.get("root", DEFAULT_LOG_LEVEL)
|
||||
if category != UNCATEGORIZED:
|
||||
logging.warning(
|
||||
f"Unknown logging category: {category}. Falling back to default 'root' level: {log_level}"
|
||||
raise ValueError(
|
||||
f"Unknown logging category: {category}. To resolve, choose a valid category from the CATEGORIES list "
|
||||
f"or add it to the CATEGORIES list. Available categories: {CATEGORIES}"
|
||||
)
|
||||
log_level = _category_levels.get("root", DEFAULT_LOG_LEVEL)
|
||||
logger.setLevel(log_level)
|
||||
return logging.LoggerAdapter(logger, {"category": category})
|
||||
|
||||
|
|
|
@ -11,19 +11,13 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import json
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
RawTextItem,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||
|
@ -175,25 +169,6 @@ def llama3_1_builtin_code_interpreter_dialog(tool_prompt_format=ToolPromptFormat
|
|||
return messages
|
||||
|
||||
|
||||
def llama3_1_builtin_tool_call_with_image_dialog(
|
||||
tool_prompt_format=ToolPromptFormat.json,
|
||||
):
|
||||
this_dir = Path(__file__).parent
|
||||
with open(this_dir / "llama3/dog.jpg", "rb") as f:
|
||||
img = f.read()
|
||||
|
||||
interface = LLama31Interface(tool_prompt_format)
|
||||
|
||||
messages = interface.system_messages(**system_message_builtin_tools_only())
|
||||
messages += interface.user_message(content=[RawMediaItem(data=img), RawTextItem(text="What is this dog breed?")])
|
||||
messages += interface.assistant_response_messages(
|
||||
"Based on the description of the dog in the image, it appears to be a small breed dog, possibly a terrier mix",
|
||||
StopReason.end_of_turn,
|
||||
)
|
||||
messages += interface.user_message("Search the web for some food recommendations for the indentified breed")
|
||||
return messages
|
||||
|
||||
|
||||
def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||
interface = LLama31Interface(tool_prompt_format)
|
||||
|
||||
|
@ -202,35 +177,6 @@ def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
|||
return messages
|
||||
|
||||
|
||||
def llama3_1_e2e_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||
tool_response = json.dumps(["great song1", "awesome song2", "cool song3"])
|
||||
interface = LLama31Interface(tool_prompt_format)
|
||||
|
||||
messages = interface.system_messages(**system_message_custom_tools_only())
|
||||
messages += interface.user_message(content="Use tools to get latest trending songs")
|
||||
messages.append(
|
||||
RawMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="call_id",
|
||||
tool_name="trending_songs",
|
||||
arguments={"n": "10", "genre": "latest"},
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
messages.append(
|
||||
RawMessage(
|
||||
role="assistant",
|
||||
content=tool_response,
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
def llama3_2_user_assistant_conversation():
|
||||
return UseCase(
|
||||
title="User and assistant conversation",
|
||||
|
|
|
@ -9,7 +9,7 @@ from pathlib import Path
|
|||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, "tokenizer_utils")
|
||||
logger = get_logger(__name__, "models")
|
||||
|
||||
|
||||
def load_bpe_file(model_path: Path) -> dict[bytes, int]:
|
||||
|
|
|
@ -21,7 +21,9 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
|
|||
deps[Api.safety],
|
||||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
deps[Api.conversations],
|
||||
policy,
|
||||
Api.telemetry in deps,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -7,8 +7,6 @@
|
|||
import copy
|
||||
import json
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
|
@ -51,6 +49,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
|
@ -84,11 +83,6 @@ from llama_stack.providers.utils.telemetry import tracing
|
|||
from .persistence import AgentPersistence
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||
|
||||
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
MEMORY_QUERY_TOOL = "knowledge_search"
|
||||
WEB_SEARCH_TOOL = "web_search"
|
||||
|
@ -110,6 +104,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
persistence_store: KVStore,
|
||||
created_at: str,
|
||||
policy: list[AccessRule],
|
||||
telemetry_enabled: bool = False,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.agent_config = agent_config
|
||||
|
@ -120,6 +115,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.created_at = created_at
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
|
||||
ShieldRunnerMixin.__init__(
|
||||
self,
|
||||
|
@ -188,28 +184,30 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
turn_id = str(uuid.uuid4())
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
if self.telemetry_enabled:
|
||||
span = tracing.get_current_span()
|
||||
if span is not None:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async for chunk in self._run_turn(request, turn_id):
|
||||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
if self.telemetry_enabled:
|
||||
span = tracing.get_current_span()
|
||||
if span is not None:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
await self._initialize_tools()
|
||||
async for chunk in self._run_turn(request):
|
||||
|
@ -395,9 +393,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
async with tracing.span("run_shields") as span:
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if self.telemetry_enabled and span is not None:
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
|
@ -430,7 +431,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute("output", e.violation.model_dump_json())
|
||||
if self.telemetry_enabled and span is not None:
|
||||
span.set_attribute("output", e.violation.model_dump_json())
|
||||
|
||||
yield CompletionMessage(
|
||||
content=str(e),
|
||||
|
@ -453,7 +455,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute("output", "no violations")
|
||||
if self.telemetry_enabled and span is not None:
|
||||
span.set_attribute("output", "no violations")
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
|
@ -518,8 +521,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
stop_reason: StopReason | None = None
|
||||
|
||||
async with tracing.span("inference") as span:
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
if self.telemetry_enabled and span is not None:
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
def _serialize_nested(value):
|
||||
"""Recursively serialize nested Pydantic models to dicts."""
|
||||
|
@ -579,7 +583,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
max_tokens = getattr(sampling_params, "max_tokens", None)
|
||||
|
||||
# Use OpenAI chat completion
|
||||
openai_stream = await self.inference_api.openai_chat_completion(
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.agent_config.model,
|
||||
messages=openai_messages,
|
||||
tools=openai_tools if openai_tools else None,
|
||||
|
@ -590,6 +594,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
)
|
||||
openai_stream = await self.inference_api.openai_chat_completion(params)
|
||||
|
||||
# Convert OpenAI stream back to Llama Stack format
|
||||
response_stream = convert_openai_chat_completion_stream(
|
||||
|
@ -637,18 +642,19 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
raise ValueError(f"Unexpected delta type {type(delta)}")
|
||||
|
||||
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
|
||||
span.set_attribute(
|
||||
"input",
|
||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||
)
|
||||
output_attr = json.dumps(
|
||||
{
|
||||
"content": content,
|
||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
||||
}
|
||||
)
|
||||
span.set_attribute("output", output_attr)
|
||||
if self.telemetry_enabled and span is not None:
|
||||
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
|
||||
span.set_attribute(
|
||||
"input",
|
||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||
)
|
||||
output_attr = json.dumps(
|
||||
{
|
||||
"content": content,
|
||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
||||
}
|
||||
)
|
||||
span.set_attribute("output", output_attr)
|
||||
|
||||
n_iter += 1
|
||||
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
|
||||
|
@ -756,7 +762,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
{
|
||||
"tool_name": tool_call.tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
}
|
||||
if self.telemetry_enabled
|
||||
else {},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now(UTC).isoformat()
|
||||
tool_result = await self.execute_tool_call_maybe(
|
||||
|
@ -771,7 +779,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
call_id=tool_call.call_id,
|
||||
content=tool_result.content,
|
||||
)
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
if self.telemetry_enabled and span is not None:
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
# Store tool execution step
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
|
|
|
@ -30,6 +30,7 @@ from llama_stack.apis.agents import (
|
|||
)
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
|
@ -63,7 +64,9 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
conversations_api: Conversations,
|
||||
policy: list[AccessRule],
|
||||
telemetry_enabled: bool = False,
|
||||
):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
|
@ -71,6 +74,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.safety_api = safety_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.conversations_api = conversations_api
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
||||
|
@ -86,6 +91,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tool_runtime_api=self.tool_runtime_api,
|
||||
responses_store=self.responses_store,
|
||||
vector_io_api=self.vector_io_api,
|
||||
conversations_api=self.conversations_api,
|
||||
)
|
||||
|
||||
async def create_agent(
|
||||
|
@ -135,6 +141,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
),
|
||||
created_at=agent_info.created_at,
|
||||
policy=self.policy,
|
||||
telemetry_enabled=self.telemetry_enabled,
|
||||
)
|
||||
|
||||
async def create_agent_session(
|
||||
|
@ -322,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
|
@ -336,6 +344,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
model,
|
||||
instructions,
|
||||
previous_response_id,
|
||||
conversation,
|
||||
store,
|
||||
stream,
|
||||
temperature,
|
||||
|
|
|
@ -24,6 +24,11 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
from llama_stack.apis.common.errors import (
|
||||
InvalidConversationIdError,
|
||||
)
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.conversations.conversations import ConversationItem
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIMessageParam,
|
||||
|
@ -39,7 +44,7 @@ from llama_stack.providers.utils.responses.responses_store import (
|
|||
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
from .tool_executor import ToolExecutor
|
||||
from .types import ChatCompletionContext
|
||||
from .types import ChatCompletionContext, ToolContext
|
||||
from .utils import (
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
|
@ -61,12 +66,14 @@ class OpenAIResponsesImpl:
|
|||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
vector_io_api: VectorIO, # VectorIO
|
||||
conversations_api: Conversations,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.responses_store = responses_store
|
||||
self.vector_io_api = vector_io_api
|
||||
self.conversations_api = conversations_api
|
||||
self.tool_executor = ToolExecutor(
|
||||
tool_groups_api=tool_groups_api,
|
||||
tool_runtime_api=tool_runtime_api,
|
||||
|
@ -91,13 +98,15 @@ class OpenAIResponsesImpl:
|
|||
async def _process_input_with_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
previous_response_id: str | None,
|
||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
|
||||
"""Process input with optional previous response context.
|
||||
|
||||
Returns:
|
||||
tuple: (all_input for storage, messages for chat completion)
|
||||
tuple: (all_input for storage, messages for chat completion, tool context)
|
||||
"""
|
||||
tool_context = ToolContext(tools)
|
||||
if previous_response_id:
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
|
||||
await self.responses_store.get_response_object(previous_response_id)
|
||||
|
@ -108,16 +117,18 @@ class OpenAIResponsesImpl:
|
|||
# Use stored messages directly and convert only new input
|
||||
message_adapter = TypeAdapter(list[OpenAIMessageParam])
|
||||
messages = message_adapter.validate_python(previous_response.messages)
|
||||
new_messages = await convert_response_input_to_chat_messages(input)
|
||||
new_messages = await convert_response_input_to_chat_messages(input, previous_messages=messages)
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
# Backward compatibility: reconstruct from inputs
|
||||
messages = await convert_response_input_to_chat_messages(all_input)
|
||||
|
||||
tool_context.recover_tools_from_previous_response(previous_response)
|
||||
else:
|
||||
all_input = input
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
|
||||
return all_input, messages
|
||||
return all_input, messages, tool_context
|
||||
|
||||
async def _prepend_instructions(self, messages, instructions):
|
||||
if instructions:
|
||||
|
@ -201,6 +212,7 @@ class OpenAIResponsesImpl:
|
|||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
|
@ -217,11 +229,27 @@ class OpenAIResponsesImpl:
|
|||
if shields is not None:
|
||||
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
|
||||
|
||||
if conversation is not None and previous_response_id is not None:
|
||||
raise ValueError(
|
||||
"Mutually exclusive parameters: 'previous_response_id' and 'conversation'. Ensure you are only providing one of these parameters."
|
||||
)
|
||||
|
||||
original_input = input # needed for syncing to Conversations
|
||||
if conversation is not None:
|
||||
if not conversation.startswith("conv_"):
|
||||
raise InvalidConversationIdError(conversation)
|
||||
|
||||
# Check conversation exists (raises ConversationNotFoundError if not)
|
||||
_ = await self.conversations_api.get_conversation(conversation)
|
||||
input = await self._load_conversation_context(conversation, input)
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
original_input=original_input,
|
||||
model=model,
|
||||
instructions=instructions,
|
||||
previous_response_id=previous_response_id,
|
||||
conversation=conversation,
|
||||
store=store,
|
||||
temperature=temperature,
|
||||
text=text,
|
||||
|
@ -232,24 +260,42 @@ class OpenAIResponsesImpl:
|
|||
if stream:
|
||||
return stream_gen
|
||||
else:
|
||||
response = None
|
||||
async for stream_chunk in stream_gen:
|
||||
if stream_chunk.type == "response.completed":
|
||||
if response is not None:
|
||||
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
|
||||
response = stream_chunk.response
|
||||
# don't leave the generator half complete!
|
||||
final_response = None
|
||||
final_event_type = None
|
||||
failed_response = None
|
||||
|
||||
if response is None:
|
||||
raise ValueError("The response stream never completed")
|
||||
return response
|
||||
async for stream_chunk in stream_gen:
|
||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
||||
if final_response is not None:
|
||||
raise ValueError(
|
||||
"The response stream produced multiple terminal responses! "
|
||||
f"Earlier response from {final_event_type}"
|
||||
)
|
||||
final_response = stream_chunk.response
|
||||
final_event_type = stream_chunk.type
|
||||
elif stream_chunk.type == "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
|
||||
if failed_response is not None:
|
||||
error_message = (
|
||||
failed_response.error.message
|
||||
if failed_response and failed_response.error
|
||||
else "Response stream failed without error details"
|
||||
)
|
||||
raise RuntimeError(f"OpenAI response failed: {error_message}")
|
||||
|
||||
if final_response is None:
|
||||
raise ValueError("The response stream never reached a terminal state")
|
||||
return final_response
|
||||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
original_input: str | list[OpenAIResponseInput] | None = None,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
store: bool | None = True,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
|
@ -257,7 +303,9 @@ class OpenAIResponsesImpl:
|
|||
max_infer_iters: int | None = 10,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
|
||||
all_input, messages, tool_context = await self._process_input_with_previous_response(
|
||||
input, tools, previous_response_id
|
||||
)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Structured outputs
|
||||
|
@ -269,11 +317,12 @@ class OpenAIResponsesImpl:
|
|||
response_tools=tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
inputs=input,
|
||||
tool_context=tool_context,
|
||||
inputs=all_input,
|
||||
)
|
||||
|
||||
# Create orchestrator and delegate streaming logic
|
||||
response_id = f"resp-{uuid.uuid4()}"
|
||||
response_id = f"resp_{uuid.uuid4()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
orchestrator = StreamingResponseOrchestrator(
|
||||
|
@ -288,18 +337,110 @@ class OpenAIResponsesImpl:
|
|||
|
||||
# Stream the response
|
||||
final_response = None
|
||||
failed_response = None
|
||||
async for stream_chunk in orchestrator.create_response():
|
||||
if stream_chunk.type == "response.completed":
|
||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
||||
final_response = stream_chunk.response
|
||||
elif stream_chunk.type == "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
yield stream_chunk
|
||||
|
||||
# Store the response if requested
|
||||
if store and final_response:
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=all_input,
|
||||
messages=orchestrator.final_messages,
|
||||
)
|
||||
# Store and sync immediately after yielding terminal events
|
||||
# This ensures the storage/syncing happens even if the consumer breaks early
|
||||
if (
|
||||
stream_chunk.type in {"response.completed", "response.incomplete"}
|
||||
and store
|
||||
and final_response
|
||||
and failed_response is None
|
||||
):
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=all_input,
|
||||
messages=orchestrator.final_messages,
|
||||
)
|
||||
|
||||
if stream_chunk.type in {"response.completed", "response.incomplete"} and conversation and final_response:
|
||||
# for Conversations, we need to use the original_input if it's available, otherwise use input
|
||||
sync_input = original_input if original_input is not None else input
|
||||
await self._sync_response_to_conversation(conversation, sync_input, final_response)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
return await self.responses_store.delete_response_object(response_id)
|
||||
|
||||
async def _load_conversation_context(
|
||||
self, conversation_id: str, content: str | list[OpenAIResponseInput]
|
||||
) -> list[OpenAIResponseInput]:
|
||||
"""Load conversation history and merge with provided content."""
|
||||
conversation_items = await self.conversations_api.list(conversation_id, order="asc")
|
||||
|
||||
context_messages = []
|
||||
for item in conversation_items.data:
|
||||
if isinstance(item, OpenAIResponseMessage):
|
||||
if item.role == "user":
|
||||
context_messages.append(
|
||||
OpenAIResponseMessage(
|
||||
role="user", content=item.content, id=item.id if hasattr(item, "id") else None
|
||||
)
|
||||
)
|
||||
elif item.role == "assistant":
|
||||
context_messages.append(
|
||||
OpenAIResponseMessage(
|
||||
role="assistant", content=item.content, id=item.id if hasattr(item, "id") else None
|
||||
)
|
||||
)
|
||||
|
||||
# add new content to context
|
||||
if isinstance(content, str):
|
||||
context_messages.append(OpenAIResponseMessage(role="user", content=content))
|
||||
elif isinstance(content, list):
|
||||
context_messages.extend(content)
|
||||
|
||||
return context_messages
|
||||
|
||||
async def _sync_response_to_conversation(
|
||||
self, conversation_id: str, content: str | list[OpenAIResponseInput], response: OpenAIResponseObject
|
||||
) -> None:
|
||||
"""Sync content and response messages to the conversation."""
|
||||
conversation_items = []
|
||||
|
||||
# add user content message(s)
|
||||
if isinstance(content, str):
|
||||
conversation_items.append(
|
||||
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if not isinstance(item, OpenAIResponseMessage):
|
||||
raise NotImplementedError(f"Unsupported input item type: {type(item)}")
|
||||
|
||||
if item.role == "user":
|
||||
if isinstance(item.content, str):
|
||||
conversation_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": item.content}],
|
||||
}
|
||||
)
|
||||
elif isinstance(item.content, list):
|
||||
conversation_items.append({"type": "message", "role": "user", "content": item.content})
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported user message content type: {type(item.content)}")
|
||||
elif item.role == "assistant":
|
||||
if isinstance(item.content, list):
|
||||
conversation_items.append({"type": "message", "role": "assistant", "content": item.content})
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported assistant message content type: {type(item.content)}")
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported message role: {item.role}")
|
||||
|
||||
# add assistant response message
|
||||
for output_item in response.output:
|
||||
if isinstance(output_item, OpenAIResponseMessage) and output_item.role == "assistant":
|
||||
if hasattr(output_item, "content") and isinstance(output_item.content, list):
|
||||
conversation_items.append({"type": "message", "role": "assistant", "content": output_item.content})
|
||||
|
||||
if conversation_items:
|
||||
adapter = TypeAdapter(list[ConversationItem])
|
||||
validated_items = adapter.validate_python(conversation_items)
|
||||
await self.conversations_api.add_items(conversation_id, validated_items)
|
||||
|
|
|
@ -13,6 +13,9 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
ApprovalFilter,
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseContentPartReasoningText,
|
||||
OpenAIResponseContentPartRefusal,
|
||||
OpenAIResponseError,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
|
@ -22,8 +25,11 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObjectStreamResponseContentPartAdded,
|
||||
OpenAIResponseObjectStreamResponseContentPartDone,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseObjectStreamResponseFailed,
|
||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
|
||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
|
||||
OpenAIResponseObjectStreamResponseIncomplete,
|
||||
OpenAIResponseObjectStreamResponseInProgress,
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
|
||||
OpenAIResponseObjectStreamResponseMcpListToolsCompleted,
|
||||
|
@ -31,21 +37,31 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObjectStreamResponseOutputItemAdded,
|
||||
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||
OpenAIResponseObjectStreamResponseReasoningTextDelta,
|
||||
OpenAIResponseObjectStreamResponseReasoningTextDone,
|
||||
OpenAIResponseObjectStreamResponseRefusalDelta,
|
||||
OpenAIResponseObjectStreamResponseRefusalDone,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseUsage,
|
||||
OpenAIResponseUsageInputTokensDetails,
|
||||
OpenAIResponseUsageOutputTokensDetails,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
OpenAIMessageParam,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||
|
@ -94,113 +110,174 @@ class StreamingResponseOrchestrator:
|
|||
self.tool_executor = tool_executor
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
|
||||
# Track final messages after all tool executions
|
||||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
# mapping for annotations
|
||||
self.citation_files: dict[str, str] = {}
|
||||
# Track accumulated usage across all inference calls
|
||||
self.accumulated_usage: OpenAIResponseUsage | None = None
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Initialize output messages
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
# Create initial response and emit response.created immediately
|
||||
initial_response = OpenAIResponseObject(
|
||||
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
|
||||
cloned: list[OpenAIResponseOutput] = []
|
||||
for item in outputs:
|
||||
if hasattr(item, "model_copy"):
|
||||
cloned.append(item.model_copy(deep=True))
|
||||
else:
|
||||
cloned.append(item)
|
||||
return cloned
|
||||
|
||||
def _snapshot_response(
|
||||
self,
|
||||
status: str,
|
||||
outputs: list[OpenAIResponseOutput],
|
||||
*,
|
||||
error: OpenAIResponseError | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return OpenAIResponseObject(
|
||||
created_at=self.created_at,
|
||||
id=self.response_id,
|
||||
model=self.ctx.model,
|
||||
object="response",
|
||||
status="in_progress",
|
||||
output=output_messages.copy(),
|
||||
status=status,
|
||||
output=self._clone_outputs(outputs),
|
||||
text=self.text,
|
||||
tools=self.ctx.available_tools(),
|
||||
error=error,
|
||||
usage=self.accumulated_usage,
|
||||
)
|
||||
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
|
||||
# Process all tools (including MCP tools) and emit streaming events
|
||||
if self.ctx.response_tools:
|
||||
async for stream_event in self._process_tools(self.ctx.response_tools, output_messages):
|
||||
yield stream_event
|
||||
# Emit response.created followed by response.in_progress to align with OpenAI streaming
|
||||
yield OpenAIResponseObjectStreamResponseCreated(
|
||||
response=self._snapshot_response("in_progress", output_messages)
|
||||
)
|
||||
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseInProgress(
|
||||
response=self._snapshot_response("in_progress", output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async for stream_event in self._process_tools(output_messages):
|
||||
yield stream_event
|
||||
|
||||
n_iter = 0
|
||||
messages = self.ctx.messages.copy()
|
||||
final_status = "completed"
|
||||
last_completion_result: ChatCompletionResult | None = None
|
||||
|
||||
while True:
|
||||
# Text is the default response format for chat completion so don't need to pass it
|
||||
# (some providers don't support non-empty response_format when tools are present)
|
||||
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
|
||||
completion_result = await self.inference_api.openai_chat_completion(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
tools=self.ctx.chat_tools,
|
||||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=response_format,
|
||||
)
|
||||
try:
|
||||
while True:
|
||||
# Text is the default response format for chat completion so don't need to pass it
|
||||
# (some providers don't support non-empty response_format when tools are present)
|
||||
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
|
||||
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
tools=self.ctx.chat_tools,
|
||||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=response_format,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
},
|
||||
)
|
||||
completion_result = await self.inference_api.openai_chat_completion(params)
|
||||
|
||||
# Process streaming chunks and build complete response
|
||||
completion_result_data = None
|
||||
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
|
||||
if isinstance(stream_event_or_result, ChatCompletionResult):
|
||||
completion_result_data = stream_event_or_result
|
||||
else:
|
||||
yield stream_event_or_result
|
||||
if not completion_result_data:
|
||||
raise ValueError("Streaming chunk processor failed to return completion data")
|
||||
current_response = self._build_chat_completion(completion_result_data)
|
||||
# Process streaming chunks and build complete response
|
||||
completion_result_data = None
|
||||
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
|
||||
if isinstance(stream_event_or_result, ChatCompletionResult):
|
||||
completion_result_data = stream_event_or_result
|
||||
else:
|
||||
yield stream_event_or_result
|
||||
if not completion_result_data:
|
||||
raise ValueError("Streaming chunk processor failed to return completion data")
|
||||
last_completion_result = completion_result_data
|
||||
current_response = self._build_chat_completion(completion_result_data)
|
||||
|
||||
function_tool_calls, non_function_tool_calls, approvals, next_turn_messages = self._separate_tool_calls(
|
||||
current_response, messages
|
||||
)
|
||||
(
|
||||
function_tool_calls,
|
||||
non_function_tool_calls,
|
||||
approvals,
|
||||
next_turn_messages,
|
||||
) = self._separate_tool_calls(current_response, messages)
|
||||
|
||||
# add any approval requests required
|
||||
for tool_call in approvals:
|
||||
async for evt in self._add_mcp_approval_request(
|
||||
tool_call.function.name, tool_call.function.arguments, output_messages
|
||||
# add any approval requests required
|
||||
for tool_call in approvals:
|
||||
async for evt in self._add_mcp_approval_request(
|
||||
tool_call.function.name, tool_call.function.arguments, output_messages
|
||||
):
|
||||
yield evt
|
||||
|
||||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
output_messages.append(
|
||||
await convert_chat_choice_to_response_message(
|
||||
choice,
|
||||
self.citation_files,
|
||||
message_id=completion_result_data.message_item_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Execute tool calls and coordinate results
|
||||
async for stream_event in self._coordinate_tool_execution(
|
||||
function_tool_calls,
|
||||
non_function_tool_calls,
|
||||
completion_result_data,
|
||||
output_messages,
|
||||
next_turn_messages,
|
||||
):
|
||||
yield evt
|
||||
yield stream_event
|
||||
|
||||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
output_messages.append(await convert_chat_choice_to_response_message(choice))
|
||||
messages = next_turn_messages
|
||||
|
||||
# Execute tool calls and coordinate results
|
||||
async for stream_event in self._coordinate_tool_execution(
|
||||
function_tool_calls,
|
||||
non_function_tool_calls,
|
||||
completion_result_data,
|
||||
output_messages,
|
||||
next_turn_messages,
|
||||
):
|
||||
yield stream_event
|
||||
if not function_tool_calls and not non_function_tool_calls:
|
||||
break
|
||||
|
||||
if not function_tool_calls and not non_function_tool_calls:
|
||||
break
|
||||
if function_tool_calls:
|
||||
logger.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||
break
|
||||
|
||||
if function_tool_calls:
|
||||
logger.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||
break
|
||||
n_iter += 1
|
||||
if n_iter >= self.max_infer_iters:
|
||||
logger.info(
|
||||
f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}"
|
||||
)
|
||||
final_status = "incomplete"
|
||||
break
|
||||
|
||||
n_iter += 1
|
||||
if n_iter >= self.max_infer_iters:
|
||||
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
|
||||
break
|
||||
if last_completion_result and last_completion_result.finish_reason == "length":
|
||||
final_status = "incomplete"
|
||||
|
||||
messages = next_turn_messages
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.final_messages = messages.copy()
|
||||
self.sequence_number += 1
|
||||
error = OpenAIResponseError(code="internal_error", message=str(exc))
|
||||
failure_response = self._snapshot_response("failed", output_messages, error=error)
|
||||
yield OpenAIResponseObjectStreamResponseFailed(
|
||||
response=failure_response,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
return
|
||||
|
||||
self.final_messages = messages.copy() + [current_response.choices[0].message]
|
||||
self.final_messages = messages.copy()
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
created_at=self.created_at,
|
||||
id=self.response_id,
|
||||
model=self.ctx.model,
|
||||
object="response",
|
||||
status="completed",
|
||||
text=self.text,
|
||||
output=output_messages,
|
||||
)
|
||||
|
||||
# Emit response.completed
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
if final_status == "incomplete":
|
||||
self.sequence_number += 1
|
||||
final_response = self._snapshot_response("incomplete", output_messages)
|
||||
yield OpenAIResponseObjectStreamResponseIncomplete(
|
||||
response=final_response,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
else:
|
||||
final_response = self._snapshot_response("completed", output_messages)
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
|
||||
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list, list]:
|
||||
"""Separate tool calls into function and non-function categories."""
|
||||
|
@ -211,6 +288,8 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
for choice in current_response.choices:
|
||||
next_turn_messages.append(choice.message)
|
||||
logger.debug(f"Choice message content: {choice.message.content}")
|
||||
logger.debug(f"Choice message tool_calls: {choice.message.tool_calls}")
|
||||
|
||||
if choice.message.tool_calls and self.ctx.response_tools:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
|
@ -227,14 +306,183 @@ class StreamingResponseOrchestrator:
|
|||
non_function_tool_calls.append(tool_call)
|
||||
else:
|
||||
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
|
||||
next_turn_messages.pop()
|
||||
else:
|
||||
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
|
||||
approvals.append(tool_call)
|
||||
next_turn_messages.pop()
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
|
||||
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
|
||||
|
||||
def _accumulate_chunk_usage(self, chunk: OpenAIChatCompletionChunk) -> None:
|
||||
"""Accumulate usage from a streaming chunk into the response usage format."""
|
||||
if not chunk.usage:
|
||||
return
|
||||
|
||||
if self.accumulated_usage is None:
|
||||
# Convert from chat completion format to response format
|
||||
self.accumulated_usage = OpenAIResponseUsage(
|
||||
input_tokens=chunk.usage.prompt_tokens,
|
||||
output_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
input_tokens_details=(
|
||||
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
|
||||
if chunk.usage.prompt_tokens_details
|
||||
else None
|
||||
),
|
||||
output_tokens_details=(
|
||||
OpenAIResponseUsageOutputTokensDetails(
|
||||
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
||||
)
|
||||
if chunk.usage.completion_tokens_details
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Accumulate across multiple inference calls
|
||||
self.accumulated_usage = OpenAIResponseUsage(
|
||||
input_tokens=self.accumulated_usage.input_tokens + chunk.usage.prompt_tokens,
|
||||
output_tokens=self.accumulated_usage.output_tokens + chunk.usage.completion_tokens,
|
||||
total_tokens=self.accumulated_usage.total_tokens + chunk.usage.total_tokens,
|
||||
# Use latest non-null details
|
||||
input_tokens_details=(
|
||||
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
|
||||
if chunk.usage.prompt_tokens_details
|
||||
else self.accumulated_usage.input_tokens_details
|
||||
),
|
||||
output_tokens_details=(
|
||||
OpenAIResponseUsageOutputTokensDetails(
|
||||
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
||||
)
|
||||
if chunk.usage.completion_tokens_details
|
||||
else self.accumulated_usage.output_tokens_details
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_reasoning_content_chunk(
|
||||
self,
|
||||
reasoning_content: str,
|
||||
reasoning_part_emitted: bool,
|
||||
reasoning_content_index: int,
|
||||
message_item_id: str,
|
||||
message_output_index: int,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Emit content_part.added event for first reasoning chunk
|
||||
if not reasoning_part_emitted:
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartAdded(
|
||||
content_index=reasoning_content_index,
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
part=OpenAIResponseContentPartReasoningText(
|
||||
text="", # Will be filled incrementally via reasoning deltas
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
# Emit reasoning_text.delta event
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseReasoningTextDelta(
|
||||
content_index=reasoning_content_index,
|
||||
delta=reasoning_content,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _handle_refusal_content_chunk(
|
||||
self,
|
||||
refusal_content: str,
|
||||
refusal_part_emitted: bool,
|
||||
refusal_content_index: int,
|
||||
message_item_id: str,
|
||||
message_output_index: int,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Emit content_part.added event for first refusal chunk
|
||||
if not refusal_part_emitted:
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartAdded(
|
||||
content_index=refusal_content_index,
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
part=OpenAIResponseContentPartRefusal(
|
||||
refusal="", # Will be filled incrementally via refusal deltas
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
# Emit refusal.delta event
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseRefusalDelta(
|
||||
content_index=refusal_content_index,
|
||||
delta=refusal_content,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _emit_reasoning_done_events(
|
||||
self,
|
||||
reasoning_text_accumulated: list[str],
|
||||
reasoning_content_index: int,
|
||||
message_item_id: str,
|
||||
message_output_index: int,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
final_reasoning_text = "".join(reasoning_text_accumulated)
|
||||
# Emit reasoning_text.done event
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseReasoningTextDone(
|
||||
content_index=reasoning_content_index,
|
||||
text=final_reasoning_text,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
# Emit content_part.done for reasoning
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartDone(
|
||||
content_index=reasoning_content_index,
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
part=OpenAIResponseContentPartReasoningText(
|
||||
text=final_reasoning_text,
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _emit_refusal_done_events(
|
||||
self,
|
||||
refusal_text_accumulated: list[str],
|
||||
refusal_content_index: int,
|
||||
message_item_id: str,
|
||||
message_output_index: int,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
final_refusal_text = "".join(refusal_text_accumulated)
|
||||
# Emit refusal.done event
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseRefusalDone(
|
||||
content_index=refusal_content_index,
|
||||
refusal=final_refusal_text,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
# Emit content_part.done for refusal
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartDone(
|
||||
content_index=refusal_content_index,
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
part=OpenAIResponseContentPartRefusal(
|
||||
refusal=final_refusal_text,
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _process_streaming_chunks(
|
||||
self, completion_result, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]:
|
||||
|
@ -253,11 +501,23 @@ class StreamingResponseOrchestrator:
|
|||
tool_call_item_ids: dict[int, str] = {}
|
||||
# Track content parts for streaming events
|
||||
content_part_emitted = False
|
||||
reasoning_part_emitted = False
|
||||
refusal_part_emitted = False
|
||||
content_index = 0
|
||||
reasoning_content_index = 1 # reasoning is a separate content part
|
||||
refusal_content_index = 2 # refusal is a separate content part
|
||||
message_output_index = len(output_messages)
|
||||
reasoning_text_accumulated = []
|
||||
refusal_text_accumulated = []
|
||||
|
||||
async for chunk in completion_result:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
|
||||
# Accumulate usage from chunks (typically in final chunk with stream_options)
|
||||
self._accumulate_chunk_usage(chunk)
|
||||
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
|
@ -266,8 +526,10 @@ class StreamingResponseOrchestrator:
|
|||
content_part_emitted = True
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartAdded(
|
||||
content_index=content_index,
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
part=OpenAIResponseContentPartOutputText(
|
||||
text="", # Will be filled incrementally via text deltas
|
||||
),
|
||||
|
@ -275,10 +537,10 @@ class StreamingResponseOrchestrator:
|
|||
)
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=0,
|
||||
content_index=content_index,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=0,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
|
@ -287,6 +549,32 @@ class StreamingResponseOrchestrator:
|
|||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Handle reasoning content if present (non-standard field for o1/o3 models)
|
||||
if hasattr(chunk_choice.delta, "reasoning_content") and chunk_choice.delta.reasoning_content:
|
||||
async for event in self._handle_reasoning_content_chunk(
|
||||
reasoning_content=chunk_choice.delta.reasoning_content,
|
||||
reasoning_part_emitted=reasoning_part_emitted,
|
||||
reasoning_content_index=reasoning_content_index,
|
||||
message_item_id=message_item_id,
|
||||
message_output_index=message_output_index,
|
||||
):
|
||||
yield event
|
||||
reasoning_part_emitted = True
|
||||
reasoning_text_accumulated.append(chunk_choice.delta.reasoning_content)
|
||||
|
||||
# Handle refusal content if present
|
||||
if chunk_choice.delta.refusal:
|
||||
async for event in self._handle_refusal_content_chunk(
|
||||
refusal_content=chunk_choice.delta.refusal,
|
||||
refusal_part_emitted=refusal_part_emitted,
|
||||
refusal_content_index=refusal_content_index,
|
||||
message_item_id=message_item_id,
|
||||
message_output_index=message_output_index,
|
||||
):
|
||||
yield event
|
||||
refusal_part_emitted = True
|
||||
refusal_text_accumulated.append(chunk_choice.delta.refusal)
|
||||
|
||||
# Aggregate tool call arguments across chunks
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
|
@ -378,14 +666,36 @@ class StreamingResponseOrchestrator:
|
|||
final_text = "".join(chat_response_content)
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartDone(
|
||||
content_index=content_index,
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
part=OpenAIResponseContentPartOutputText(
|
||||
text=final_text,
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit reasoning done events if reasoning content was streamed
|
||||
if reasoning_part_emitted:
|
||||
async for event in self._emit_reasoning_done_events(
|
||||
reasoning_text_accumulated=reasoning_text_accumulated,
|
||||
reasoning_content_index=reasoning_content_index,
|
||||
message_item_id=message_item_id,
|
||||
message_output_index=message_output_index,
|
||||
):
|
||||
yield event
|
||||
|
||||
# Emit refusal done events if refusal content was streamed
|
||||
if refusal_part_emitted:
|
||||
async for event in self._emit_refusal_done_events(
|
||||
refusal_text_accumulated=refusal_text_accumulated,
|
||||
refusal_content_index=refusal_content_index,
|
||||
message_item_id=message_item_id,
|
||||
message_output_index=message_output_index,
|
||||
):
|
||||
yield event
|
||||
|
||||
# Clear content when there are tool calls (OpenAI spec behavior)
|
||||
if chat_response_tool_calls:
|
||||
chat_response_content = []
|
||||
|
@ -470,6 +780,8 @@ class StreamingResponseOrchestrator:
|
|||
tool_call_log = result.final_output_message
|
||||
tool_response_message = result.final_input_message
|
||||
self.sequence_number = result.sequence_number
|
||||
if result.citation_files:
|
||||
self.citation_files.update(result.citation_files)
|
||||
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
|
@ -518,7 +830,7 @@ class StreamingResponseOrchestrator:
|
|||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _process_tools(
|
||||
async def _process_new_tools(
|
||||
self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Process all tools and emit appropriate streaming events."""
|
||||
|
@ -573,7 +885,6 @@ class StreamingResponseOrchestrator:
|
|||
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse allowed/never allowed tools
|
||||
always_allowed = None
|
||||
|
@ -586,14 +897,22 @@ class StreamingResponseOrchestrator:
|
|||
never_allowed = mcp_tool.allowed_tools.never
|
||||
|
||||
# Call list_mcp_tools
|
||||
tool_defs = await list_mcp_tools(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
)
|
||||
tool_defs = None
|
||||
list_id = f"mcp_list_{uuid.uuid4()}"
|
||||
attributes = {
|
||||
"server_label": mcp_tool.server_label,
|
||||
"server_url": mcp_tool.server_url,
|
||||
"mcp_list_tools_id": list_id,
|
||||
}
|
||||
async with tracing.span("list_mcp_tools", attributes):
|
||||
tool_defs = await list_mcp_tools(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
)
|
||||
|
||||
# Create the MCP list tools message
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
id=list_id,
|
||||
server_label=mcp_tool.server_label,
|
||||
tools=[],
|
||||
)
|
||||
|
@ -627,39 +946,26 @@ class StreamingResponseOrchestrator:
|
|||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add the MCP list message to output
|
||||
output_messages.append(mcp_list_message)
|
||||
|
||||
# Emit output_item.added for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit mcp_list_tools.completed
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit output_item.done for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
|
||||
yield stream_event
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Emit mcp_list_tools.failed event if needed
|
||||
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
|
||||
raise
|
||||
|
||||
async def _process_tools(
|
||||
self, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Handle all mcp tool lists from previous response that are still valid:
|
||||
for tool in self.ctx.tool_context.previous_tool_listings:
|
||||
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
|
||||
yield evt
|
||||
# Process all remaining tools (including MCP tools) and emit streaming events
|
||||
if self.ctx.tool_context.tools_to_process:
|
||||
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
|
||||
yield stream_event
|
||||
|
||||
def _approval_required(self, tool_name: str) -> bool:
|
||||
if tool_name not in self.mcp_tool_to_server:
|
||||
return False
|
||||
|
@ -694,7 +1000,6 @@ class StreamingResponseOrchestrator:
|
|||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
|
@ -702,3 +1007,60 @@ class StreamingResponseOrchestrator:
|
|||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _add_mcp_list_tools(
|
||||
self, mcp_list_message: OpenAIResponseOutputMessageMCPListTools, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Add the MCP list message to output
|
||||
output_messages.append(mcp_list_message)
|
||||
|
||||
# Emit output_item.added for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
# Emit mcp_list_tools.completed
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit output_item.done for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _reuse_mcp_list_tools(
|
||||
self, original: OpenAIResponseOutputMessageMCPListTools, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
for t in original.tools:
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
# convert from input_schema to map of ToolParamDefinitions...
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=t.name,
|
||||
description=t.description,
|
||||
input_schema=t.input_schema,
|
||||
)
|
||||
# ...then can convert that to openai completions tool
|
||||
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool)
|
||||
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
server_label=original.server_label,
|
||||
tools=original.tools,
|
||||
)
|
||||
|
||||
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
|
||||
yield stream_event
|
||||
|
|
|
@ -11,6 +11,9 @@ from collections.abc import AsyncIterator
|
|||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseObjectStreamResponseFileSearchCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseFileSearchCallInProgress,
|
||||
OpenAIResponseObjectStreamResponseFileSearchCallSearching,
|
||||
OpenAIResponseObjectStreamResponseMcpCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseMcpCallFailed,
|
||||
OpenAIResponseObjectStreamResponseMcpCallInProgress,
|
||||
|
@ -35,6 +38,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from .types import ChatCompletionContext, ToolExecutionResult
|
||||
|
||||
|
@ -94,7 +98,10 @@ class ToolExecutor:
|
|||
|
||||
# Yield the final result
|
||||
yield ToolExecutionResult(
|
||||
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
|
||||
sequence_number=sequence_number,
|
||||
final_output_message=output_message,
|
||||
final_input_message=input_message,
|
||||
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
|
||||
)
|
||||
|
||||
async def _execute_knowledge_search_via_vector_store(
|
||||
|
@ -129,8 +136,6 @@ class ToolExecutor:
|
|||
for results in all_results:
|
||||
search_results.extend(results)
|
||||
|
||||
# Convert search results to tool result format matching memory.py
|
||||
# Format the results as interleaved content similar to memory.py
|
||||
content_items = []
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
|
@ -138,27 +143,58 @@ class ToolExecutor:
|
|||
)
|
||||
)
|
||||
|
||||
unique_files = set()
|
||||
for i, result_item in enumerate(search_results):
|
||||
chunk_text = result_item.content[0].text if result_item.content else ""
|
||||
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
|
||||
# Get file_id from attributes if result_item.file_id is empty
|
||||
file_id = result_item.file_id or (
|
||||
result_item.attributes.get("document_id") if result_item.attributes else None
|
||||
)
|
||||
metadata_text = f"document_id: {file_id}, score: {result_item.score}"
|
||||
if result_item.attributes:
|
||||
metadata_text += f", attributes: {result_item.attributes}"
|
||||
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
|
||||
|
||||
text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
|
||||
content_items.append(TextContentItem(text=text_content))
|
||||
unique_files.add(file_id)
|
||||
|
||||
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
|
||||
citation_instruction = ""
|
||||
if unique_files:
|
||||
citation_instruction = (
|
||||
" Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'). "
|
||||
"Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)."
|
||||
)
|
||||
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n',
|
||||
)
|
||||
)
|
||||
|
||||
# handling missing attributes for old versions
|
||||
citation_files = {}
|
||||
for result in search_results:
|
||||
file_id = result.file_id
|
||||
if not file_id and result.attributes:
|
||||
file_id = result.attributes.get("document_id")
|
||||
|
||||
filename = result.filename
|
||||
if not filename and result.attributes:
|
||||
filename = result.attributes.get("filename")
|
||||
if not filename:
|
||||
filename = "unknown"
|
||||
|
||||
citation_files[file_id] = filename
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=content_items,
|
||||
metadata={
|
||||
"document_ids": [r.file_id for r in search_results],
|
||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||
"scores": [r.score for r in search_results],
|
||||
"citation_files": citation_files,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -188,7 +224,13 @@ class ToolExecutor:
|
|||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
|
||||
elif function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
if progress_event:
|
||||
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
||||
|
@ -203,6 +245,16 @@ class ToolExecutor:
|
|||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
# For file search, emit searching event
|
||||
if function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
function_name: str,
|
||||
|
@ -219,12 +271,18 @@ class ToolExecutor:
|
|||
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
|
||||
|
||||
mcp_tool = mcp_tool_to_server[function_name]
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
attributes = {
|
||||
"server_label": mcp_tool.server_label,
|
||||
"server_url": mcp_tool.server_url,
|
||||
"tool_name": function_name,
|
||||
}
|
||||
async with tracing.span("invoke_mcp_tool", attributes):
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
response_file_search_tool = next(
|
||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||
|
@ -234,15 +292,20 @@ class ToolExecutor:
|
|||
# Use vector_stores.search API instead of knowledge_search tool
|
||||
# to support filters and ranking_options
|
||||
query = tool_kwargs.get("query", "")
|
||||
result = await self._execute_knowledge_search_via_vector_store(
|
||||
query=query,
|
||||
response_file_search_tool=response_file_search_tool,
|
||||
)
|
||||
async with tracing.span("knowledge_search", {}):
|
||||
result = await self._execute_knowledge_search_via_vector_store(
|
||||
query=query,
|
||||
response_file_search_tool=response_file_search_tool,
|
||||
)
|
||||
else:
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
attributes = {
|
||||
"tool_name": function_name,
|
||||
}
|
||||
async with tracing.span("invoke_tool", attributes):
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
error_exc = e
|
||||
|
||||
|
@ -278,7 +341,13 @@ class ToolExecutor:
|
|||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
|
||||
elif function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
if completion_event:
|
||||
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
||||
|
|
|
@ -12,10 +12,18 @@ from pydantic import BaseModel
|
|||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMCPApprovalResponse,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseTool,
|
||||
OpenAIResponseToolMCP,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
|
||||
|
||||
|
@ -27,6 +35,7 @@ class ToolExecutionResult(BaseModel):
|
|||
sequence_number: int
|
||||
final_output_message: OpenAIResponseOutput | None = None
|
||||
final_input_message: OpenAIMessageParam | None = None
|
||||
citation_files: dict[str, str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -54,6 +63,86 @@ class ChatCompletionResult:
|
|||
return bool(self.tool_calls)
|
||||
|
||||
|
||||
class ToolContext(BaseModel):
|
||||
"""Holds information about tools from this and (if relevant)
|
||||
previous response in order to facilitate reuse of previous
|
||||
listings where appropriate."""
|
||||
|
||||
# tools argument passed into current request:
|
||||
current_tools: list[OpenAIResponseInputTool]
|
||||
# reconstructed map of tool -> mcp server from previous response:
|
||||
previous_tools: dict[str, OpenAIResponseInputToolMCP]
|
||||
# reusable mcp-list-tools objects from previous response:
|
||||
previous_tool_listings: list[OpenAIResponseOutputMessageMCPListTools]
|
||||
# tool arguments from current request that still need to be processed:
|
||||
tools_to_process: list[OpenAIResponseInputTool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
current_tools: list[OpenAIResponseInputTool] | None,
|
||||
):
|
||||
super().__init__(
|
||||
current_tools=current_tools or [],
|
||||
previous_tools={},
|
||||
previous_tool_listings=[],
|
||||
tools_to_process=current_tools or [],
|
||||
)
|
||||
|
||||
def recover_tools_from_previous_response(
|
||||
self,
|
||||
previous_response: OpenAIResponseObject,
|
||||
):
|
||||
"""Determine which mcp_list_tools objects from previous response we can reuse."""
|
||||
|
||||
if self.current_tools and previous_response.tools:
|
||||
previous_tools_by_label: dict[str, OpenAIResponseToolMCP] = {}
|
||||
for tool in previous_response.tools:
|
||||
if isinstance(tool, OpenAIResponseToolMCP):
|
||||
previous_tools_by_label[tool.server_label] = tool
|
||||
# collect tool definitions which are the same in current and previous requests:
|
||||
tools_to_process = []
|
||||
matched: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
for tool in self.current_tools:
|
||||
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
|
||||
previous_tool = previous_tools_by_label[tool.server_label]
|
||||
if previous_tool.allowed_tools == tool.allowed_tools:
|
||||
matched[tool.server_label] = tool
|
||||
else:
|
||||
tools_to_process.append(tool)
|
||||
else:
|
||||
tools_to_process.append(tool)
|
||||
# tools that are not the same or were not previously defined need to be processed:
|
||||
self.tools_to_process = tools_to_process
|
||||
# for all matched definitions, get the mcp_list_tools objects from the previous output:
|
||||
self.previous_tool_listings = [
|
||||
obj for obj in previous_response.output if obj.type == "mcp_list_tools" and obj.server_label in matched
|
||||
]
|
||||
# reconstruct the tool to server mappings that can be reused:
|
||||
for listing in self.previous_tool_listings:
|
||||
definition = matched[listing.server_label]
|
||||
for tool in listing.tools:
|
||||
self.previous_tools[tool.name] = definition
|
||||
|
||||
def available_tools(self) -> list[OpenAIResponseTool]:
|
||||
if not self.current_tools:
|
||||
return []
|
||||
|
||||
def convert_tool(tool: OpenAIResponseInputTool) -> OpenAIResponseTool:
|
||||
if isinstance(tool, OpenAIResponseInputToolWebSearch):
|
||||
return tool
|
||||
if isinstance(tool, OpenAIResponseInputToolFileSearch):
|
||||
return tool
|
||||
if isinstance(tool, OpenAIResponseInputToolFunction):
|
||||
return tool
|
||||
if isinstance(tool, OpenAIResponseInputToolMCP):
|
||||
return OpenAIResponseToolMCP(
|
||||
server_label=tool.server_label,
|
||||
allowed_tools=tool.allowed_tools,
|
||||
)
|
||||
|
||||
return [convert_tool(tool) for tool in self.current_tools]
|
||||
|
||||
|
||||
class ChatCompletionContext(BaseModel):
|
||||
model: str
|
||||
messages: list[OpenAIMessageParam]
|
||||
|
@ -61,6 +150,7 @@ class ChatCompletionContext(BaseModel):
|
|||
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
||||
tool_context: ToolContext | None
|
||||
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
|
||||
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
|
||||
|
||||
|
@ -71,6 +161,7 @@ class ChatCompletionContext(BaseModel):
|
|||
response_tools: list[OpenAIResponseInputTool] | None,
|
||||
temperature: float | None,
|
||||
response_format: OpenAIResponseFormatParam,
|
||||
tool_context: ToolContext,
|
||||
inputs: list[OpenAIResponseInput] | str,
|
||||
):
|
||||
super().__init__(
|
||||
|
@ -79,6 +170,7 @@ class ChatCompletionContext(BaseModel):
|
|||
response_tools=response_tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
tool_context=tool_context,
|
||||
)
|
||||
if not isinstance(inputs, str):
|
||||
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
|
||||
|
@ -95,3 +187,8 @@ class ChatCompletionContext(BaseModel):
|
|||
if request.name == tool_name and request.arguments == arguments:
|
||||
return request
|
||||
return None
|
||||
|
||||
def available_tools(self) -> list[OpenAIResponseTool]:
|
||||
if not self.tool_context:
|
||||
return []
|
||||
return self.tool_context.available_tools()
|
||||
|
|
|
@ -4,9 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContent,
|
||||
|
@ -45,7 +47,12 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
|
||||
async def convert_chat_choice_to_response_message(
|
||||
choice: OpenAIChoice,
|
||||
citation_files: dict[str, str] | None = None,
|
||||
*,
|
||||
message_id: str | None = None,
|
||||
) -> OpenAIResponseMessage:
|
||||
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
||||
output_content = ""
|
||||
if isinstance(choice.message.content, str):
|
||||
|
@ -57,9 +64,11 @@ async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenA
|
|||
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||
)
|
||||
|
||||
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||
id=message_id or f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
|
@ -97,9 +106,13 @@ async def convert_response_content_to_chat_content(
|
|||
|
||||
async def convert_response_input_to_chat_messages(
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_messages: list[OpenAIMessageParam] | None = None,
|
||||
) -> list[OpenAIMessageParam]:
|
||||
"""
|
||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||
|
||||
:param input: The input to convert
|
||||
:param previous_messages: Optional previous messages to check for function_call references
|
||||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if isinstance(input, list):
|
||||
|
@ -163,16 +176,53 @@ async def convert_response_input_to_chat_messages(
|
|||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
# Skip user messages that duplicate the last user message in previous_messages
|
||||
# This handles cases where input includes context for function_call_outputs
|
||||
if previous_messages and input_item.role == "user":
|
||||
last_user_msg = None
|
||||
for msg in reversed(previous_messages):
|
||||
if isinstance(msg, OpenAIUserMessageParam):
|
||||
last_user_msg = msg
|
||||
break
|
||||
if last_user_msg:
|
||||
last_user_content = getattr(last_user_msg, "content", None)
|
||||
if last_user_content == content:
|
||||
continue # Skip duplicate user message
|
||||
messages.append(message_type(content=content))
|
||||
if len(tool_call_results):
|
||||
raise ValueError(
|
||||
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
|
||||
)
|
||||
# Check if unpaired function_call_outputs reference function_calls from previous messages
|
||||
if previous_messages:
|
||||
previous_call_ids = _extract_tool_call_ids(previous_messages)
|
||||
for call_id in list(tool_call_results.keys()):
|
||||
if call_id in previous_call_ids:
|
||||
# Valid: this output references a call from previous messages
|
||||
# Add the tool message
|
||||
messages.append(tool_call_results[call_id])
|
||||
del tool_call_results[call_id]
|
||||
|
||||
# If still have unpaired outputs, error
|
||||
if len(tool_call_results):
|
||||
raise ValueError(
|
||||
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
|
||||
)
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
||||
|
||||
def _extract_tool_call_ids(messages: list[OpenAIMessageParam]) -> set[str]:
|
||||
"""Extract all tool_call IDs from messages."""
|
||||
call_ids = set()
|
||||
for msg in messages:
|
||||
if isinstance(msg, OpenAIAssistantMessageParam):
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
# tool_call is a Pydantic model, use attribute access
|
||||
call_ids.add(tool_call.id)
|
||||
return call_ids
|
||||
|
||||
|
||||
async def convert_response_text_to_chat_response_format(
|
||||
text: OpenAIResponseText,
|
||||
) -> OpenAIResponseFormatParam:
|
||||
|
@ -200,6 +250,53 @@ async def get_message_type_by_role(role: str):
|
|||
return role_to_type.get(role)
|
||||
|
||||
|
||||
def _extract_citations_from_text(
|
||||
text: str, citation_files: dict[str, str]
|
||||
) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
|
||||
"""Extract citation markers from text and create annotations
|
||||
|
||||
Args:
|
||||
text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
|
||||
citation_files: Dictionary mapping file_id to filename
|
||||
|
||||
Returns:
|
||||
Tuple of (annotations_list, clean_text_without_markers)
|
||||
"""
|
||||
file_id_regex = re.compile(r"<\|(?P<file_id>file-[A-Za-z0-9_-]+)\|>")
|
||||
|
||||
annotations = []
|
||||
parts = []
|
||||
total_len = 0
|
||||
last_end = 0
|
||||
|
||||
for m in file_id_regex.finditer(text):
|
||||
# segment before the marker
|
||||
prefix = text[last_end : m.start()]
|
||||
|
||||
# drop one space if it exists (since marker is at sentence end)
|
||||
if prefix.endswith(" "):
|
||||
prefix = prefix[:-1]
|
||||
|
||||
parts.append(prefix)
|
||||
total_len += len(prefix)
|
||||
|
||||
fid = m.group(1)
|
||||
if fid in citation_files:
|
||||
annotations.append(
|
||||
OpenAIResponseAnnotationFileCitation(
|
||||
file_id=fid,
|
||||
filename=citation_files[fid],
|
||||
index=total_len, # index points to punctuation
|
||||
)
|
||||
)
|
||||
|
||||
last_end = m.end()
|
||||
|
||||
parts.append(text[last_end:])
|
||||
cleaned_text = "".join(parts)
|
||||
return annotations, cleaned_text
|
||||
|
||||
|
||||
def is_function_tool_call(
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
tools: list[OpenAIResponseInputTool],
|
||||
|
|
|
@ -22,7 +22,10 @@ from llama_stack.apis.files import Files, OpenAIFilePurpose
|
|||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
|
@ -178,9 +181,9 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
|
||||
raise ValueError(
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
|
||||
)
|
||||
|
||||
if completion_window != "24h":
|
||||
|
@ -425,18 +428,23 @@ class ReferenceBatchesImpl(Batches):
|
|||
valid = False
|
||||
|
||||
if batch.endpoint == "/v1/chat/completions":
|
||||
required_params = [
|
||||
required_params: list[tuple[str, Any, str]] = [
|
||||
("model", str, "a string"),
|
||||
# messages is specific to /v1/chat/completions
|
||||
# we could skip validating messages here and let inference fail. however,
|
||||
# that would be a very expensive way to find out messages is wrong.
|
||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||
]
|
||||
else: # /v1/completions
|
||||
elif batch.endpoint == "/v1/completions":
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
|
||||
]
|
||||
else: # /v1/embeddings
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("input", (str, list), "a string or array of strings"),
|
||||
]
|
||||
|
||||
for param, expected_type, type_string in required_params:
|
||||
if param not in body:
|
||||
|
@ -601,7 +609,8 @@ class ReferenceBatchesImpl(Batches):
|
|||
# TODO(SECURITY): review body for security issues
|
||||
if request.url == "/v1/chat/completions":
|
||||
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||
chat_params = OpenAIChatCompletionRequestWithExtraBody(**request.body)
|
||||
chat_response = await self.inference_api.openai_chat_completion(chat_params)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
||||
|
@ -614,8 +623,9 @@ class ReferenceBatchesImpl(Batches):
|
|||
"body": chat_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
else: # /v1/completions
|
||||
completion_response = await self.inference_api.openai_completion(**request.body)
|
||||
elif request.url == "/v1/completions":
|
||||
completion_params = OpenAICompletionRequestWithExtraBody(**request.body)
|
||||
completion_response = await self.inference_api.openai_completion(completion_params)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(completion_response, "model_dump_json"), (
|
||||
|
@ -630,6 +640,22 @@ class ReferenceBatchesImpl(Batches):
|
|||
"body": completion_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
else: # /v1/embeddings
|
||||
embeddings_response = await self.inference_api.openai_embeddings(
|
||||
OpenAIEmbeddingsRequestWithExtraBody(**request.body)
|
||||
)
|
||||
assert hasattr(embeddings_response, "model_dump_json"), (
|
||||
"Embeddings response must have model_dump_json method"
|
||||
)
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id, # TODO: should this be different?
|
||||
"body": embeddings_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||
return {
|
||||
|
|
|
@ -12,7 +12,14 @@ from llama_stack.apis.agents import Agents, StepType
|
|||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference, OpenAISystemMessageParam, OpenAIUserMessageParam, UserMessage
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
|
@ -168,11 +175,12 @@ class MetaReferenceEvalImpl(
|
|||
sampling_params["stop"] = candidate.sampling_params.stop
|
||||
|
||||
input_content = json.loads(x[ColumnName.completion_input.value])
|
||||
response = await self.inference_api.openai_completion(
|
||||
params = OpenAICompletionRequestWithExtraBody(
|
||||
model=candidate.model,
|
||||
prompt=input_content,
|
||||
**sampling_params,
|
||||
)
|
||||
response = await self.inference_api.openai_completion(params)
|
||||
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
|
@ -187,11 +195,12 @@ class MetaReferenceEvalImpl(
|
|||
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
|
||||
messages += input_messages
|
||||
response = await self.inference_api.openai_chat_completion(
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=candidate.model,
|
||||
messages=messages,
|
||||
**sampling_params,
|
||||
)
|
||||
response = await self.inference_api.openai_chat_completion(params)
|
||||
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
|
||||
else:
|
||||
raise ValueError("Invalid input row")
|
||||
|
|
|
@ -22,6 +22,7 @@ from llama_stack.apis.files import (
|
|||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
|
@ -65,7 +66,7 @@ class LocalfsFilesImpl(Files):
|
|||
|
||||
def _generate_file_id(self) -> str:
|
||||
"""Generate a unique file ID for OpenAI API."""
|
||||
return f"file-{uuid.uuid4().hex}"
|
||||
return generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
|
||||
|
||||
def _get_file_path(self, file_id: str) -> Path:
|
||||
"""Get the filesystem path for a file ID."""
|
||||
|
@ -95,7 +96,9 @@ class LocalfsFilesImpl(Files):
|
|||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
if expires_after is not None:
|
||||
raise NotImplementedError("File expiration is not supported by this provider")
|
||||
logger.warning(
|
||||
f"File expiration is not supported by this provider, ignoring expires_after: {expires_after}"
|
||||
)
|
||||
|
||||
file_id = self._generate_file_id()
|
||||
file_path = self._get_file_path(file_id)
|
||||
|
|
|
@ -18,7 +18,7 @@ def model_checkpoint_dir(model_id) -> str:
|
|||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
||||
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
|
||||
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
|
||||
f"If you try to use the native llama model, please download the model using `llama-model download --source meta --model-id {model_id}` (see https://github.com/meta-llama/llama-models). "
|
||||
f"Otherwise, please save your model checkpoint under {model_local_dir(model_id)}"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
|
|
@ -6,16 +6,16 @@
|
|||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
InferenceProvider,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAICompletion,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -65,7 +65,10 @@ class MetaReferenceInferenceImpl(
|
|||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
async def openai_completion(self, *args, **kwargs):
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
|
@ -150,28 +153,6 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
|
||||
|
|
|
@ -5,17 +5,16 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
InferenceProvider,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -73,56 +72,12 @@ class SentenceTransformersInferenceImpl(
|
|||
|
||||
async def openai_completion(
|
||||
self,
|
||||
# Standard OpenAI completion parameters
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
# vLLM-specific parameters
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
# for fill-in-the-middle type completion
|
||||
suffix: str | None = None,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")
|
||||
|
|
|
@ -104,9 +104,10 @@ class LoraFinetuningSingleDevice:
|
|||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
hf_repo = model.huggingface_repo or f"meta-llama/{model.descriptor()}"
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
|
||||
f"Please download model using `llama download --model-id {model.descriptor()}`"
|
||||
f"Please download the model using `huggingface-cli download {hf_repo} --local-dir ~/.llama/{model.descriptor()}`"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any
|
|||
if TYPE_CHECKING:
|
||||
from codeshield.cs import CodeShieldScanResult
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
@ -53,7 +53,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
|
|
|
@ -10,7 +10,12 @@ from string import Template
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.inference import Inference, Message, UserMessage
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
@ -159,7 +164,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
|
@ -169,8 +174,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
# since this might be a tool call, first role might not be user
|
||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
if len(messages) > 0 and messages[0].role != "user":
|
||||
messages[0] = OpenAIUserMessageParam(content=messages[0].content)
|
||||
|
||||
# Use the inference API's model resolution instead of hardcoded mappings
|
||||
# This allows the shield to work with any registered model
|
||||
|
@ -202,7 +207,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
messages = [input]
|
||||
|
||||
# convert to user messages format with role
|
||||
messages = [UserMessage(content=m) for m in messages]
|
||||
messages = [OpenAIUserMessageParam(content=m) for m in messages]
|
||||
|
||||
# Determine safety categories based on the model type
|
||||
# For known Llama Guard models, use specific categories
|
||||
|
@ -271,7 +276,7 @@ class LlamaGuardShield:
|
|||
|
||||
return final_categories
|
||||
|
||||
def validate_messages(self, messages: list[Message]) -> None:
|
||||
def validate_messages(self, messages: list[OpenAIMessageParam]) -> list[OpenAIMessageParam]:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
if messages[0].role != Role.user.value:
|
||||
|
@ -282,7 +287,7 @@ class LlamaGuardShield:
|
|||
|
||||
return messages
|
||||
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
|
@ -290,20 +295,21 @@ class LlamaGuardShield:
|
|||
else:
|
||||
shield_input_message = self.build_text_shield_input(messages)
|
||||
|
||||
response = await self.inference_api.openai_chat_completion(
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=False,
|
||||
temperature=0.0, # default is 1, which is too high for safety
|
||||
)
|
||||
response = await self.inference_api.openai_chat_completion(params)
|
||||
content = response.choices[0].message.content
|
||||
content = content.strip()
|
||||
return self.get_shield_response(content)
|
||||
|
||||
def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
|
||||
return UserMessage(content=self.build_prompt(messages))
|
||||
def build_text_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
|
||||
return OpenAIUserMessageParam(content=self.build_prompt(messages))
|
||||
|
||||
def build_vision_shield_input(self, messages: list[Message]) -> UserMessage:
|
||||
def build_vision_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
|
||||
conversation = []
|
||||
most_recent_img = None
|
||||
|
||||
|
@ -326,7 +332,7 @@ class LlamaGuardShield:
|
|||
else:
|
||||
raise ValueError(f"Unknown content type: {c}")
|
||||
|
||||
conversation.append(UserMessage(content=content))
|
||||
conversation.append(OpenAIUserMessageParam(content=content))
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {m.content}")
|
||||
|
||||
|
@ -335,9 +341,9 @@ class LlamaGuardShield:
|
|||
prompt.append(most_recent_img)
|
||||
prompt.append(self.build_prompt(conversation[::-1]))
|
||||
|
||||
return UserMessage(content=prompt)
|
||||
return OpenAIUserMessageParam(content=prompt)
|
||||
|
||||
def build_prompt(self, messages: list[Message]) -> str:
|
||||
def build_prompt(self, messages: list[OpenAIMessageParam]) -> str:
|
||||
categories = self.get_safety_categories()
|
||||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
|
@ -370,18 +376,20 @@ class LlamaGuardShield:
|
|||
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
async def run_moderation(self, messages: list[Message]) -> ModerationObject:
|
||||
async def run_moderation(self, messages: list[OpenAIMessageParam]) -> ModerationObject:
|
||||
if not messages:
|
||||
return self.create_moderation_object(self.model)
|
||||
|
||||
# TODO: Add Image based support for OpenAI Moderations
|
||||
shield_input_message = self.build_text_shield_input(messages)
|
||||
|
||||
response = await self.inference_api.openai_chat_completion(
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=False,
|
||||
temperature=0.0, # default is 1, which is too high for safety
|
||||
)
|
||||
response = await self.inference_api.openai_chat_completion(params)
|
||||
content = response.choices[0].message.content
|
||||
content = content.strip()
|
||||
return self.get_moderation_object(content)
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any
|
|||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
@ -22,9 +22,7 @@ from llama_stack.apis.shields import Shield
|
|||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
|
||||
from .config import PromptGuardConfig, PromptGuardType
|
||||
|
||||
|
@ -56,7 +54,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
|
@ -93,7 +91,7 @@ class PromptGuardShield:
|
|||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
|
||||
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
|
||||
message = messages[-1]
|
||||
text = interleaved_content_as_str(message.content)
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
import re
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import Inference, OpenAIChatCompletionRequestWithExtraBody
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
@ -55,7 +55,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
|||
generated_answer=generated_answer,
|
||||
)
|
||||
|
||||
judge_response = await self.inference_api.openai_chat_completion(
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=fn_def.params.judge_model,
|
||||
messages=[
|
||||
{
|
||||
|
@ -64,6 +64,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
|||
}
|
||||
],
|
||||
)
|
||||
judge_response = await self.inference_api.openai_chat_completion(params)
|
||||
content = judge_response.choices[0].message.content
|
||||
rating_regexes = fn_def.params.judge_score_regexes
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
from jinja2 import Template
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import OpenAIUserMessageParam
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
|
||||
from llama_stack.apis.tools.rag_tool import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
|
@ -65,11 +65,12 @@ async def llm_rag_query_generator(
|
|||
|
||||
model = config.model
|
||||
message = OpenAIUserMessageParam(content=rendered_content)
|
||||
response = await inference_api.openai_chat_completion(
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=False,
|
||||
)
|
||||
response = await inference_api.openai_chat_completion(params)
|
||||
|
||||
query = response.choices[0].message.content
|
||||
|
||||
|
|
|
@ -8,8 +8,6 @@ import asyncio
|
|||
import base64
|
||||
import io
|
||||
import mimetypes
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
@ -52,10 +50,6 @@ from .context_retriever import generate_rag_query
|
|||
log = get_logger(name=__name__, category="tool_runtime")
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||
|
||||
|
||||
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
||||
"""Get raw binary data and mime type from a RAGDocument for file upload."""
|
||||
if isinstance(doc.content, URL):
|
||||
|
@ -331,5 +325,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
|
||||
return ToolInvocationResult(
|
||||
content=result.content or [],
|
||||
metadata=result.metadata,
|
||||
metadata={
|
||||
**(result.metadata or {}),
|
||||
"citation_files": getattr(result, "citation_files", None),
|
||||
},
|
||||
)
|
||||
|
|
|
@ -200,12 +200,10 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
|
@ -227,8 +225,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
await self.initialize_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
# Cleanup if needed
|
||||
pass
|
||||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
|
|
|
@ -410,12 +410,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
"""
|
||||
|
||||
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
|
@ -436,8 +434,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
await self.initialize_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
# nothing to do since we don't maintain a persistent connection
|
||||
pass
|
||||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
return [v.vector_db for v in self.cache.values()]
|
||||
|
|
|
@ -32,9 +32,12 @@ def available_providers() -> list[ProviderSpec]:
|
|||
Api.inference,
|
||||
Api.safety,
|
||||
Api.vector_io,
|
||||
Api.vector_dbs,
|
||||
Api.tool_runtime,
|
||||
Api.tool_groups,
|
||||
Api.conversations,
|
||||
],
|
||||
optional_api_dependencies=[
|
||||
Api.telemetry,
|
||||
],
|
||||
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
|
||||
),
|
||||
|
|
|
@ -52,9 +52,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter_type="cerebras",
|
||||
provider_type="remote::cerebras",
|
||||
pip_packages=[
|
||||
"cerebras_cloud_sdk",
|
||||
],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.cerebras",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||
|
@ -169,7 +167,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter_type="openai",
|
||||
provider_type="remote::openai",
|
||||
pip_packages=["litellm"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.openai",
|
||||
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||
|
@ -179,7 +177,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter_type="anthropic",
|
||||
provider_type="remote::anthropic",
|
||||
pip_packages=["litellm"],
|
||||
pip_packages=["anthropic"],
|
||||
module="llama_stack.providers.remote.inference.anthropic",
|
||||
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||
|
@ -189,9 +187,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter_type="gemini",
|
||||
provider_type="remote::gemini",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.gemini",
|
||||
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||
|
@ -202,7 +198,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
adapter_type="vertexai",
|
||||
provider_type="remote::vertexai",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
"google-cloud-aiplatform",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.vertexai",
|
||||
|
@ -233,9 +228,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
adapter_type="groq",
|
||||
provider_type="remote::groq",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.groq",
|
||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
|
@ -245,7 +238,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
adapter_type="llama-openai-compat",
|
||||
provider_type="remote::llama-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||
|
@ -255,9 +248,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
adapter_type="sambanova",
|
||||
provider_type="remote::sambanova",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.sambanova",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||
|
@ -277,7 +268,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
adapter_type="watsonx",
|
||||
provider_type="remote::watsonx",
|
||||
pip_packages=["ibm_watsonx_ai"],
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.watsonx",
|
||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||
|
@ -287,7 +278,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
provider_type="remote::azure",
|
||||
adapter_type="azure",
|
||||
pip_packages=["litellm"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.azure",
|
||||
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",
|
||||
|
|
|
@ -11,6 +11,7 @@ from llama_stack.providers.datatypes import (
|
|||
ProviderSpec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
from llama_stack.providers.registry.vector_io import DEFAULT_VECTOR_IO_DEPS
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
|
@ -18,9 +19,8 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
provider_type="inline::rag-runtime",
|
||||
pip_packages=[
|
||||
"chardet",
|
||||
"pypdf",
|
||||
pip_packages=DEFAULT_VECTOR_IO_DEPS
|
||||
+ [
|
||||
"tqdm",
|
||||
"numpy",
|
||||
"scikit-learn",
|
||||
|
|
|
@ -12,13 +12,16 @@ from llama_stack.providers.datatypes import (
|
|||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
# Common dependencies for all vector IO providers that support document processing
|
||||
DEFAULT_VECTOR_IO_DEPS = ["chardet", "pypdf"]
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=["faiss-cpu"],
|
||||
pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.faiss",
|
||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||
|
@ -29,7 +32,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::faiss",
|
||||
pip_packages=["faiss-cpu"],
|
||||
pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.faiss",
|
||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
@ -82,7 +85,7 @@ more details about Faiss in general.
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::sqlite-vec",
|
||||
pip_packages=["sqlite-vec"],
|
||||
pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
@ -289,7 +292,7 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::sqlite_vec",
|
||||
pip_packages=["sqlite-vec"],
|
||||
pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
|
||||
|
@ -303,7 +306,7 @@ Please refer to the sqlite-vec provider documentation.
|
|||
api=Api.vector_io,
|
||||
adapter_type="chromadb",
|
||||
provider_type="remote::chromadb",
|
||||
pip_packages=["chromadb-client"],
|
||||
pip_packages=["chromadb-client"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.remote.vector_io.chroma",
|
||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
@ -345,7 +348,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::chromadb",
|
||||
pip_packages=["chromadb"],
|
||||
pip_packages=["chromadb"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.chroma",
|
||||
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
@ -389,7 +392,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
api=Api.vector_io,
|
||||
adapter_type="pgvector",
|
||||
provider_type="remote::pgvector",
|
||||
pip_packages=["psycopg2-binary"],
|
||||
pip_packages=["psycopg2-binary"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
@ -500,7 +503,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
|||
api=Api.vector_io,
|
||||
adapter_type="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
pip_packages=["weaviate-client>=4.16.5"],
|
||||
pip_packages=["weaviate-client>=4.16.5"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
|
@ -541,7 +544,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::qdrant",
|
||||
pip_packages=["qdrant-client"],
|
||||
pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
@ -594,7 +597,7 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta
|
|||
api=Api.vector_io,
|
||||
adapter_type="qdrant",
|
||||
provider_type="remote::qdrant",
|
||||
pip_packages=["qdrant-client"],
|
||||
pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
@ -607,7 +610,7 @@ Please refer to the inline provider documentation.
|
|||
api=Api.vector_io,
|
||||
adapter_type="milvus",
|
||||
provider_type="remote::milvus",
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
pip_packages=["pymilvus>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.remote.vector_io.milvus",
|
||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
@ -813,7 +816,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::milvus",
|
||||
pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
|
||||
pip_packages=["pymilvus[milvus-lite]>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.milvus",
|
||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.files import (
|
|||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
|
@ -198,7 +199,7 @@ class S3FilesImpl(Files):
|
|||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||
) -> OpenAIFileObject:
|
||||
file_id = f"file-{uuid.uuid4().hex}"
|
||||
file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
|
||||
|
||||
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||
|
||||
|
|
|
@ -10,6 +10,6 @@ from .config import AnthropicConfig
|
|||
async def get_adapter_impl(config: AnthropicConfig, _deps):
|
||||
from .anthropic import AnthropicInferenceAdapter
|
||||
|
||||
impl = AnthropicInferenceAdapter(config)
|
||||
impl = AnthropicInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,13 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from collections.abc import Iterable
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AnthropicConfig
|
||||
|
||||
|
||||
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
class AnthropicInferenceAdapter(OpenAIMixin):
|
||||
config: AnthropicConfig
|
||||
|
||||
provider_data_api_key_field: str = "anthropic_api_key"
|
||||
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
|
||||
# TODO: add support for voyageai, which is where these models are hosted
|
||||
# embedding_model_metadata = {
|
||||
|
@ -23,22 +29,8 @@ class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# }
|
||||
|
||||
def __init__(self, config: AnthropicConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="anthropic",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="anthropic_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://api.anthropic.com/v1"
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]
|
||||
|
|
|
@ -21,11 +21,6 @@ class AnthropicProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AnthropicConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Anthropic models",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -10,6 +10,6 @@ from .config import AzureConfig
|
|||
async def get_adapter_impl(config: AzureConfig, _deps):
|
||||
from .azure import AzureInferenceAdapter
|
||||
|
||||
impl = AzureInferenceAdapter(config)
|
||||
impl = AzureInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,31 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||
LiteLLMOpenAIMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AzureConfig
|
||||
|
||||
|
||||
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: AzureConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="azure",
|
||||
api_key_from_config=config.api_key.get_secret_value(),
|
||||
provider_data_api_key_field="azure_api_key",
|
||||
openai_compat_api_base=str(config.api_base),
|
||||
)
|
||||
self.config = config
|
||||
class AzureInferenceAdapter(OpenAIMixin):
|
||||
config: AzureConfig
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
provider_data_api_key_field: str = "azure_api_key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
|
@ -37,26 +23,3 @@ class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
Returns the Azure API base URL from the configuration.
|
||||
"""
|
||||
return urljoin(str(self.config.api_base), "/openai/v1")
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||
# Get base parameters from parent
|
||||
params = await super()._get_params(request)
|
||||
|
||||
# Add Azure specific parameters
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data:
|
||||
if getattr(provider_data, "azure_api_key", None):
|
||||
params["api_key"] = provider_data.azure_api_key
|
||||
if getattr(provider_data, "azure_api_base", None):
|
||||
params["api_base"] = provider_data.azure_api_base
|
||||
if getattr(provider_data, "azure_api_version", None):
|
||||
params["api_version"] = provider_data.azure_api_version
|
||||
if getattr(provider_data, "azure_api_type", None):
|
||||
params["api_type"] = provider_data.azure_api_type
|
||||
else:
|
||||
params["api_key"] = self.config.api_key.get_secret_value()
|
||||
params["api_base"] = str(self.config.api_base)
|
||||
params["api_version"] = self.config.api_version
|
||||
params["api_type"] = self.config.api_type
|
||||
|
||||
return params
|
||||
|
|
|
@ -32,9 +32,6 @@ class AzureProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AzureConfig(RemoteInferenceProviderConfig):
|
||||
api_key: SecretStr = Field(
|
||||
description="Azure API key for Azure",
|
||||
)
|
||||
api_base: HttpUrl = Field(
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||
)
|
||||
|
|
|
@ -6,21 +6,21 @@
|
|||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from botocore.client import BaseClient
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
Inference,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
|
@ -125,66 +125,18 @@ class BedrockInferenceAdapter(
|
|||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
# Standard OpenAI completion parameters
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
# vLLM-specific parameters
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
# for fill-in-the-middle type completion
|
||||
suffix: str | None = None,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
|
||||
|
|
|
@ -12,7 +12,7 @@ async def get_adapter_impl(config: CerebrasImplConfig, _deps):
|
|||
|
||||
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = CerebrasInferenceAdapter(config)
|
||||
impl = CerebrasInferenceAdapter(config=config)
|
||||
|
||||
await impl.initialize()
|
||||
|
||||
|
|
|
@ -6,77 +6,23 @@
|
|||
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from cerebras.cloud.sdk import AsyncCerebras
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
Inference,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
TopKSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import CerebrasImplConfig
|
||||
|
||||
|
||||
class CerebrasInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
Inference,
|
||||
):
|
||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
# TODO: make this use provider data, etc. like other providers
|
||||
self._cerebras_client = AsyncCerebras(
|
||||
base_url=self.config.base_url,
|
||||
api_key=self.config.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value()
|
||||
class CerebrasInferenceAdapter(OpenAIMixin):
|
||||
config: CerebrasImplConfig
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return urljoin(self.config.base_url, "v1")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
|
||||
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
||||
raise ValueError("`top_k` not supported by Cerebras")
|
||||
|
||||
prompt = ""
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||
elif isinstance(request, CompletionRequest):
|
||||
prompt = await completion_request_to_prompt(request)
|
||||
else:
|
||||
raise ValueError(f"Unknown request type {type(request)}")
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
"prompt": prompt,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -21,10 +21,6 @@ class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
|||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||
description="Base URL for the Cerebras API",
|
||||
)
|
||||
api_key: SecretStr = Field(
|
||||
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
|
||||
description="Cerebras API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
@ -11,6 +11,6 @@ async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
|||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = DatabricksInferenceAdapter(config)
|
||||
impl = DatabricksInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue