mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-21 12:09:40 +00:00
Merge branch 'main' into nvidia-e2e-notebook
This commit is contained in:
commit
1a492ad0cc
200 changed files with 8714 additions and 3175 deletions
|
@ -37,6 +37,7 @@ from .openai_responses import (
|
|||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseText,
|
||||
)
|
||||
|
||||
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||
|
@ -603,7 +604,9 @@ class Agents(Protocol):
|
|||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a new OpenAI response.
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
|
@ -126,6 +127,32 @@ OpenAIResponseOutput = Annotated[
|
|||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||
|
||||
|
||||
# This has to be a TypedDict because we need a "schema" field and our strong
|
||||
# typing code in the schema generator doesn't support Pydantic aliases. That also
|
||||
# means we can't use a discriminator field here, because TypedDicts don't support
|
||||
# default values which the strong typing code requires for discriminators.
|
||||
class OpenAIResponseTextFormat(TypedDict, total=False):
|
||||
"""Configuration for Responses API text format.
|
||||
|
||||
:param type: Must be "text", "json_schema", or "json_object" to identify the format type
|
||||
:param name: The name of the response format. Only used for json_schema.
|
||||
:param schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model. Only used for json_schema.
|
||||
:param description: (Optional) A description of the response format. Only used for json_schema.
|
||||
:param strict: (Optional) Whether to strictly enforce the JSON schema. If true, the response must match the schema exactly. Only used for json_schema.
|
||||
"""
|
||||
|
||||
type: Literal["text"] | Literal["json_schema"] | Literal["json_object"]
|
||||
name: str | None
|
||||
schema: dict[str, Any] | None
|
||||
description: str | None
|
||||
strict: bool | None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseText(BaseModel):
|
||||
format: OpenAIResponseTextFormat | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObject(BaseModel):
|
||||
created_at: int
|
||||
|
@ -138,6 +165,9 @@ class OpenAIResponseObject(BaseModel):
|
|||
previous_response_id: str | None = None
|
||||
status: str
|
||||
temperature: float | None = None
|
||||
# Default to text format to avoid breaking the loading of old responses
|
||||
# before the field was added. New responses will have this set always.
|
||||
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
top_p: float | None = None
|
||||
truncation: str | None = None
|
||||
user: str | None = None
|
||||
|
@ -149,6 +179,30 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel):
|
|||
type: Literal["response.created"] = "response.created"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||
response: OpenAIResponseObject
|
||||
type: Literal["response.completed"] = "response.completed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseOutputItemAdded(BaseModel):
|
||||
response_id: str
|
||||
item: OpenAIResponseOutput
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.output_item.added"] = "response.output_item.added"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseOutputItemDone(BaseModel):
|
||||
response_id: str
|
||||
item: OpenAIResponseOutput
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.output_item.done"] = "response.output_item.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
||||
content_index: int
|
||||
|
@ -160,14 +214,132 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||
response: OpenAIResponseObject
|
||||
type: Literal["response.completed"] = "response.completed"
|
||||
class OpenAIResponseObjectStreamResponseOutputTextDone(BaseModel):
|
||||
content_index: int
|
||||
text: str # final text of the output item
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.output_text.done"] = "response.output_text.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(BaseModel):
|
||||
delta: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.function_call_arguments.delta"] = "response.function_call_arguments.delta"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(BaseModel):
|
||||
arguments: str # final arguments of the function call
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.function_call_arguments.done"] = "response.function_call_arguments.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseWebSearchCallInProgress(BaseModel):
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.web_search_call.in_progress"] = "response.web_search_call.in_progress"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseWebSearchCallSearching(BaseModel):
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.web_search_call.searching"] = "response.web_search_call.searching"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseWebSearchCallCompleted(BaseModel):
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.web_search_call.completed"] = "response.web_search_call.completed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseMcpListToolsInProgress(BaseModel):
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_list_tools.in_progress"] = "response.mcp_list_tools.in_progress"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseMcpListToolsFailed(BaseModel):
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_list_tools.failed"] = "response.mcp_list_tools.failed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseMcpListToolsCompleted(BaseModel):
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_list_tools.completed"] = "response.mcp_list_tools.completed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(BaseModel):
|
||||
delta: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_call.arguments.delta"] = "response.mcp_call.arguments.delta"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseMcpCallArgumentsDone(BaseModel):
|
||||
arguments: str # final arguments of the MCP call
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_call.arguments.done"] = "response.mcp_call.arguments.done"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseMcpCallInProgress(BaseModel):
|
||||
item_id: str
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_call.in_progress"] = "response.mcp_call.in_progress"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseMcpCallFailed(BaseModel):
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_call.failed"] = "response.mcp_call.failed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel):
|
||||
sequence_number: int
|
||||
type: Literal["response.mcp_call.completed"] = "response.mcp_call.completed"
|
||||
|
||||
|
||||
OpenAIResponseObjectStream = Annotated[
|
||||
OpenAIResponseObjectStreamResponseCreated
|
||||
| OpenAIResponseObjectStreamResponseOutputItemAdded
|
||||
| OpenAIResponseObjectStreamResponseOutputItemDone
|
||||
| OpenAIResponseObjectStreamResponseOutputTextDelta
|
||||
| OpenAIResponseObjectStreamResponseOutputTextDone
|
||||
| OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta
|
||||
| OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone
|
||||
| OpenAIResponseObjectStreamResponseWebSearchCallInProgress
|
||||
| OpenAIResponseObjectStreamResponseWebSearchCallSearching
|
||||
| OpenAIResponseObjectStreamResponseWebSearchCallCompleted
|
||||
| OpenAIResponseObjectStreamResponseMcpListToolsInProgress
|
||||
| OpenAIResponseObjectStreamResponseMcpListToolsFailed
|
||||
| OpenAIResponseObjectStreamResponseMcpListToolsCompleted
|
||||
| OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta
|
||||
| OpenAIResponseObjectStreamResponseMcpCallArgumentsDone
|
||||
| OpenAIResponseObjectStreamResponseMcpCallInProgress
|
||||
| OpenAIResponseObjectStreamResponseMcpCallFailed
|
||||
| OpenAIResponseObjectStreamResponseMcpCallCompleted
|
||||
| OpenAIResponseObjectStreamResponseCompleted,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
|
@ -4,179 +4,158 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
from enum import Enum
|
||||
from typing import Annotated, Literal, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import File, Form, Response, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FileUploadResponse(BaseModel):
|
||||
# OpenAI Files API Models
|
||||
class OpenAIFilePurpose(str, Enum):
|
||||
"""
|
||||
Valid purpose values for OpenAI Files API.
|
||||
"""
|
||||
Response after initiating a file upload session.
|
||||
|
||||
:param id: ID of the upload session
|
||||
:param url: Upload URL for the file or file parts
|
||||
:param offset: Upload content offset
|
||||
:param size: Upload content size
|
||||
ASSISTANTS = "assistants"
|
||||
# TODO: Add other purposes as needed
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIFileObject(BaseModel):
|
||||
"""
|
||||
OpenAI File object as defined in the OpenAI Files API.
|
||||
|
||||
:param object: The object type, which is always "file"
|
||||
:param id: The file identifier, which can be referenced in the API endpoints
|
||||
:param bytes: The size of the file, in bytes
|
||||
:param created_at: The Unix timestamp (in seconds) for when the file was created
|
||||
:param expires_at: The Unix timestamp (in seconds) for when the file expires
|
||||
:param filename: The name of the file
|
||||
:param purpose: The intended purpose of the file
|
||||
"""
|
||||
|
||||
object: Literal["file"] = "file"
|
||||
id: str
|
||||
bytes: int
|
||||
created_at: int
|
||||
expires_at: int
|
||||
filename: str
|
||||
purpose: OpenAIFilePurpose
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListOpenAIFileResponse(BaseModel):
|
||||
"""
|
||||
Response for listing files in OpenAI Files API.
|
||||
|
||||
:param data: List of file objects
|
||||
:param object: The object type, which is always "list"
|
||||
"""
|
||||
|
||||
data: list[OpenAIFileObject]
|
||||
has_more: bool
|
||||
first_id: str
|
||||
last_id: str
|
||||
object: Literal["list"] = "list"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIFileDeleteResponse(BaseModel):
|
||||
"""
|
||||
Response for deleting a file in OpenAI Files API.
|
||||
|
||||
:param id: The file identifier that was deleted
|
||||
:param object: The object type, which is always "file"
|
||||
:param deleted: Whether the file was successfully deleted
|
||||
"""
|
||||
|
||||
id: str
|
||||
url: str
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BucketResponse(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBucketResponse(BaseModel):
|
||||
"""
|
||||
Response representing a list of file entries.
|
||||
|
||||
:param data: List of FileResponse entries
|
||||
"""
|
||||
|
||||
data: list[BucketResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FileResponse(BaseModel):
|
||||
"""
|
||||
Response representing a file entry.
|
||||
|
||||
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
:param mime_type: MIME type of the file
|
||||
:param url: Upload URL for the file contents
|
||||
:param bytes: Size of the file in bytes
|
||||
:param created_at: Timestamp of when the file was created
|
||||
"""
|
||||
|
||||
bucket: str
|
||||
key: str
|
||||
mime_type: str
|
||||
url: str
|
||||
bytes: int
|
||||
created_at: int
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListFileResponse(BaseModel):
|
||||
"""
|
||||
Response representing a list of file entries.
|
||||
|
||||
:param data: List of FileResponse entries
|
||||
"""
|
||||
|
||||
data: list[FileResponse]
|
||||
object: Literal["file"] = "file"
|
||||
deleted: bool
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Files(Protocol):
|
||||
@webmethod(route="/files", method="POST")
|
||||
async def create_upload_session(
|
||||
# OpenAI Files API Endpoints
|
||||
@webmethod(route="/openai/v1/files", method="POST")
|
||||
async def openai_upload_file(
|
||||
self,
|
||||
bucket: str,
|
||||
key: str,
|
||||
mime_type: str,
|
||||
size: int,
|
||||
) -> FileUploadResponse:
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Create a new upload session for a file identified by a bucket and key.
|
||||
Upload a file that can be used across various endpoints.
|
||||
|
||||
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-).
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||
:param mime_type: MIME type of the file.
|
||||
:param size: File size in bytes.
|
||||
:returns: A FileUploadResponse.
|
||||
The file upload should be a multipart form request with:
|
||||
- file: The File object (not file name) to be uploaded.
|
||||
- purpose: The intended purpose of the uploaded file.
|
||||
|
||||
:param file: The uploaded file object containing content and metadata (filename, content_type, etc.).
|
||||
:param purpose: The intended purpose of the uploaded file (e.g., "assistants", "fine-tune").
|
||||
:returns: An OpenAIFileObject representing the uploaded file.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/files/session:{upload_id}", method="POST", raw_bytes_request_body=True)
|
||||
async def upload_content_to_session(
|
||||
@webmethod(route="/openai/v1/files", method="GET")
|
||||
async def openai_list_files(
|
||||
self,
|
||||
upload_id: str,
|
||||
) -> FileResponse | None:
|
||||
after: str | None = None,
|
||||
limit: int | None = 10000,
|
||||
order: Order | None = Order.desc,
|
||||
purpose: OpenAIFilePurpose | None = None,
|
||||
) -> ListOpenAIFileResponse:
|
||||
"""
|
||||
Upload file content to an existing upload session.
|
||||
On the server, request body will have the raw bytes that are uploaded.
|
||||
Returns a list of files that belong to the user's organization.
|
||||
|
||||
:param upload_id: ID of the upload session.
|
||||
:returns: A FileResponse or None if the upload is not complete.
|
||||
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
|
||||
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 10,000, and the default is 10,000.
|
||||
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
|
||||
:param purpose: Only return files with the given purpose.
|
||||
:returns: An ListOpenAIFileResponse containing the list of files.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/files/session:{upload_id}", method="GET")
|
||||
async def get_upload_session_info(
|
||||
@webmethod(route="/openai/v1/files/{file_id}", method="GET")
|
||||
async def openai_retrieve_file(
|
||||
self,
|
||||
upload_id: str,
|
||||
) -> FileUploadResponse:
|
||||
file_id: str,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Returns information about an existsing upload session.
|
||||
Returns information about a specific file.
|
||||
|
||||
:param upload_id: ID of the upload session.
|
||||
:returns: A FileUploadResponse.
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
:returns: An OpenAIFileObject containing file information.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/files", method="GET")
|
||||
async def list_all_buckets(
|
||||
@webmethod(route="/openai/v1/files/{file_id}", method="DELETE")
|
||||
async def openai_delete_file(
|
||||
self,
|
||||
bucket: str,
|
||||
) -> ListBucketResponse:
|
||||
file_id: str,
|
||||
) -> OpenAIFileDeleteResponse:
|
||||
"""
|
||||
List all buckets.
|
||||
Delete a file.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:returns: A ListBucketResponse.
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
:returns: An OpenAIFileDeleteResponse indicating successful deletion.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/files/{bucket}", method="GET")
|
||||
async def list_files_in_bucket(
|
||||
@webmethod(route="/openai/v1/files/{file_id}/content", method="GET")
|
||||
async def openai_retrieve_file_content(
|
||||
self,
|
||||
bucket: str,
|
||||
) -> ListFileResponse:
|
||||
file_id: str,
|
||||
) -> Response:
|
||||
"""
|
||||
List all files in a bucket.
|
||||
Returns the contents of the specified file.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:returns: A ListFileResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/files/{bucket}/{key:path}", method="GET")
|
||||
async def get_file(
|
||||
self,
|
||||
bucket: str,
|
||||
key: str,
|
||||
) -> FileResponse:
|
||||
"""
|
||||
Get a file info identified by a bucket and key.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||
:returns: A FileResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/files/{bucket}/{key:path}", method="DELETE")
|
||||
async def delete_file(
|
||||
self,
|
||||
bucket: str,
|
||||
key: str,
|
||||
) -> None:
|
||||
"""
|
||||
Delete a file identified by a bucket and key.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
:returns: The raw file content as a binary response.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -783,6 +783,48 @@ class OpenAICompletion(BaseModel):
|
|||
object: Literal["text_completion"] = "text_completion"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIEmbeddingData(BaseModel):
|
||||
"""A single embedding data object from an OpenAI-compatible embeddings response.
|
||||
|
||||
:param object: The object type, which will be "embedding"
|
||||
:param embedding: The embedding vector as a list of floats (when encoding_format="float") or as a base64-encoded string (when encoding_format="base64")
|
||||
:param index: The index of the embedding in the input list
|
||||
"""
|
||||
|
||||
object: Literal["embedding"] = "embedding"
|
||||
embedding: list[float] | str
|
||||
index: int
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIEmbeddingUsage(BaseModel):
|
||||
"""Usage information for an OpenAI-compatible embeddings response.
|
||||
|
||||
:param prompt_tokens: The number of tokens in the input
|
||||
:param total_tokens: The total number of tokens used
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIEmbeddingsResponse(BaseModel):
|
||||
"""Response from an OpenAI-compatible embeddings request.
|
||||
|
||||
:param object: The object type, which will be "list"
|
||||
:param data: List of embedding data objects
|
||||
:param model: The model that was used to generate the embeddings
|
||||
:param usage: Usage information
|
||||
"""
|
||||
|
||||
object: Literal["list"] = "list"
|
||||
data: list[OpenAIEmbeddingData]
|
||||
model: str
|
||||
usage: OpenAIEmbeddingUsage
|
||||
|
||||
|
||||
class ModelStore(Protocol):
|
||||
async def get_model(self, identifier: str) -> Model: ...
|
||||
|
||||
|
@ -1076,6 +1118,26 @@ class InferenceProvider(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/embeddings", method="POST")
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""Generate OpenAI-compatible embeddings for the given input using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
|
||||
:param encoding_format: (Optional) The format to return the embeddings in. Can be either "float" or "base64". Defaults to "float".
|
||||
:param dimensions: (Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
|
||||
:param user: (Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
|
||||
:returns: An OpenAIEmbeddingsResponse containing the embeddings.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Inference(InferenceProvider):
|
||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
|
|
@ -19,8 +19,16 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
"""
|
||||
A chunk of content that can be inserted into a vector database.
|
||||
:param content: The content of the chunk, which can be interleaved text, images, or other types.
|
||||
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
|
||||
:param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information.
|
||||
"""
|
||||
|
||||
content: InterleavedContent
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
embedding: list[float] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -50,7 +58,10 @@ class VectorIO(Protocol):
|
|||
"""Insert chunks into a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to insert the chunks into.
|
||||
:param chunks: The chunks to insert.
|
||||
:param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types.
|
||||
`metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional.
|
||||
If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
|
||||
If `embedding` is not provided, it will be computed later.
|
||||
:param ttl_seconds: The time to live of the chunks.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -35,7 +35,8 @@ class StackRun(Subcommand):
|
|||
"config",
|
||||
type=str,
|
||||
nargs="?", # Make it optional
|
||||
help="Path to config file to use for the run. Required for venv and conda environments.",
|
||||
metavar="config | template",
|
||||
help="Path to config file to use for the run or name of known template (`llama stack list` for a list).",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--port",
|
||||
|
@ -59,7 +60,7 @@ class StackRun(Subcommand):
|
|||
"--image-type",
|
||||
type=str,
|
||||
help="Image Type used during the build. This can be either conda or container or venv.",
|
||||
choices=[e.value for e in ImageType],
|
||||
choices=[e.value for e in ImageType if e.value != ImageType.CONTAINER.value],
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--enable-ui",
|
||||
|
@ -154,7 +155,10 @@ class StackRun(Subcommand):
|
|||
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
|
||||
if callable(getattr(args, arg)):
|
||||
continue
|
||||
setattr(server_args, arg, getattr(args, arg))
|
||||
if arg == "config" and template_name:
|
||||
server_args.config = str(config_file)
|
||||
else:
|
||||
setattr(server_args, arg, getattr(args, arg))
|
||||
|
||||
# Run the server
|
||||
server_main(server_args)
|
||||
|
|
|
@ -1,86 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def check_access(
|
||||
obj_identifier: str,
|
||||
obj_attributes: AccessAttributes | None,
|
||||
user_attributes: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""Check if the current user has access to the given object, based on access attributes.
|
||||
|
||||
Access control algorithm:
|
||||
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
|
||||
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
|
||||
3. For each attribute category in the resource's access_attributes:
|
||||
a. If the user lacks that category, access is DENIED
|
||||
b. If the user has the category but none of the required values, access is DENIED
|
||||
c. If the user has at least one matching value in each required category, access is GRANTED
|
||||
|
||||
Example:
|
||||
# Resource requires:
|
||||
access_attributes = AccessAttributes(
|
||||
roles=["admin", "data-scientist"],
|
||||
teams=["ml-team"]
|
||||
)
|
||||
|
||||
# User has:
|
||||
user_attributes = {
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "infra-team"],
|
||||
"projects": ["llama-3"]
|
||||
}
|
||||
|
||||
# Result: Access GRANTED
|
||||
# - User has the "data-scientist" role (matches one of the required roles)
|
||||
# - AND user is part of the "ml-team" (matches the required team)
|
||||
# - The extra "projects" attribute is ignored
|
||||
|
||||
Args:
|
||||
obj_identifier: The identifier of the resource object to check access for
|
||||
obj_attributes: The access attributes of the resource object
|
||||
user_attributes: The attributes of the current user
|
||||
|
||||
Returns:
|
||||
bool: True if access is granted, False if denied
|
||||
"""
|
||||
# If object has no access attributes, allow access by default
|
||||
if not obj_attributes:
|
||||
return True
|
||||
|
||||
# If no user attributes, deny access to objects with access control
|
||||
if not user_attributes:
|
||||
return False
|
||||
|
||||
dict_attribs = obj_attributes.model_dump(exclude_none=True)
|
||||
if not dict_attribs:
|
||||
return True
|
||||
|
||||
# Check each attribute category (requires ALL categories to match)
|
||||
# TODO: formalize this into a proper ABAC policy
|
||||
for attr_key, required_values in dict_attribs.items():
|
||||
user_values = user_attributes.get(attr_key, [])
|
||||
|
||||
if not user_values:
|
||||
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
|
||||
return False
|
||||
|
||||
if not any(val in user_values for val in required_values):
|
||||
logger.debug(
|
||||
f"Access denied to {obj_identifier}: "
|
||||
f"no match for attribute '{attr_key}', required one of {required_values}"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.debug(f"Access granted to {obj_identifier}")
|
||||
return True
|
5
llama_stack/distribution/access_control/__init__.py
Normal file
5
llama_stack/distribution/access_control/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
109
llama_stack/distribution/access_control/access_control.py
Normal file
109
llama_stack/distribution/access_control/access_control.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import User
|
||||
|
||||
from .conditions import (
|
||||
Condition,
|
||||
ProtectedResource,
|
||||
parse_conditions,
|
||||
)
|
||||
from .datatypes import (
|
||||
AccessRule,
|
||||
Action,
|
||||
Scope,
|
||||
)
|
||||
|
||||
|
||||
def matches_resource(resource_scope: str, actual_resource: str) -> bool:
|
||||
if resource_scope == actual_resource:
|
||||
return True
|
||||
return resource_scope.endswith("::*") and actual_resource.startswith(resource_scope[:-1])
|
||||
|
||||
|
||||
def matches_scope(
|
||||
scope: Scope,
|
||||
action: Action,
|
||||
resource: str,
|
||||
user: str | None,
|
||||
) -> bool:
|
||||
if scope.resource and not matches_resource(scope.resource, resource):
|
||||
return False
|
||||
if scope.principal and scope.principal != user:
|
||||
return False
|
||||
return action in scope.actions
|
||||
|
||||
|
||||
def as_list(obj: Any) -> list[Any]:
|
||||
if isinstance(obj, list):
|
||||
return obj
|
||||
return [obj]
|
||||
|
||||
|
||||
def matches_conditions(
|
||||
conditions: list[Condition],
|
||||
resource: ProtectedResource,
|
||||
user: User,
|
||||
) -> bool:
|
||||
for condition in conditions:
|
||||
# must match all conditions
|
||||
if not condition.matches(resource, user):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def default_policy() -> list[AccessRule]:
|
||||
# for backwards compatibility, if no rules are provided, assume
|
||||
# full access subject to previous attribute matching rules
|
||||
return [
|
||||
AccessRule(
|
||||
permit=Scope(actions=list(Action)),
|
||||
when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def is_action_allowed(
|
||||
policy: list[AccessRule],
|
||||
action: Action,
|
||||
resource: ProtectedResource,
|
||||
user: User | None,
|
||||
) -> bool:
|
||||
# If user is not set, assume authentication is not enabled
|
||||
if not user:
|
||||
return True
|
||||
|
||||
if not len(policy):
|
||||
policy = default_policy()
|
||||
|
||||
qualified_resource_id = resource.type + "::" + resource.identifier
|
||||
for rule in policy:
|
||||
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
|
||||
if rule.when:
|
||||
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
|
||||
return False
|
||||
elif rule.unless:
|
||||
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
|
||||
if rule.when:
|
||||
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
|
||||
return True
|
||||
elif rule.unless:
|
||||
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
|
||||
return True
|
||||
else:
|
||||
return True
|
||||
# assume access is denied unless we find a rule that permits access
|
||||
return False
|
||||
|
||||
|
||||
class AccessDeniedError(RuntimeError):
|
||||
pass
|
129
llama_stack/distribution/access_control/conditions.py
Normal file
129
llama_stack/distribution/access_control/conditions.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class User(Protocol):
|
||||
principal: str
|
||||
attributes: dict[str, list[str]] | None
|
||||
|
||||
|
||||
class ProtectedResource(Protocol):
|
||||
type: str
|
||||
identifier: str
|
||||
owner: User
|
||||
|
||||
|
||||
class Condition(Protocol):
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool: ...
|
||||
|
||||
|
||||
class UserInOwnersList:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
def owners_values(self, resource: ProtectedResource) -> list[str] | None:
|
||||
if (
|
||||
hasattr(resource, "owner")
|
||||
and resource.owner
|
||||
and resource.owner.attributes
|
||||
and self.name in resource.owner.attributes
|
||||
):
|
||||
return resource.owner.attributes[self.name]
|
||||
else:
|
||||
return None
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
required = self.owners_values(resource)
|
||||
if not required:
|
||||
return True
|
||||
if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]:
|
||||
return False
|
||||
user_values = user.attributes[self.name]
|
||||
for value in required:
|
||||
if value in user_values:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f"user in owners {self.name}"
|
||||
|
||||
|
||||
class UserNotInOwnersList(UserInOwnersList):
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return not super().matches(resource, user)
|
||||
|
||||
def __repr__(self):
|
||||
return f"user not in owners {self.name}"
|
||||
|
||||
|
||||
class UserWithValueInList:
|
||||
def __init__(self, name: str, value: str):
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
if user.attributes and self.name in user.attributes:
|
||||
return self.value in user.attributes[self.name]
|
||||
print(f"User does not have {self.value} in {self.name}")
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f"user with {self.value} in {self.name}"
|
||||
|
||||
|
||||
class UserWithValueNotInList(UserWithValueInList):
|
||||
def __init__(self, name: str, value: str):
|
||||
super().__init__(name, value)
|
||||
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return not super().matches(resource, user)
|
||||
|
||||
def __repr__(self):
|
||||
return f"user with {self.value} not in {self.name}"
|
||||
|
||||
|
||||
class UserIsOwner:
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return resource.owner.principal == user.principal if resource.owner else False
|
||||
|
||||
def __repr__(self):
|
||||
return "user is owner"
|
||||
|
||||
|
||||
class UserIsNotOwner:
|
||||
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||
return not resource.owner or resource.owner.principal != user.principal
|
||||
|
||||
def __repr__(self):
|
||||
return "user is not owner"
|
||||
|
||||
|
||||
def parse_condition(condition: str) -> Condition:
|
||||
words = condition.split()
|
||||
match words:
|
||||
case ["user", "is", "owner"]:
|
||||
return UserIsOwner()
|
||||
case ["user", "is", "not", "owner"]:
|
||||
return UserIsNotOwner()
|
||||
case ["user", "with", value, "in", name]:
|
||||
return UserWithValueInList(name, value)
|
||||
case ["user", "with", value, "not", "in", name]:
|
||||
return UserWithValueNotInList(name, value)
|
||||
case ["user", "in", "owners", name]:
|
||||
return UserInOwnersList(name)
|
||||
case ["user", "not", "in", "owners", name]:
|
||||
return UserNotInOwnersList(name)
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def parse_conditions(conditions: list[str]) -> list[Condition]:
|
||||
return [parse_condition(c) for c in conditions]
|
107
llama_stack/distribution/access_control/datatypes.py
Normal file
107
llama_stack/distribution/access_control/datatypes.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from .conditions import parse_conditions
|
||||
|
||||
|
||||
class Action(str, Enum):
|
||||
CREATE = "create"
|
||||
READ = "read"
|
||||
UPDATE = "update"
|
||||
DELETE = "delete"
|
||||
|
||||
|
||||
class Scope(BaseModel):
|
||||
principal: str | None = None
|
||||
actions: Action | list[Action]
|
||||
resource: str | None = None
|
||||
|
||||
|
||||
def _mutually_exclusive(obj, a: str, b: str):
|
||||
if getattr(obj, a) and getattr(obj, b):
|
||||
raise ValueError(f"{a} and {b} are mutually exclusive")
|
||||
|
||||
|
||||
def _require_one_of(obj, a: str, b: str):
|
||||
if not getattr(obj, a) and not getattr(obj, b):
|
||||
raise ValueError(f"on of {a} or {b} is required")
|
||||
|
||||
|
||||
class AccessRule(BaseModel):
|
||||
"""Access rule based loosely on cedar policy language
|
||||
|
||||
A rule defines a list of action either to permit or to forbid. It may specify a
|
||||
principal or a resource that must match for the rule to take effect. The resource
|
||||
to match should be specified in the form of a type qualified identifier, e.g.
|
||||
model::my-model or vector_db::some-db, or a wildcard for all resources of a type,
|
||||
e.g. model::*. If the principal or resource are not specified, they will match all
|
||||
requests.
|
||||
|
||||
A rule may also specify a condition, either a 'when' or an 'unless', with additional
|
||||
constraints as to where the rule applies. The constraints supported at present are:
|
||||
|
||||
- 'user with <attr-value> in <attr-name>'
|
||||
- 'user with <attr-value> not in <attr-name>'
|
||||
- 'user is owner'
|
||||
- 'user is not owner'
|
||||
- 'user in owners <attr-name>'
|
||||
- 'user not in owners <attr-name>'
|
||||
|
||||
Rules are tested in order to find a match. If a match is found, the request is
|
||||
permitted or forbidden depending on the type of rule. If no match is found, the
|
||||
request is denied. If no rules are specified, a rule that allows any action as
|
||||
long as the resource attributes match the user attributes is added
|
||||
(i.e. the previous behaviour is the default).
|
||||
|
||||
Some examples in yaml:
|
||||
|
||||
- permit:
|
||||
principal: user-1
|
||||
actions: [create, read, delete]
|
||||
resource: model::*
|
||||
description: user-1 has full access to all models
|
||||
- permit:
|
||||
principal: user-2
|
||||
actions: [read]
|
||||
resource: model::model-1
|
||||
description: user-2 has read access to model-1 only
|
||||
- permit:
|
||||
actions: [read]
|
||||
when: user in owner teams
|
||||
description: any user has read access to any resource created by a member of their team
|
||||
- forbid:
|
||||
actions: [create, read, delete]
|
||||
resource: vector_db::*
|
||||
unless: user with admin in roles
|
||||
description: only user with admin role can use vector_db resources
|
||||
|
||||
"""
|
||||
|
||||
permit: Scope | None = None
|
||||
forbid: Scope | None = None
|
||||
when: str | list[str] | None = None
|
||||
unless: str | list[str] | None = None
|
||||
description: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_rule_format(self) -> Self:
|
||||
_require_one_of(self, "permit", "forbid")
|
||||
_mutually_exclusive(self, "permit", "forbid")
|
||||
_mutually_exclusive(self, "when", "unless")
|
||||
if isinstance(self.when, list):
|
||||
parse_conditions(self.when)
|
||||
elif self.when:
|
||||
parse_conditions([self.when])
|
||||
if isinstance(self.unless, list):
|
||||
parse_conditions(self.unless)
|
||||
elif self.unless:
|
||||
parse_conditions([self.unless])
|
||||
return self
|
|
@ -29,6 +29,8 @@ SERVER_DEPENDENCIES = [
|
|||
"fire",
|
||||
"httpx",
|
||||
"uvicorn",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
]
|
||||
|
||||
|
||||
|
@ -41,23 +43,12 @@ def get_provider_dependencies(
|
|||
config: BuildConfig | DistributionTemplate,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Get normal and special dependencies from provider configuration."""
|
||||
# Extract providers based on config type
|
||||
if isinstance(config, DistributionTemplate):
|
||||
providers = config.providers
|
||||
config = config.build_config()
|
||||
|
||||
providers = config.distribution_spec.providers
|
||||
additional_pip_packages = config.additional_pip_packages
|
||||
|
||||
# TODO: This is a hack to get the dependencies for internal APIs into build
|
||||
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
||||
# and providers, with a way to specify dependencies for them.
|
||||
run_configs = config.run_configs
|
||||
additional_pip_packages: list[str] = []
|
||||
if run_configs:
|
||||
for run_config in run_configs.values():
|
||||
run_config_ = run_config.run_config(name="", providers={}, container_image=None)
|
||||
if run_config_.inference_store:
|
||||
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
|
||||
elif isinstance(config, BuildConfig):
|
||||
providers = config.distribution_spec.providers
|
||||
additional_pip_packages = config.additional_pip_packages
|
||||
deps = []
|
||||
registry = get_provider_registry(config)
|
||||
for api_str, provider_or_providers in providers.items():
|
||||
|
@ -85,8 +76,7 @@ def get_provider_dependencies(
|
|||
else:
|
||||
normal_deps.append(package)
|
||||
|
||||
if additional_pip_packages:
|
||||
normal_deps.extend(additional_pip_packages)
|
||||
normal_deps.extend(additional_pip_packages or [])
|
||||
|
||||
return list(set(normal_deps)), list(set(special_deps))
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput
|
|||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.access_control.datatypes import AccessRule
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
||||
|
@ -35,126 +36,66 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
|||
RoutingKey = str | list[str]
|
||||
|
||||
|
||||
class AccessAttributes(BaseModel):
|
||||
"""Structured representation of user attributes for access control.
|
||||
class User(BaseModel):
|
||||
principal: str
|
||||
# further attributes that may be used for access control decisions
|
||||
attributes: dict[str, list[str]] | None = None
|
||||
|
||||
This model defines a structured approach to representing user attributes
|
||||
with common standard categories for access control.
|
||||
|
||||
Standard attribute categories include:
|
||||
- roles: Role-based attributes (e.g., admin, data-scientist)
|
||||
- teams: Team-based attributes (e.g., ml-team, infra-team)
|
||||
- projects: Project access attributes (e.g., llama-3, customer-insights)
|
||||
- namespaces: Namespace-based access control for resource isolation
|
||||
"""
|
||||
|
||||
# Standard attribute categories - the minimal set we need now
|
||||
roles: list[str] | None = Field(
|
||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
||||
)
|
||||
|
||||
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||
|
||||
projects: list[str] | None = Field(
|
||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
||||
)
|
||||
|
||||
namespaces: list[str] | None = Field(
|
||||
default=None, description="Namespace-based access control for resource isolation"
|
||||
)
|
||||
def __init__(self, principal: str, attributes: dict[str, list[str]] | None):
|
||||
super().__init__(principal=principal, attributes=attributes)
|
||||
|
||||
|
||||
class ResourceWithACL(Resource):
|
||||
"""Extension of Resource that adds attribute-based access control capabilities.
|
||||
class ResourceWithOwner(Resource):
|
||||
"""Extension of Resource that adds an optional owner, i.e. the user that created the
|
||||
resource. This can be used to constrain access to the resource."""
|
||||
|
||||
This class adds an optional access_attributes field that allows fine-grained control
|
||||
over which users can access each resource. When attributes are defined, a user must have
|
||||
matching attributes to access the resource.
|
||||
|
||||
Attribute Matching Algorithm:
|
||||
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
|
||||
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
|
||||
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
|
||||
4. Within each category, ANY value match is sufficient (OR relationship within a category)
|
||||
|
||||
Examples:
|
||||
# Resource visible to everyone (no access control)
|
||||
model = Model(identifier="llama-2", ...)
|
||||
|
||||
# Resource visible only to admins
|
||||
model = Model(
|
||||
identifier="gpt-4",
|
||||
access_attributes=AccessAttributes(roles=["admin"])
|
||||
)
|
||||
|
||||
# Resource visible to data scientists on the ML team
|
||||
model = Model(
|
||||
identifier="private-model",
|
||||
access_attributes=AccessAttributes(
|
||||
roles=["data-scientist", "researcher"],
|
||||
teams=["ml-team"]
|
||||
)
|
||||
)
|
||||
# ^ User must have at least one of the roles AND be on the ml-team
|
||||
|
||||
# Resource visible to users with specific project access
|
||||
vector_db = VectorDB(
|
||||
identifier="customer-embeddings",
|
||||
access_attributes=AccessAttributes(
|
||||
projects=["customer-insights"],
|
||||
namespaces=["confidential"]
|
||||
)
|
||||
)
|
||||
# ^ User must have access to the customer-insights project AND have confidential namespace
|
||||
"""
|
||||
|
||||
access_attributes: AccessAttributes | None = None
|
||||
owner: User | None = None
|
||||
|
||||
|
||||
# Use the extended Resource for all routable objects
|
||||
class ModelWithACL(Model, ResourceWithACL):
|
||||
class ModelWithOwner(Model, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ShieldWithACL(Shield, ResourceWithACL):
|
||||
class ShieldWithOwner(Shield, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class VectorDBWithACL(VectorDB, ResourceWithACL):
|
||||
class VectorDBWithOwner(VectorDB, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class DatasetWithACL(Dataset, ResourceWithACL):
|
||||
class DatasetWithOwner(Dataset, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
|
||||
class ScoringFnWithOwner(ScoringFn, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class BenchmarkWithACL(Benchmark, ResourceWithACL):
|
||||
class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ToolWithACL(Tool, ResourceWithACL):
|
||||
class ToolWithOwner(Tool, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
||||
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
||||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
ModelWithACL
|
||||
| ShieldWithACL
|
||||
| VectorDBWithACL
|
||||
| DatasetWithACL
|
||||
| ScoringFnWithACL
|
||||
| BenchmarkWithACL
|
||||
| ToolWithACL
|
||||
| ToolGroupWithACL,
|
||||
ModelWithOwner
|
||||
| ShieldWithOwner
|
||||
| VectorDBWithOwner
|
||||
| DatasetWithOwner
|
||||
| ScoringFnWithOwner
|
||||
| BenchmarkWithOwner
|
||||
| ToolWithOwner
|
||||
| ToolGroupWithOwner,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
@ -234,6 +175,7 @@ class AuthenticationConfig(BaseModel):
|
|||
...,
|
||||
description="Provider-specific configuration",
|
||||
)
|
||||
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
|
||||
|
||||
|
||||
class AuthenticationRequiredError(Exception):
|
||||
|
|
|
@ -149,12 +149,13 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
# NOTE: We are using AsyncLlamaStackClient under the hood
|
||||
# A new event loop is needed to convert the AsyncStream
|
||||
# from async client into SyncStream return type for streaming
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if kwargs.get("stream"):
|
||||
# NOTE: We are using AsyncLlamaStackClient under the hood
|
||||
# A new event loop is needed to convert the AsyncStream
|
||||
# from async client into SyncStream return type for streaming
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
def sync_generator():
|
||||
try:
|
||||
|
@ -172,7 +173,14 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
|
||||
return sync_generator()
|
||||
else:
|
||||
return asyncio.run(self.async_client.request(*args, **kwargs))
|
||||
try:
|
||||
result = loop.run_until_complete(self.async_client.request(*args, **kwargs))
|
||||
finally:
|
||||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
loop.close()
|
||||
return result
|
||||
|
||||
|
||||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||
|
|
|
@ -10,6 +10,8 @@ import logging
|
|||
from contextlib import AbstractContextManager
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import User
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -21,12 +23,10 @@ PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
|||
class RequestProviderDataContext(AbstractContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
||||
def __init__(
|
||||
self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
|
||||
):
|
||||
def __init__(self, provider_data: dict[str, Any] | None = None, user: User | None = None):
|
||||
self.provider_data = provider_data or {}
|
||||
if auth_attributes:
|
||||
self.provider_data["__auth_attributes"] = auth_attributes
|
||||
if user:
|
||||
self.provider_data["__authenticated_user"] = user
|
||||
|
||||
self.token = None
|
||||
|
||||
|
@ -95,9 +95,9 @@ def request_provider_data_context(
|
|||
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||
|
||||
|
||||
def get_auth_attributes() -> dict[str, list[str]] | None:
|
||||
def get_authenticated_user() -> User | None:
|
||||
"""Helper to retrieve auth attributes from the provider data context"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
if not provider_data:
|
||||
return None
|
||||
return provider_data.get("__auth_attributes")
|
||||
return provider_data.get("__authenticated_user")
|
||||
|
|
|
@ -28,6 +28,7 @@ from llama_stack.apis.vector_dbs import VectorDBs
|
|||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.client import get_client_impl
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessRule,
|
||||
AutoRoutedProviderSpec,
|
||||
Provider,
|
||||
RoutingTableProviderSpec,
|
||||
|
@ -118,6 +119,7 @@ async def resolve_impls(
|
|||
run_config: StackRunConfig,
|
||||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
policy: list[AccessRule],
|
||||
) -> dict[Api, Any]:
|
||||
"""
|
||||
Resolves provider implementations by:
|
||||
|
@ -140,7 +142,7 @@ async def resolve_impls(
|
|||
|
||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config)
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy)
|
||||
|
||||
|
||||
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||
|
@ -247,6 +249,7 @@ async def instantiate_providers(
|
|||
router_apis: set[Api],
|
||||
dist_registry: DistributionRegistry,
|
||||
run_config: StackRunConfig,
|
||||
policy: list[AccessRule],
|
||||
) -> dict:
|
||||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: dict[Api, Any] = {}
|
||||
|
@ -261,7 +264,7 @@ async def instantiate_providers(
|
|||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config)
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
|
||||
|
||||
if api_str.startswith("inner-"):
|
||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||
|
@ -312,6 +315,7 @@ async def instantiate_provider(
|
|||
inner_impls: dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
run_config: StackRunConfig,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
provider_spec = provider.spec
|
||||
if not hasattr(provider_spec, "module"):
|
||||
|
@ -336,13 +340,15 @@ async def instantiate_provider(
|
|||
method = "get_routing_table_impl"
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, inner_impls, deps, dist_registry]
|
||||
args = [provider_spec.api, inner_impls, deps, dist_registry, policy]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider.config)
|
||||
args = [config, deps]
|
||||
if "policy" in inspect.signature(getattr(module, method)).parameters:
|
||||
args.append(policy)
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
||||
from llama_stack.distribution.datatypes import AccessRule, RoutedProtocol
|
||||
from llama_stack.distribution.stack import StackRunConfig
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
@ -18,6 +18,7 @@ async def get_routing_table_impl(
|
|||
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||
_deps,
|
||||
dist_registry: DistributionRegistry,
|
||||
policy: list[AccessRule],
|
||||
) -> Any:
|
||||
from ..routing_tables.benchmarks import BenchmarksRoutingTable
|
||||
from ..routing_tables.datasets import DatasetsRoutingTable
|
||||
|
@ -40,7 +41,7 @@ async def get_routing_table_impl(
|
|||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
|
||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
|
|
@ -45,6 +45,7 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
|
@ -546,6 +547,34 @@ class InferenceRouter(Inference):
|
|||
await self.store.store_chat_completion(response, messages)
|
||||
return response
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
|
||||
)
|
||||
model_obj = await self.routing_table.get_model(model)
|
||||
if model_obj is None:
|
||||
raise ValueError(f"Model '{model}' not found")
|
||||
if model_obj.model_type != ModelType.embedding:
|
||||
raise ValueError(f"Model '{model}' is not an embedding model")
|
||||
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
input=input,
|
||||
encoding_format=encoding_format,
|
||||
dimensions=dimensions,
|
||||
user=user,
|
||||
)
|
||||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_embeddings(**params)
|
||||
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
after: str | None = None,
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BenchmarkWithACL,
|
||||
BenchmarkWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -47,7 +47,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
)
|
||||
if provider_benchmark_id is None:
|
||||
provider_benchmark_id = benchmark_id
|
||||
benchmark = BenchmarkWithACL(
|
||||
benchmark = BenchmarkWithOwner(
|
||||
identifier=benchmark_id,
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=scoring_functions,
|
||||
|
|
|
@ -8,14 +8,14 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessAttributes,
|
||||
AccessRule,
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
from llama_stack.distribution.request_headers import get_authenticated_user
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
@ -73,9 +73,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
self,
|
||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||
dist_registry: DistributionRegistry,
|
||||
policy: list[AccessRule],
|
||||
) -> None:
|
||||
self.impls_by_provider_id = impls_by_provider_id
|
||||
self.dist_registry = dist_registry
|
||||
self.policy = policy
|
||||
|
||||
async def initialize(self) -> None:
|
||||
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||
|
@ -166,13 +168,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
return None
|
||||
|
||||
# Check if user has permission to access this object
|
||||
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
|
||||
logger.debug(f"Access denied to {type} '{identifier}'")
|
||||
return None
|
||||
|
||||
return obj
|
||||
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
if not is_action_allowed(self.policy, "delete", obj, get_authenticated_user()):
|
||||
raise AccessDeniedError()
|
||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||
|
||||
|
@ -187,11 +191,12 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
# If object supports access control but no attributes set, use creator's attributes
|
||||
if not obj.access_attributes:
|
||||
creator_attributes = get_auth_attributes()
|
||||
if creator_attributes:
|
||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||
creator = get_authenticated_user()
|
||||
if not is_action_allowed(self.policy, "create", obj, creator):
|
||||
raise AccessDeniedError()
|
||||
if creator:
|
||||
obj.owner = creator
|
||||
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
|
@ -210,9 +215,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
# Apply attribute-based access control filtering
|
||||
if filtered_objs:
|
||||
filtered_objs = [
|
||||
obj
|
||||
for obj in filtered_objs
|
||||
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
||||
obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
|
||||
]
|
||||
|
||||
return filtered_objs
|
||||
|
|
|
@ -19,7 +19,7 @@ from llama_stack.apis.datasets import (
|
|||
)
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
DatasetWithACL,
|
||||
DatasetWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -74,7 +74,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
dataset = DatasetWithACL(
|
||||
dataset = DatasetWithOwner(
|
||||
identifier=dataset_id,
|
||||
provider_resource_id=provider_dataset_id,
|
||||
provider_id=provider_id,
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelWithACL,
|
||||
ModelWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -65,7 +65,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
model_type = ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.scoring_functions import (
|
|||
ScoringFunctions,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ScoringFnWithACL,
|
||||
ScoringFnWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -50,7 +50,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
scoring_fn = ScoringFnWithACL(
|
||||
scoring_fn = ScoringFnWithOwner(
|
||||
identifier=scoring_fn_id,
|
||||
description=description,
|
||||
return_type=return_type,
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any
|
|||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ShieldWithACL,
|
||||
ShieldWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -47,7 +47,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
)
|
||||
if params is None:
|
||||
params = {}
|
||||
shield = ShieldWithACL(
|
||||
shield = ShieldWithOwner(
|
||||
identifier=shield_id,
|
||||
provider_resource_id=provider_shield_id,
|
||||
provider_id=provider_id,
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||
from llama_stack.distribution.datatypes import ToolGroupWithACL
|
||||
from llama_stack.distribution.datatypes import ToolGroupWithOwner
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
@ -106,7 +106,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
mcp_endpoint: URL | None = None,
|
||||
args: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
toolgroup = ToolGroupWithACL(
|
||||
toolgroup = ToolGroupWithOwner(
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
|
|
|
@ -10,7 +10,7 @@ from llama_stack.apis.models import ModelType
|
|||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.distribution.datatypes import (
|
||||
VectorDBWithACL,
|
||||
VectorDBWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -63,7 +63,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
||||
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
return vector_db
|
||||
|
||||
|
|
|
@ -105,24 +105,16 @@ class AuthenticationMiddleware:
|
|||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if validation_result.access_attributes:
|
||||
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to token by default")
|
||||
user_attributes = {
|
||||
"roles": [token],
|
||||
}
|
||||
|
||||
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
||||
# can identify the requester and enforce per-client rate limits.
|
||||
scope["authenticated_client_id"] = token
|
||||
|
||||
# Store attributes in request scope
|
||||
scope["user_attributes"] = user_attributes
|
||||
scope["principal"] = validation_result.principal
|
||||
if validation_result.attributes:
|
||||
scope["user_attributes"] = validation_result.attributes
|
||||
logger.debug(
|
||||
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
|
||||
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
||||
)
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
|
|
@ -16,43 +16,18 @@ from jose import jwt
|
|||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class TokenValidationResult(BaseModel):
|
||||
principal: str | None = Field(
|
||||
default=None,
|
||||
description="The principal (username or persistent identifier) of the authenticated user",
|
||||
)
|
||||
access_attributes: AccessAttributes | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Structured user attributes for attribute-based access control.
|
||||
|
||||
These attributes determine which resources the user can access.
|
||||
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
||||
Each attribute category contains a list of values that the user has for that category.
|
||||
During access control checks, these values are compared against resource requirements.
|
||||
|
||||
Example with standard categories:
|
||||
```json
|
||||
{
|
||||
"roles": ["admin", "data-scientist"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["llama-3"],
|
||||
"namespaces": ["research"]
|
||||
}
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class AuthResponse(TokenValidationResult):
|
||||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
principal: str
|
||||
# further attributes that may be used for access control decisions
|
||||
attributes: dict[str, list[str]] | None = None
|
||||
message: str | None = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
@ -78,7 +53,7 @@ class AuthProvider(ABC):
|
|||
"""Abstract base class for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token and return access attributes."""
|
||||
pass
|
||||
|
||||
|
@ -88,10 +63,10 @@ class AuthProvider(ABC):
|
|||
pass
|
||||
|
||||
|
||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
||||
attributes = AccessAttributes()
|
||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
||||
attributes: dict[str, list[str]] = {}
|
||||
for claim_key, attribute_key in mapping.items():
|
||||
if claim_key not in claims or not hasattr(attributes, attribute_key):
|
||||
if claim_key not in claims:
|
||||
continue
|
||||
claim = claims[claim_key]
|
||||
if isinstance(claim, list):
|
||||
|
@ -99,11 +74,10 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
|||
else:
|
||||
values = claim.split()
|
||||
|
||||
current = getattr(attributes, attribute_key)
|
||||
if current:
|
||||
current.extend(values)
|
||||
if attribute_key in attributes:
|
||||
attributes[attribute_key].extend(values)
|
||||
else:
|
||||
setattr(attributes, attribute_key, values)
|
||||
attributes[attribute_key] = values
|
||||
return attributes
|
||||
|
||||
|
||||
|
@ -145,8 +119,6 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
|
|||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
if value not in AccessAttributes.model_fields:
|
||||
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
@ -171,14 +143,14 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
self._jwks: dict[str, str] = {}
|
||||
self._jwks_lock = Lock()
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
if self.config.jwks:
|
||||
return await self.validate_jwt_token(token, scope)
|
||||
if self.config.introspection:
|
||||
return await self.introspect_token(token, scope)
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
|
||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using the JWT token."""
|
||||
await self._refresh_jwks()
|
||||
|
||||
|
@ -203,12 +175,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
# We should incorporate these into the access attributes.
|
||||
principal = claims["sub"]
|
||||
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||
return TokenValidationResult(
|
||||
return User(
|
||||
principal=principal,
|
||||
access_attributes=access_attributes,
|
||||
attributes=access_attributes,
|
||||
)
|
||||
|
||||
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
async def introspect_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using token introspection as defined by RFC 7662."""
|
||||
form = {
|
||||
"token": token,
|
||||
|
@ -242,9 +214,9 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
raise ValueError("Token not active")
|
||||
principal = fields["sub"] or fields["username"]
|
||||
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
||||
return TokenValidationResult(
|
||||
return User(
|
||||
principal=principal,
|
||||
access_attributes=access_attributes,
|
||||
attributes=access_attributes,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Token introspection request timed out")
|
||||
|
@ -299,7 +271,7 @@ class CustomAuthProvider(AuthProvider):
|
|||
self.config = config
|
||||
self._client = None
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using the custom authentication endpoint."""
|
||||
if scope is None:
|
||||
scope = {}
|
||||
|
@ -341,7 +313,7 @@ class CustomAuthProvider(AuthProvider):
|
|||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
return auth_response
|
||||
return User(auth_response.principal, auth_response.attributes)
|
||||
except Exception as e:
|
||||
logger.exception("Error parsing authentication response")
|
||||
raise ValueError("Invalid authentication response format") from e
|
||||
|
|
|
@ -18,7 +18,7 @@ from collections.abc import Callable
|
|||
from contextlib import asynccontextmanager
|
||||
from importlib.metadata import version as parse_version
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated, Any, get_origin
|
||||
|
||||
import rich.pretty
|
||||
import yaml
|
||||
|
@ -26,17 +26,13 @@ from aiohttp import hdrs
|
|||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
from fastapi import Path as FastapiPath
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.server.routes import (
|
||||
find_matching_route,
|
||||
|
@ -217,11 +213,13 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
async def route_handler(request: Request, **kwargs):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
principal = request.scope.get("principal", "")
|
||||
user = User(principal, user_attributes)
|
||||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user_attributes):
|
||||
with request_provider_data_context(request.headers, user):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
try:
|
||||
|
@ -244,15 +242,23 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
|
||||
path_params = extract_path_params(route)
|
||||
if method == "post":
|
||||
# Annotate parameters that are in the path with Path(...) and others with Body(...)
|
||||
new_params = [new_params[0]] + [
|
||||
(
|
||||
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
|
||||
if param.name in path_params
|
||||
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
||||
)
|
||||
for param in new_params[1:]
|
||||
]
|
||||
# Annotate parameters that are in the path with Path(...) and others with Body(...),
|
||||
# but preserve existing File() and Form() annotations for multipart form data
|
||||
new_params = (
|
||||
[new_params[0]]
|
||||
+ [
|
||||
(
|
||||
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
|
||||
if param.name in path_params
|
||||
else (
|
||||
param # Keep original annotation if it's already an Annotated type
|
||||
if get_origin(param.annotation) is Annotated
|
||||
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
||||
)
|
||||
)
|
||||
for param in new_params[1:]
|
||||
]
|
||||
)
|
||||
|
||||
route_handler.__signature__ = sig.replace(parameters=new_params)
|
||||
|
||||
|
@ -472,17 +478,6 @@ def main(args: argparse.Namespace | None = None):
|
|||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
# --- CORS middleware for local development ---
|
||||
# TODO: move to reverse proxy
|
||||
ui_port = os.environ.get("LLAMA_STACK_UI_PORT", 8322)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[f"http://localhost:{ui_port}"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
except InvalidProviderError as e:
|
||||
|
|
|
@ -223,7 +223,10 @@ async def construct_stack(
|
|||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
||||
) -> dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
||||
impls = await resolve_impls(
|
||||
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
|
||||
)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, run_config)
|
||||
|
|
|
@ -7,10 +7,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
||||
CONTAINER_OPTS=${CONTAINER_OPTS:-}
|
||||
LLAMA_CHECKPOINT_DIR=${LLAMA_CHECKPOINT_DIR:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
PYPI_VERSION=${PYPI_VERSION:-}
|
||||
VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
||||
|
@ -132,63 +128,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
|||
$env_vars \
|
||||
$other_args
|
||||
elif [[ "$env_type" == "container" ]]; then
|
||||
set -x
|
||||
|
||||
# Check if container command is available
|
||||
if ! is_command_available $CONTAINER_BINARY; then
|
||||
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if is_command_available selinuxenabled &> /dev/null && selinuxenabled; then
|
||||
# Disable SELinux labels
|
||||
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
||||
fi
|
||||
|
||||
mounts=""
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):/app/llama-stack-source"
|
||||
fi
|
||||
if [ -n "$LLAMA_CHECKPOINT_DIR" ]; then
|
||||
mounts="$mounts -v $LLAMA_CHECKPOINT_DIR:/root/.llama"
|
||||
CONTAINER_OPTS="$CONTAINER_OPTS --gpus=all"
|
||||
fi
|
||||
|
||||
if [ -n "$PYPI_VERSION" ]; then
|
||||
version_tag="$PYPI_VERSION"
|
||||
elif [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
version_tag="dev"
|
||||
elif [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
version_tag="test-$TEST_PYPI_VERSION"
|
||||
else
|
||||
if ! is_command_available jq; then
|
||||
echo -e "${RED}Error: jq not found" >&2
|
||||
exit 1
|
||||
fi
|
||||
URL="https://pypi.org/pypi/llama-stack/json"
|
||||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||
fi
|
||||
|
||||
# Build the command with optional yaml config
|
||||
cmd="$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||
-p $port:$port \
|
||||
$env_vars \
|
||||
$mounts \
|
||||
--env LLAMA_STACK_PORT=$port \
|
||||
--entrypoint python \
|
||||
$container_image:$version_tag \
|
||||
-m llama_stack.distribution.server.server"
|
||||
|
||||
# Add yaml config if provided, otherwise use default
|
||||
if [ -n "$yaml_config" ]; then
|
||||
cmd="$cmd -v $yaml_config:/app/run.yaml --config /app/run.yaml"
|
||||
else
|
||||
cmd="$cmd --config /app/run.yaml"
|
||||
fi
|
||||
|
||||
# Add any other args
|
||||
cmd="$cmd $other_args"
|
||||
|
||||
# Execute the command
|
||||
eval $cmd
|
||||
echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}"
|
||||
echo -e "Please refer to the documentation for more information: https://llama-stack.readthedocs.io/en/latest/distributions/building_distro.html#llama-stack-build"
|
||||
exit 1
|
||||
fi
|
||||
|
|
|
@ -23,11 +23,8 @@ from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
|||
|
||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||
env_name = ""
|
||||
if image_type == LlamaStackImageType.CONTAINER.value:
|
||||
env_name = (
|
||||
f"distribution-{template_name}" if template_name else (config.container_image if config else image_name)
|
||||
)
|
||||
elif image_type == LlamaStackImageType.CONDA.value:
|
||||
|
||||
if image_type == LlamaStackImageType.CONDA.value:
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
env_name = image_name or current_conda_env
|
||||
if not env_name:
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from collections.abc import Collection, Iterator, Sequence, Set
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
@ -14,7 +13,8 @@ from typing import (
|
|||
)
|
||||
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
|
||||
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -48,19 +48,20 @@ class Tokenizer:
|
|||
global _INSTANCE
|
||||
|
||||
if _INSTANCE is None:
|
||||
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
|
||||
_INSTANCE = Tokenizer(Path(__file__).parent / "tokenizer.model")
|
||||
return _INSTANCE
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
def __init__(self, model_path: Path):
|
||||
"""
|
||||
Initializes the Tokenizer with a Tiktoken model.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the Tiktoken model file.
|
||||
"""
|
||||
assert os.path.isfile(model_path), model_path
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Tokenizer model file not found: {model_path}")
|
||||
|
||||
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||||
mergeable_ranks = load_bpe_file(model_path)
|
||||
num_base_tokens = len(mergeable_ranks)
|
||||
special_tokens = [
|
||||
"<|begin_of_text|>",
|
||||
|
@ -83,7 +84,7 @@ class Tokenizer:
|
|||
|
||||
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
||||
self.model = tiktoken.Encoding(
|
||||
name=Path(model_path).name,
|
||||
name=model_path.name,
|
||||
pat_str=self.pat_str,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=self.special_tokens,
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from collections.abc import Collection, Iterator, Sequence, Set
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
@ -14,7 +13,8 @@ from typing import (
|
|||
)
|
||||
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
|
||||
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -118,19 +118,20 @@ class Tokenizer:
|
|||
global _INSTANCE
|
||||
|
||||
if _INSTANCE is None:
|
||||
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
|
||||
_INSTANCE = Tokenizer(Path(__file__).parent / "tokenizer.model")
|
||||
return _INSTANCE
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
def __init__(self, model_path: Path):
|
||||
"""
|
||||
Initializes the Tokenizer with a Tiktoken model.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the Tiktoken model file.
|
||||
model_path (Path): The path to the Tiktoken model file.
|
||||
"""
|
||||
assert os.path.isfile(model_path), model_path
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Tokenizer model file not found: {model_path}")
|
||||
|
||||
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||||
mergeable_ranks = load_bpe_file(model_path)
|
||||
num_base_tokens = len(mergeable_ranks)
|
||||
|
||||
special_tokens = BASIC_SPECIAL_TOKENS + LLAMA4_SPECIAL_TOKENS
|
||||
|
@ -144,7 +145,7 @@ class Tokenizer:
|
|||
|
||||
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
||||
self.model = tiktoken.Encoding(
|
||||
name=Path(model_path).name,
|
||||
name=model_path.name,
|
||||
pat_str=self.O200K_PATTERN,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=self.special_tokens,
|
||||
|
|
40
llama_stack/models/llama/tokenizer_utils.py
Normal file
40
llama_stack/models/llama/tokenizer_utils.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, "tokenizer_utils")
|
||||
|
||||
|
||||
def load_bpe_file(model_path: Path) -> dict[bytes, int]:
|
||||
"""
|
||||
Load BPE file directly and return mergeable ranks.
|
||||
|
||||
Args:
|
||||
model_path (Path): Path to the BPE model file.
|
||||
|
||||
Returns:
|
||||
dict[bytes, int]: Dictionary mapping byte sequences to their ranks.
|
||||
"""
|
||||
mergeable_ranks = {}
|
||||
|
||||
with open(model_path, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
for line in content.splitlines():
|
||||
if not line.strip(): # Skip empty lines
|
||||
continue
|
||||
try:
|
||||
token, rank = line.split()
|
||||
mergeable_ranks[base64.b64decode(token)] = int(rank)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse line '{line}': {e}")
|
||||
continue
|
||||
|
||||
return mergeable_ranks
|
|
@ -6,12 +6,12 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.datatypes import AccessRule, Api
|
||||
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
|
@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
|
|||
deps[Api.safety],
|
||||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
policy,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -60,6 +60,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
|
@ -96,13 +97,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
vector_io_api: VectorIO,
|
||||
persistence_store: KVStore,
|
||||
created_at: str,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.agent_config = agent_config
|
||||
self.inference_api = inference_api
|
||||
self.safety_api = safety_api
|
||||
self.vector_io_api = vector_io_api
|
||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||
self.storage = AgentPersistence(agent_id, persistence_store, policy)
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.created_at = created_at
|
||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.apis.agents import (
|
|||
Session,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
|
@ -40,6 +41,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.datatypes import AccessRule
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
@ -61,6 +63,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
|
@ -71,6 +74,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
||||
self.policy = policy
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||
|
@ -129,6 +133,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||
),
|
||||
created_at=agent_info.created_at,
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
async def create_agent_session(
|
||||
|
@ -324,10 +329,12 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input, model, instructions, previous_response_id, store, stream, temperature, tools
|
||||
input, model, instructions, previous_response_id, store, stream, temperature, text, tools, max_infer_iters
|
||||
)
|
||||
|
||||
async def list_openai_responses(
|
||||
|
|
|
@ -8,7 +8,7 @@ import json
|
|||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
@ -37,6 +37,8 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
Inference,
|
||||
|
@ -50,7 +52,12 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
|
@ -158,6 +165,21 @@ async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> Open
|
|||
)
|
||||
|
||||
|
||||
async def _convert_response_text_to_chat_response_format(text: OpenAIResponseText) -> OpenAIResponseFormatParam:
|
||||
"""
|
||||
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
|
||||
"""
|
||||
if not text.format or text.format["type"] == "text":
|
||||
return OpenAIResponseFormatText(type="text")
|
||||
if text.format["type"] == "json_object":
|
||||
return OpenAIResponseFormatJSONObject()
|
||||
if text.format["type"] == "json_schema":
|
||||
return OpenAIResponseFormatJSONSchema(
|
||||
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
||||
)
|
||||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
||||
async def _get_message_type_by_role(role: str):
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
|
@ -178,8 +200,8 @@ class ChatCompletionContext(BaseModel):
|
|||
messages: list[OpenAIMessageParam]
|
||||
tools: list[ChatCompletionToolParam] | None = None
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
||||
stream: bool
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
||||
|
||||
|
||||
class OpenAIResponsesImpl:
|
||||
|
@ -258,37 +280,6 @@ class OpenAIResponsesImpl:
|
|||
"""
|
||||
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
||||
|
||||
async def _process_response_choices(
|
||||
self,
|
||||
chat_response: OpenAIChatCompletion,
|
||||
ctx: ChatCompletionContext,
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
) -> list[OpenAIResponseOutput]:
|
||||
"""Handle tool execution and response message creation."""
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
# Execute tool calls if any
|
||||
for choice in chat_response.choices:
|
||||
if choice.message.tool_calls and tools:
|
||||
# Assume if the first tool is a function, all tools are functions
|
||||
if tools[0].type == "function":
|
||||
for tool_call in choice.message.tool_calls:
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments=tool_call.function.arguments or "",
|
||||
call_id=tool_call.id,
|
||||
name=tool_call.function.name or "",
|
||||
id=f"fc_{uuid.uuid4()}",
|
||||
status="completed",
|
||||
)
|
||||
)
|
||||
else:
|
||||
tool_messages = await self._execute_tool_and_return_final_output(choice, ctx)
|
||||
output_messages.extend(tool_messages)
|
||||
else:
|
||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||
|
||||
return output_messages
|
||||
|
||||
async def _store_response(
|
||||
self,
|
||||
response: OpenAIResponseObject,
|
||||
|
@ -331,10 +322,52 @@ class OpenAIResponsesImpl:
|
|||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
):
|
||||
stream = False if stream is None else stream
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
model=model,
|
||||
instructions=instructions,
|
||||
previous_response_id=previous_response_id,
|
||||
store=store,
|
||||
temperature=temperature,
|
||||
text=text,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return stream_gen
|
||||
else:
|
||||
response = None
|
||||
async for stream_chunk in stream_gen:
|
||||
if stream_chunk.type == "response.completed":
|
||||
if response is not None:
|
||||
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
|
||||
response = stream_chunk.response
|
||||
# don't leave the generator half complete!
|
||||
|
||||
if response is None:
|
||||
raise ValueError("The response stream never completed")
|
||||
return response
|
||||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
|
||||
# Input preprocessing
|
||||
|
@ -342,7 +375,10 @@ class OpenAIResponsesImpl:
|
|||
messages = await _convert_response_input_to_chat_messages(input)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Tool setup
|
||||
# Structured outputs
|
||||
response_format = await _convert_response_text_to_chat_response_format(text)
|
||||
|
||||
# Tool setup, TODO: refactor this slightly since this can also yield events
|
||||
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
||||
)
|
||||
|
@ -354,89 +390,10 @@ class OpenAIResponsesImpl:
|
|||
messages=messages,
|
||||
tools=chat_tools,
|
||||
mcp_tool_to_server=mcp_tool_to_server,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
inference_result = await self.inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=chat_tools,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._create_streaming_response(
|
||||
inference_result=inference_result,
|
||||
ctx=ctx,
|
||||
output_messages=output_messages,
|
||||
input=input,
|
||||
model=model,
|
||||
store=store,
|
||||
tools=tools,
|
||||
)
|
||||
else:
|
||||
return await self._create_non_streaming_response(
|
||||
inference_result=inference_result,
|
||||
ctx=ctx,
|
||||
output_messages=output_messages,
|
||||
input=input,
|
||||
model=model,
|
||||
store=store,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
async def _create_non_streaming_response(
|
||||
self,
|
||||
inference_result: Any,
|
||||
ctx: ChatCompletionContext,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
store: bool | None,
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
) -> OpenAIResponseObject:
|
||||
chat_response = OpenAIChatCompletion(**inference_result.model_dump())
|
||||
|
||||
# Process response choices (tool execution and message creation)
|
||||
output_messages.extend(
|
||||
await self._process_response_choices(
|
||||
chat_response=chat_response,
|
||||
ctx=ctx,
|
||||
tools=tools,
|
||||
)
|
||||
)
|
||||
|
||||
response = OpenAIResponseObject(
|
||||
created_at=chat_response.created,
|
||||
id=f"resp-{uuid.uuid4()}",
|
||||
model=model,
|
||||
object="response",
|
||||
status="completed",
|
||||
output=output_messages,
|
||||
)
|
||||
logger.debug(f"OpenAI Responses response: {response}")
|
||||
|
||||
# Store response if requested
|
||||
if store:
|
||||
await self._store_response(
|
||||
response=response,
|
||||
input=input,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
inference_result: Any,
|
||||
ctx: ChatCompletionContext,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
store: bool | None,
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Create initial response and emit response.created immediately
|
||||
response_id = f"resp-{uuid.uuid4()}"
|
||||
created_at = int(time.time())
|
||||
|
@ -448,87 +405,144 @@ class OpenAIResponsesImpl:
|
|||
object="response",
|
||||
status="in_progress",
|
||||
output=output_messages.copy(),
|
||||
text=text,
|
||||
)
|
||||
|
||||
# Emit response.created immediately
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||
|
||||
# For streaming, inference_result is an async iterator of chunks
|
||||
# Stream chunks and emit delta events as they arrive
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
sequence_number = 0
|
||||
n_iter = 0
|
||||
messages = ctx.messages.copy()
|
||||
|
||||
# Create a placeholder message item for delta events
|
||||
message_item_id = f"msg_{uuid.uuid4()}"
|
||||
|
||||
async for chunk in inference_result:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=0,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=0,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Aggregate tool call arguments across chunks, using their index as the aggregation key
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
if response_tool_call:
|
||||
response_tool_call.function.arguments += tool_call.function.arguments
|
||||
else:
|
||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||
tool_call_dict.pop("type", None)
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Convert collected chunks to complete response
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
chat_response_obj = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=assistant_message,
|
||||
finish_reason=chunk_finish_reason,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=chunk_created,
|
||||
model=chunk_model,
|
||||
)
|
||||
|
||||
# Process response choices (tool execution and message creation)
|
||||
output_messages.extend(
|
||||
await self._process_response_choices(
|
||||
chat_response=chat_response_obj,
|
||||
ctx=ctx,
|
||||
tools=tools,
|
||||
while True:
|
||||
completion_result = await self.inference_api.openai_chat_completion(
|
||||
model=ctx.model,
|
||||
messages=messages,
|
||||
tools=ctx.tools,
|
||||
stream=True,
|
||||
temperature=ctx.temperature,
|
||||
response_format=ctx.response_format,
|
||||
)
|
||||
)
|
||||
|
||||
# Process streaming chunks and build complete response
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
sequence_number = 0
|
||||
|
||||
# Create a placeholder message item for delta events
|
||||
message_item_id = f"msg_{uuid.uuid4()}"
|
||||
|
||||
async for chunk in completion_result:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=0,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=0,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Aggregate tool call arguments across chunks
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
if response_tool_call:
|
||||
# Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions
|
||||
if tool_call.function.arguments:
|
||||
# Guard against an initial None argument before we concatenate
|
||||
response_tool_call.function.arguments = (
|
||||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
else:
|
||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||
tool_call_dict.pop("type", None)
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Convert collected chunks to complete response
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
current_response = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=assistant_message,
|
||||
finish_reason=chunk_finish_reason,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=chunk_created,
|
||||
model=chunk_model,
|
||||
)
|
||||
|
||||
function_tool_calls = []
|
||||
non_function_tool_calls = []
|
||||
|
||||
next_turn_messages = messages.copy()
|
||||
for choice in current_response.choices:
|
||||
next_turn_messages.append(choice.message)
|
||||
|
||||
if choice.message.tool_calls and tools:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if _is_function_tool_call(tool_call, tools):
|
||||
function_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
else:
|
||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||
|
||||
# execute non-function tool calls
|
||||
for tool_call in non_function_tool_calls:
|
||||
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
if tool_response_message:
|
||||
next_turn_messages.append(tool_response_message)
|
||||
|
||||
for tool_call in function_tool_calls:
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments=tool_call.function.arguments or "",
|
||||
call_id=tool_call.id,
|
||||
name=tool_call.function.name or "",
|
||||
id=f"fc_{uuid.uuid4()}",
|
||||
status="completed",
|
||||
)
|
||||
)
|
||||
|
||||
if not function_tool_calls and not non_function_tool_calls:
|
||||
break
|
||||
|
||||
if function_tool_calls:
|
||||
logger.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||
break
|
||||
|
||||
n_iter += 1
|
||||
if n_iter >= max_infer_iters:
|
||||
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
|
||||
break
|
||||
|
||||
messages = next_turn_messages
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
|
@ -537,18 +551,19 @@ class OpenAIResponsesImpl:
|
|||
model=model,
|
||||
object="response",
|
||||
status="completed",
|
||||
text=text,
|
||||
output=output_messages,
|
||||
)
|
||||
|
||||
# Emit response.completed
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
|
||||
if store:
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=input,
|
||||
)
|
||||
|
||||
# Emit response.completed
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
|
||||
async def _convert_response_tools_to_chat_tools(
|
||||
self, tools: list[OpenAIResponseInputTool]
|
||||
) -> tuple[
|
||||
|
@ -641,49 +656,6 @@ class OpenAIResponsesImpl:
|
|||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||
return chat_tools, mcp_tool_to_server, mcp_list_message
|
||||
|
||||
async def _execute_tool_and_return_final_output(
|
||||
self,
|
||||
choice: OpenAIChoice,
|
||||
ctx: ChatCompletionContext,
|
||||
) -> list[OpenAIResponseOutput]:
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
|
||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
||||
return output_messages
|
||||
|
||||
if not choice.message.tool_calls:
|
||||
return output_messages
|
||||
|
||||
next_turn_messages = ctx.messages.copy()
|
||||
|
||||
# Add the assistant message with tool_calls response to the messages list
|
||||
next_turn_messages.append(choice.message)
|
||||
|
||||
for tool_call in choice.message.tool_calls:
|
||||
# TODO: telemetry spans for tool calls
|
||||
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
if further_input:
|
||||
next_turn_messages.append(further_input)
|
||||
|
||||
tool_results_chat_response = await self.inference_api.openai_chat_completion(
|
||||
model=ctx.model,
|
||||
messages=next_turn_messages,
|
||||
stream=ctx.stream,
|
||||
temperature=ctx.temperature,
|
||||
)
|
||||
# type cast to appease mypy: this is needed because we don't handle streaming properly :)
|
||||
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
||||
|
||||
# Huge TODO: these are NOT the final outputs, we must keep the loop going
|
||||
tool_final_outputs = [
|
||||
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
|
||||
]
|
||||
# TODO: Wire in annotations with URLs, titles, etc to these output messages
|
||||
output_messages.extend(tool_final_outputs)
|
||||
return output_messages
|
||||
|
||||
async def _execute_tool_call(
|
||||
self,
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
|
@ -767,5 +739,20 @@ class OpenAIResponsesImpl:
|
|||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||
else:
|
||||
text = str(error_exc)
|
||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||
|
||||
return message, input_message
|
||||
|
||||
|
||||
def _is_function_tool_call(
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
tools: list[OpenAIResponseInputTool],
|
||||
) -> bool:
|
||||
if not tool_call.function:
|
||||
return False
|
||||
for t in tools:
|
||||
if t.type == "function" and t.name == tool_call.function.name:
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -10,9 +10,10 @@ import uuid
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.access_control.datatypes import AccessRule
|
||||
from llama_stack.distribution.datatypes import User
|
||||
from llama_stack.distribution.request_headers import get_authenticated_user
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -22,7 +23,9 @@ class AgentSessionInfo(Session):
|
|||
# TODO: is this used anywhere?
|
||||
vector_db_id: str | None = None
|
||||
started_at: datetime
|
||||
access_attributes: AccessAttributes | None = None
|
||||
owner: User | None = None
|
||||
identifier: str | None = None
|
||||
type: str = "session"
|
||||
|
||||
|
||||
class AgentInfo(AgentConfig):
|
||||
|
@ -30,24 +33,27 @@ class AgentInfo(AgentConfig):
|
|||
|
||||
|
||||
class AgentPersistence:
|
||||
def __init__(self, agent_id: str, kvstore: KVStore):
|
||||
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
||||
self.agent_id = agent_id
|
||||
self.kvstore = kvstore
|
||||
self.policy = policy
|
||||
|
||||
async def create_session(self, name: str) -> str:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Get current user's auth attributes for new sessions
|
||||
auth_attributes = get_auth_attributes()
|
||||
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
|
||||
user = get_authenticated_user()
|
||||
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
access_attributes=access_attributes,
|
||||
owner=user,
|
||||
turns=[],
|
||||
identifier=name, # should this be qualified in any way?
|
||||
)
|
||||
if not is_action_allowed(self.policy, "create", session_info, user):
|
||||
raise AccessDeniedError()
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
|
@ -73,10 +79,10 @@ class AgentPersistence:
|
|||
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
||||
"""Check if current user has access to the session."""
|
||||
# Handle backward compatibility for old sessions without access control
|
||||
if not hasattr(session_info, "access_attributes"):
|
||||
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
||||
return True
|
||||
|
||||
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
|
||||
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
|
|
20
llama_stack/providers/inline/files/localfs/__init__.py
Normal file
20
llama_stack/providers/inline/files/localfs/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import LocalfsFilesImplConfig
|
||||
from .files import LocalfsFilesImpl
|
||||
|
||||
__all__ = ["LocalfsFilesImpl", "LocalfsFilesImplConfig"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any]):
|
||||
impl = LocalfsFilesImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
31
llama_stack/providers/inline/files/localfs/config.py
Normal file
31
llama_stack/providers/inline/files/localfs/config.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
|
||||
|
||||
class LocalfsFilesImplConfig(BaseModel):
|
||||
storage_dir: str = Field(
|
||||
description="Directory to store uploaded files",
|
||||
)
|
||||
metadata_store: SqlStoreConfig = Field(
|
||||
description="SQL store configuration for file metadata",
|
||||
)
|
||||
ttl_secs: int = 365 * 24 * 60 * 60 # 1 year
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"storage_dir": "${env.FILES_STORAGE_DIR:" + __distro_dir__ + "/files}",
|
||||
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="files_metadata.db",
|
||||
),
|
||||
}
|
214
llama_stack/providers/inline/files/localfs/files.py
Normal file
214
llama_stack/providers/inline/files/localfs/files.py
Normal file
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import File, Form, Response, UploadFile
|
||||
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import (
|
||||
Files,
|
||||
ListOpenAIFileResponse,
|
||||
OpenAIFileDeleteResponse,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
|
||||
|
||||
from .config import LocalfsFilesImplConfig
|
||||
|
||||
|
||||
class LocalfsFilesImpl(Files):
|
||||
def __init__(self, config: LocalfsFilesImplConfig) -> None:
|
||||
self.config = config
|
||||
self.sql_store: SqlStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the files provider by setting up storage directory and metadata database."""
|
||||
# Create storage directory if it doesn't exist
|
||||
storage_path = Path(self.config.storage_dir)
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize SQL store for metadata
|
||||
self.sql_store = sqlstore_impl(self.config.metadata_store)
|
||||
await self.sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"filename": ColumnType.STRING,
|
||||
"purpose": ColumnType.STRING,
|
||||
"bytes": ColumnType.INTEGER,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"expires_at": ColumnType.INTEGER,
|
||||
"file_path": ColumnType.STRING, # Path to actual file on disk
|
||||
},
|
||||
)
|
||||
|
||||
def _generate_file_id(self) -> str:
|
||||
"""Generate a unique file ID for OpenAI API."""
|
||||
return f"file-{uuid.uuid4().hex}"
|
||||
|
||||
def _get_file_path(self, file_id: str) -> Path:
|
||||
"""Get the filesystem path for a file ID."""
|
||||
return Path(self.config.storage_dir) / file_id
|
||||
|
||||
# OpenAI Files API Implementation
|
||||
async def openai_upload_file(
|
||||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
) -> OpenAIFileObject:
|
||||
"""Upload a file that can be used across various endpoints."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
file_id = self._generate_file_id()
|
||||
file_path = self._get_file_path(file_id)
|
||||
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
created_at = int(time.time())
|
||||
expires_at = created_at + self.config.ttl_secs
|
||||
|
||||
await self.sql_store.insert(
|
||||
"openai_files",
|
||||
{
|
||||
"id": file_id,
|
||||
"filename": file.filename or "uploaded_file",
|
||||
"purpose": purpose.value,
|
||||
"bytes": file_size,
|
||||
"created_at": created_at,
|
||||
"expires_at": expires_at,
|
||||
"file_path": file_path.as_posix(),
|
||||
},
|
||||
)
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=file_id,
|
||||
filename=file.filename or "uploaded_file",
|
||||
purpose=purpose,
|
||||
bytes=file_size,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
async def openai_list_files(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 10000,
|
||||
order: Order | None = Order.desc,
|
||||
purpose: OpenAIFilePurpose | None = None,
|
||||
) -> ListOpenAIFileResponse:
|
||||
"""Returns a list of files that belong to the user's organization."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
# TODO: Implement 'after' pagination properly
|
||||
if after:
|
||||
raise NotImplementedError("After pagination not yet implemented")
|
||||
|
||||
where = None
|
||||
if purpose:
|
||||
where = {"purpose": purpose.value}
|
||||
|
||||
rows = await self.sql_store.fetch_all(
|
||||
"openai_files",
|
||||
where=where,
|
||||
order_by=[("created_at", order.value if order else Order.desc.value)],
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
files = [
|
||||
OpenAIFileObject(
|
||||
id=row["id"],
|
||||
filename=row["filename"],
|
||||
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||
bytes=row["bytes"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return ListOpenAIFileResponse(
|
||||
data=files,
|
||||
has_more=False, # TODO: Implement proper pagination
|
||||
first_id=files[0].id if files else "",
|
||||
last_id=files[-1].id if files else "",
|
||||
)
|
||||
|
||||
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||
"""Returns information about a specific file."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ValueError(f"File with id {file_id} not found")
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=row["id"],
|
||||
filename=row["filename"],
|
||||
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||
bytes=row["bytes"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
|
||||
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||
"""Delete a file."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ValueError(f"File with id {file_id} not found")
|
||||
|
||||
# Delete physical file
|
||||
file_path = Path(row["file_path"])
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
# Delete metadata from database
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
return OpenAIFileDeleteResponse(
|
||||
id=file_id,
|
||||
deleted=True,
|
||||
)
|
||||
|
||||
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||
"""Returns the contents of the specified file."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
# Get file metadata
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ValueError(f"File with id {file_id} not found")
|
||||
|
||||
# Read file content
|
||||
file_path = Path(row["file_path"])
|
||||
if not file_path.exists():
|
||||
raise ValueError(f"File content not found on disk: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
# Return as binary response with appropriate content type
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="application/octet-stream",
|
||||
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
|
||||
)
|
|
@ -40,6 +40,7 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -410,6 +411,16 @@ class VLLMInferenceImpl(
|
|||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
@ -30,7 +30,7 @@ class TelemetryConfig(BaseModel):
|
|||
)
|
||||
service_name: str = Field(
|
||||
# service name is always the same, use zero-width space to avoid clutter
|
||||
default="",
|
||||
default="\u200b",
|
||||
description="The service name to use for telemetry",
|
||||
)
|
||||
sinks: list[TelemetrySink] = Field(
|
||||
|
@ -52,7 +52,7 @@ class TelemetryConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:}",
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:\u200b}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
||||
"sqlite_db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
|
||||
}
|
||||
|
|
|
@ -146,7 +146,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
]
|
||||
for i, chunk in enumerate(chunks):
|
||||
metadata = chunk.metadata
|
||||
tokens += metadata["token_count"]
|
||||
tokens += metadata.get("token_count", 0)
|
||||
tokens += metadata.get("metadata_token_count", 0)
|
||||
|
||||
if tokens > query_config.max_tokens_in_context:
|
||||
|
|
|
@ -24,7 +24,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
"pandas",
|
||||
"scikit-learn",
|
||||
]
|
||||
+ kvstore_dependencies(),
|
||||
+ kvstore_dependencies(), # TODO make this dynamic based on the kvstore config
|
||||
module="llama_stack.providers.inline.agents.meta_reference",
|
||||
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
|
||||
api_dependencies=[
|
||||
|
|
|
@ -4,8 +4,22 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return []
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.files,
|
||||
provider_type="inline::localfs",
|
||||
# TODO: make this dynamic according to the sql store type
|
||||
pip_packages=sql_store_pip_packages,
|
||||
module="llama_stack.providers.inline.files.localfs",
|
||||
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -15,7 +15,6 @@ from llama_stack.providers.datatypes import (
|
|||
|
||||
META_REFERENCE_DEPS = [
|
||||
"accelerate",
|
||||
"blobfile",
|
||||
"fairscale",
|
||||
"torch",
|
||||
"torchvision",
|
||||
|
|
|
@ -20,7 +20,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.tool_runtime,
|
||||
provider_type="inline::rag-runtime",
|
||||
pip_packages=[
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"pypdf",
|
||||
"tqdm",
|
||||
|
|
|
@ -22,6 +22,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -197,3 +198,13 @@ class BedrockInferenceAdapter(
|
|||
response_body = json.loads(response.get("body").read())
|
||||
embeddings.append(response_body.get("embedding"))
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -194,3 +195,13 @@ class CerebrasInferenceAdapter(
|
|||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -152,3 +153,13 @@ class DatabricksInferenceAdapter(
|
|||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -37,6 +37,7 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
|
@ -254,7 +255,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
params = {
|
||||
"model": request.model,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
"stream": bool(request.stream),
|
||||
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
||||
}
|
||||
logger.debug(f"params to fireworks: {params}")
|
||||
|
@ -286,6 +287,16 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -238,6 +239,16 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
#
|
||||
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
@ -12,7 +12,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_model_entry,
|
||||
)
|
||||
|
||||
model_entries = [
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:8b-instruct-fp16",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
|
@ -32,6 +33,7 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -76,7 +78,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
request_has_media,
|
||||
)
|
||||
|
||||
from .models import model_entries
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
@ -86,7 +88,7 @@ class OllamaInferenceAdapter(
|
|||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.register_helper = ModelRegistryHelper(model_entries)
|
||||
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
||||
self.url = url
|
||||
|
||||
@property
|
||||
|
@ -343,21 +345,27 @@ class OllamaInferenceAdapter(
|
|||
model = await self.register_helper.register_model(model)
|
||||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
|
||||
if model.provider_resource_id is None:
|
||||
raise ValueError("Model provider_resource_id cannot be None")
|
||||
|
||||
if model.model_type == ModelType.embedding:
|
||||
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||
await self.client.pull(model.provider_resource_id)
|
||||
# TODO: you should pull here only if the model is not found in a list
|
||||
response = await self.client.list()
|
||||
if model.provider_resource_id not in [m.model for m in response.models]:
|
||||
await self.client.pull(model.provider_resource_id)
|
||||
|
||||
# we use list() here instead of ps() -
|
||||
# - ps() only lists running models, not available models
|
||||
# - models not currently running are run by the ollama server as needed
|
||||
response = await self.client.list()
|
||||
available_models = [m["model"] for m in response["models"]]
|
||||
if model.provider_resource_id is None:
|
||||
raise ValueError("Model provider_resource_id cannot be None")
|
||||
available_models = [m.model for m in response.models]
|
||||
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
|
||||
if provider_resource_id is None:
|
||||
provider_resource_id = model.provider_resource_id
|
||||
if provider_resource_id not in available_models:
|
||||
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]]
|
||||
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
||||
if provider_resource_id in available_models_latest:
|
||||
logger.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||
|
@ -370,6 +378,16 @@ class OllamaInferenceAdapter(
|
|||
|
||||
return model
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -469,7 +487,25 @@ class OllamaInferenceAdapter(
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self.openai_client.chat.completions.create(**params) # type: ignore
|
||||
response = await self.openai_client.chat.completions.create(**params)
|
||||
return await self._adjust_ollama_chat_completion_response_ids(response)
|
||||
|
||||
async def _adjust_ollama_chat_completion_response_ids(
|
||||
self,
|
||||
response: OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk],
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
id = f"chatcmpl-{uuid.uuid4()}"
|
||||
if isinstance(response, AsyncIterator):
|
||||
|
||||
async def stream_with_chunk_ids() -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
async for chunk in response:
|
||||
chunk.id = id
|
||||
yield chunk
|
||||
|
||||
return stream_with_chunk_ids()
|
||||
else:
|
||||
response.id = id
|
||||
return response
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
|
|
|
@ -14,6 +14,9 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
|
@ -38,6 +41,7 @@ logger = logging.getLogger(__name__)
|
|||
# | batch_chat_completion | LiteLLMOpenAIMixin |
|
||||
# | openai_completion | AsyncOpenAI |
|
||||
# | openai_chat_completion | AsyncOpenAI |
|
||||
# | openai_embeddings | AsyncOpenAI |
|
||||
#
|
||||
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: OpenAIConfig) -> None:
|
||||
|
@ -171,3 +175,51 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
user=user,
|
||||
)
|
||||
return await self._openai_client.chat.completions.create(**params)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
||||
if model_id.startswith("openai/"):
|
||||
model_id = model_id[len("openai/") :]
|
||||
|
||||
# Prepare parameters for OpenAI embeddings API
|
||||
params = {
|
||||
"model": model_id,
|
||||
"input": input,
|
||||
}
|
||||
|
||||
if encoding_format is not None:
|
||||
params["encoding_format"] = encoding_format
|
||||
if dimensions is not None:
|
||||
params["dimensions"] = dimensions
|
||||
if user is not None:
|
||||
params["user"] = user
|
||||
|
||||
# Call OpenAI embeddings API
|
||||
response = await self._openai_client.embeddings.create(**params)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=response.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
|
|
@ -19,6 +19,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -210,6 +211,16 @@ class PassthroughInferenceAdapter(Inference):
|
|||
task_type=task_type,
|
||||
)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator
|
|||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.inference.inference import OpenAIEmbeddingsResponse
|
||||
|
||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
@ -134,3 +135,13 @@ class RunpodInferenceAdapter(
|
|||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -218,7 +218,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
"json_schema": {
|
||||
"name": name,
|
||||
"schema": fmt,
|
||||
"strict": True,
|
||||
"strict": False,
|
||||
},
|
||||
}
|
||||
if request.tools:
|
||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
|
@ -291,6 +292,16 @@ class _HfAdapter(
|
|||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
|
@ -267,6 +268,16 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
embeddings = [item.embedding for item in r.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -507,6 +508,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -260,6 +261,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError("embedding is not supported for watsonx")
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -4,7 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import struct
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -15,6 +17,9 @@ from llama_stack.apis.inference import (
|
|||
EmbeddingTaskType,
|
||||
InterleavedContentItem,
|
||||
ModelStore,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
TextTruncation,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
|
@ -43,6 +48,50 @@ class SentenceTransformerEmbeddingMixin:
|
|||
)
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
# Convert input to list format if it's a single string
|
||||
input_list = [input] if isinstance(input, str) else input
|
||||
if not input_list:
|
||||
raise ValueError("Empty list not supported")
|
||||
|
||||
# Get the model and generate embeddings
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
||||
embeddings = embedding_model.encode(input_list, show_progress_bar=False)
|
||||
|
||||
# Convert embeddings to the requested format
|
||||
data = []
|
||||
for i, embedding in enumerate(embeddings):
|
||||
if encoding_format == "base64":
|
||||
# Convert float array to base64 string
|
||||
float_bytes = struct.pack(f"{len(embedding)}f", *embedding)
|
||||
embedding_value = base64.b64encode(float_bytes).decode("ascii")
|
||||
else:
|
||||
# Default to float format
|
||||
embedding_value = embedding.tolist()
|
||||
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_value,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
# Not returning actual token usage
|
||||
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=model_obj.provider_resource_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||
global EMBEDDING_MODELS
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import struct
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
|
@ -35,6 +37,9 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
|
@ -264,6 +269,52 @@ class LiteLLMOpenAIMixin(
|
|||
embeddings = [data["embedding"] for data in response["data"]]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Convert input to list if it's a string
|
||||
input_list = [input] if isinstance(input, str) else input
|
||||
|
||||
# Call litellm embedding function
|
||||
# litellm.drop_params = True
|
||||
response = litellm.embedding(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
input=input_list,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
dimensions=dimensions,
|
||||
)
|
||||
|
||||
# Convert response to OpenAI format
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response["data"]):
|
||||
# we encode to base64 if the encoding format is base64 in the request
|
||||
if encoding_format == "base64":
|
||||
byte_data = b"".join(struct.pack("f", f) for f in embedding_data["embedding"])
|
||||
embedding = base64.b64encode(byte_data).decode("utf-8")
|
||||
else:
|
||||
embedding = embedding_data["embedding"]
|
||||
|
||||
data.append(OpenAIEmbeddingData(embedding=embedding, index=i))
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||
total_tokens=response["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=model_obj.provider_resource_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -36,6 +36,10 @@ class RedisKVStoreConfig(CommonConfig):
|
|||
def url(self) -> str:
|
||||
return f"redis://{self.host}:{self.port}"
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["redis"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls):
|
||||
return {
|
||||
|
@ -53,6 +57,10 @@ class SqliteKVStoreConfig(CommonConfig):
|
|||
description="File path for the sqlite database",
|
||||
)
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["aiosqlite"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "kvstore.db"):
|
||||
return {
|
||||
|
@ -65,22 +73,22 @@ class SqliteKVStoreConfig(CommonConfig):
|
|||
class PostgresKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
port: str = "5432"
|
||||
db: str = "llamastack"
|
||||
user: str
|
||||
password: str | None = None
|
||||
table_name: str = "llamastack_kvstore"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, table_name: str = "llamastack_kvstore"):
|
||||
def sample_run_config(cls, table_name: str = "llamastack_kvstore", **kwargs):
|
||||
return {
|
||||
"type": "postgres",
|
||||
"namespace": None,
|
||||
"host": "${env.POSTGRES_HOST:localhost}",
|
||||
"port": "${env.POSTGRES_PORT:5432}",
|
||||
"db": "${env.POSTGRES_DB}",
|
||||
"user": "${env.POSTGRES_USER}",
|
||||
"password": "${env.POSTGRES_PASSWORD}",
|
||||
"db": "${env.POSTGRES_DB:llamastack}",
|
||||
"user": "${env.POSTGRES_USER:llamastack}",
|
||||
"password": "${env.POSTGRES_PASSWORD:llamastack}",
|
||||
"table_name": "${env.POSTGRES_TABLE_NAME:" + table_name + "}",
|
||||
}
|
||||
|
||||
|
@ -100,6 +108,10 @@ class PostgresKVStoreConfig(CommonConfig):
|
|||
raise ValueError("Table name must be less than 63 characters")
|
||||
return v
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["psycopg2-binary"]
|
||||
|
||||
|
||||
class MongoDBKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value
|
||||
|
@ -110,6 +122,10 @@ class MongoDBKVStoreConfig(CommonConfig):
|
|||
password: str | None = None
|
||||
collection_name: str = "llamastack_kvstore"
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["pymongo"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
|
||||
return {
|
||||
|
|
|
@ -10,6 +10,13 @@ from .config import KVStoreConfig, KVStoreType
|
|||
|
||||
|
||||
def kvstore_dependencies():
|
||||
"""
|
||||
Returns all possible kvstore dependencies for registry/provider specifications.
|
||||
|
||||
NOTE: For specific kvstore implementations, use config.pip_packages instead.
|
||||
This function returns the union of all dependencies for cases where the specific
|
||||
kvstore type is not known at declaration time (e.g., provider registries).
|
||||
"""
|
||||
return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"]
|
||||
|
||||
|
||||
|
|
|
@ -171,6 +171,22 @@ def make_overlapped_chunks(
|
|||
return chunks
|
||||
|
||||
|
||||
def _validate_embedding(embedding: NDArray, index: int, expected_dimension: int):
|
||||
"""Helper method to validate embedding format and dimensions"""
|
||||
if not isinstance(embedding, (list | np.ndarray)):
|
||||
raise ValueError(f"Embedding at index {index} must be a list or numpy array, got {type(embedding)}")
|
||||
|
||||
if isinstance(embedding, np.ndarray):
|
||||
if not np.issubdtype(embedding.dtype, np.number):
|
||||
raise ValueError(f"Embedding at index {index} contains non-numeric values")
|
||||
else:
|
||||
if not all(isinstance(e, (float | int | np.number)) for e in embedding):
|
||||
raise ValueError(f"Embedding at index {index} contains non-numeric values")
|
||||
|
||||
if len(embedding) != expected_dimension:
|
||||
raise ValueError(f"Embedding at index {index} has dimension {len(embedding)}, expected {expected_dimension}")
|
||||
|
||||
|
||||
class EmbeddingIndex(ABC):
|
||||
@abstractmethod
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
|
@ -199,11 +215,22 @@ class VectorDBWithIndex:
|
|||
self,
|
||||
chunks: list[Chunk],
|
||||
) -> None:
|
||||
embeddings_response = await self.inference_api.embeddings(
|
||||
self.vector_db.embedding_model, [x.content for x in chunks]
|
||||
)
|
||||
embeddings = np.array(embeddings_response.embeddings)
|
||||
chunks_to_embed = []
|
||||
for i, c in enumerate(chunks):
|
||||
if c.embedding is None:
|
||||
chunks_to_embed.append(c)
|
||||
else:
|
||||
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
|
||||
|
||||
if chunks_to_embed:
|
||||
resp = await self.inference_api.embeddings(
|
||||
self.vector_db.embedding_model,
|
||||
[c.content for c in chunks_to_embed],
|
||||
)
|
||||
for c, embedding in zip(chunks_to_embed, resp.embeddings, strict=False):
|
||||
c.embedding = embedding
|
||||
|
||||
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
|
||||
await self.index.add_chunks(chunks, embeddings)
|
||||
|
||||
async def query_chunks(
|
||||
|
|
|
@ -19,10 +19,10 @@ from sqlalchemy import (
|
|||
Text,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from ..api import ColumnDefinition, ColumnType, SqlStore
|
||||
from ..sqlstore import SqliteSqlStoreConfig
|
||||
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||
from .sqlstore import SqlAlchemySqlStoreConfig
|
||||
|
||||
TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||
ColumnType.INTEGER: Integer,
|
||||
|
@ -35,9 +35,10 @@ TYPE_MAPPING: dict[ColumnType, Any] = {
|
|||
}
|
||||
|
||||
|
||||
class SqliteSqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqliteSqlStoreConfig):
|
||||
self.engine = create_async_engine(config.engine_str)
|
||||
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||
self.config = config
|
||||
self.async_session = async_sessionmaker(create_async_engine(config.engine_str))
|
||||
self.metadata = MetaData()
|
||||
|
||||
async def create_table(
|
||||
|
@ -78,13 +79,14 @@ class SqliteSqlStoreImpl(SqlStore):
|
|||
|
||||
# Create the table in the database if it doesn't exist
|
||||
# checkfirst=True ensures it doesn't try to recreate if it's already there
|
||||
async with self.engine.begin() as conn:
|
||||
engine = create_async_engine(self.config.engine_str)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.execute(self.metadata.tables[table].insert(), data)
|
||||
await conn.commit()
|
||||
async with self.async_session() as session:
|
||||
await session.execute(self.metadata.tables[table].insert(), data)
|
||||
await session.commit()
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
|
@ -93,7 +95,7 @@ class SqliteSqlStoreImpl(SqlStore):
|
|||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self.engine.begin() as conn:
|
||||
async with self.async_session() as session:
|
||||
query = select(self.metadata.tables[table])
|
||||
if where:
|
||||
for key, value in where.items():
|
||||
|
@ -117,7 +119,7 @@ class SqliteSqlStoreImpl(SqlStore):
|
|||
query = query.order_by(self.metadata.tables[table].c[name].desc())
|
||||
else:
|
||||
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
|
||||
result = await conn.execute(query)
|
||||
result = await session.execute(query)
|
||||
if result.rowcount == 0:
|
||||
return []
|
||||
return [dict(row._mapping) for row in result]
|
||||
|
@ -142,20 +144,20 @@ class SqliteSqlStoreImpl(SqlStore):
|
|||
if not where:
|
||||
raise ValueError("where is required for update")
|
||||
|
||||
async with self.engine.begin() as conn:
|
||||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].update()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
await conn.execute(stmt, data)
|
||||
await conn.commit()
|
||||
await session.execute(stmt, data)
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
||||
if not where:
|
||||
raise ValueError("where is required for delete")
|
||||
|
||||
async with self.engine.begin() as conn:
|
||||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].delete()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
await conn.execute(stmt)
|
||||
await conn.commit()
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
|
@ -15,13 +16,26 @@ from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
|||
|
||||
from .api import SqlStore
|
||||
|
||||
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
||||
|
||||
|
||||
class SqlStoreType(Enum):
|
||||
sqlite = "sqlite"
|
||||
postgres = "postgres"
|
||||
|
||||
|
||||
class SqliteSqlStoreConfig(BaseModel):
|
||||
class SqlAlchemySqlStoreConfig(BaseModel):
|
||||
@property
|
||||
@abstractmethod
|
||||
def engine_str(self) -> str: ...
|
||||
|
||||
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["sqlalchemy[asyncio]"]
|
||||
|
||||
|
||||
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal["sqlite"] = SqlStoreType.sqlite.value
|
||||
db_path: str = Field(
|
||||
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
|
@ -39,18 +53,37 @@ class SqliteSqlStoreConfig(BaseModel):
|
|||
db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
|
||||
)
|
||||
|
||||
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["sqlalchemy[asyncio]"]
|
||||
return super().pip_packages + ["aiosqlite"]
|
||||
|
||||
|
||||
class PostgresSqlStoreConfig(BaseModel):
|
||||
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal["postgres"] = SqlStoreType.postgres.value
|
||||
host: str = "localhost"
|
||||
port: str = "5432"
|
||||
db: str = "llamastack"
|
||||
user: str
|
||||
password: str | None = None
|
||||
|
||||
@property
|
||||
def engine_str(self) -> str:
|
||||
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
raise NotImplementedError("Postgres is not implemented yet")
|
||||
return super().pip_packages + ["asyncpg"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs):
|
||||
return cls(
|
||||
type="postgres",
|
||||
host="${env.POSTGRES_HOST:localhost}",
|
||||
port="${env.POSTGRES_PORT:5432}",
|
||||
db="${env.POSTGRES_DB:llamastack}",
|
||||
user="${env.POSTGRES_USER:llamastack}",
|
||||
password="${env.POSTGRES_PASSWORD:llamastack}",
|
||||
)
|
||||
|
||||
|
||||
SqlStoreConfig = Annotated[
|
||||
|
@ -60,12 +93,10 @@ SqlStoreConfig = Annotated[
|
|||
|
||||
|
||||
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
|
||||
if config.type == SqlStoreType.sqlite.value:
|
||||
from .sqlite.sqlite import SqliteSqlStoreImpl
|
||||
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
|
||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
|
||||
impl = SqliteSqlStoreImpl(config)
|
||||
elif config.type == SqlStoreType.postgres.value:
|
||||
raise NotImplementedError("Postgres is not implemented yet")
|
||||
impl = SqlAlchemySqlStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown sqlstore type {config.type}")
|
||||
|
||||
|
|
|
@ -30,4 +30,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -42,7 +42,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/trace_store.db
|
||||
eval:
|
||||
|
|
|
@ -30,4 +30,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -82,7 +82,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/trace_store.db
|
||||
tool_runtime:
|
||||
|
|
|
@ -31,4 +31,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -45,7 +45,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/trace_store.db
|
||||
eval:
|
||||
|
|
|
@ -31,5 +31,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -48,7 +48,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/trace_store.db
|
||||
eval:
|
||||
|
|
|
@ -44,7 +44,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/trace_store.db
|
||||
eval:
|
||||
|
|
|
@ -24,6 +24,8 @@ distribution_spec:
|
|||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
files:
|
||||
- inline::localfs
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
|
@ -32,5 +34,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import (
|
|||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
@ -36,6 +37,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"files": ["inline::localfs"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
|
@ -62,6 +64,11 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
|
||||
available_models = {
|
||||
"fireworks": MODEL_ENTRIES,
|
||||
|
@ -104,6 +111,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_overrides={
|
||||
"inference": [inference_provider, embedding_provider],
|
||||
"vector_io": [vector_io_provider],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_models=default_models + [embedding_model],
|
||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||
|
@ -116,6 +124,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
embedding_provider,
|
||||
],
|
||||
"vector_io": [vector_io_provider],
|
||||
"files": [files_provider],
|
||||
"safety": [
|
||||
Provider(
|
||||
provider_id="llama-guard",
|
||||
|
|
|
@ -4,6 +4,7 @@ apis:
|
|||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
|
@ -53,7 +54,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/trace_store.db
|
||||
eval:
|
||||
|
@ -90,6 +91,14 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:~/.llama/distributions/fireworks/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/files_metadata.db
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
|
|
|
@ -4,6 +4,7 @@ apis:
|
|||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
|
@ -48,7 +49,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/trace_store.db
|
||||
eval:
|
||||
|
@ -85,6 +86,14 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:~/.llama/distributions/fireworks/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/files_metadata.db
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
|
|
|
@ -27,4 +27,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -48,7 +48,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/trace_store.db
|
||||
eval:
|
||||
|
@ -112,7 +112,7 @@ models:
|
|||
provider_model_id: groq/llama3-8b-8192
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
model_id: groq/meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/llama3-8b-8192
|
||||
model_type: llm
|
||||
|
@ -127,7 +127,7 @@ models:
|
|||
provider_model_id: groq/llama3-70b-8192
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3-70B-Instruct
|
||||
model_id: groq/meta-llama/Llama-3-70B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/llama3-70b-8192
|
||||
model_type: llm
|
||||
|
@ -137,7 +137,7 @@ models:
|
|||
provider_model_id: groq/llama-3.3-70b-versatile
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||
model_id: groq/meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/llama-3.3-70b-versatile
|
||||
model_type: llm
|
||||
|
@ -147,7 +147,7 @@ models:
|
|||
provider_model_id: groq/llama-3.2-3b-preview
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||
model_id: groq/meta-llama/Llama-3.2-3B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/llama-3.2-3b-preview
|
||||
model_type: llm
|
||||
|
@ -157,7 +157,7 @@ models:
|
|||
provider_model_id: groq/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
model_id: groq/meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
|
@ -167,7 +167,7 @@ models:
|
|||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
model_id: groq/meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
|
@ -177,7 +177,7 @@ models:
|
|||
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
model_id: groq/meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
|
@ -187,7 +187,7 @@ models:
|
|||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
model_id: groq/meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
|
|
|
@ -30,5 +30,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -53,7 +53,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/trace_store.db
|
||||
eval:
|
||||
|
|
|
@ -48,7 +48,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/trace_store.db
|
||||
eval:
|
||||
|
|
|
@ -31,5 +31,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -53,7 +53,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/trace_store.db
|
||||
eval:
|
||||
|
|
|
@ -48,7 +48,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/trace_store.db
|
||||
eval:
|
||||
|
|
|
@ -31,4 +31,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -57,7 +57,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/trace_store.db
|
||||
eval:
|
||||
|
|
|
@ -30,5 +30,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -63,7 +63,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/trace_store.db
|
||||
eval:
|
||||
|
|
|
@ -53,7 +53,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/trace_store.db
|
||||
eval:
|
||||
|
|
|
@ -25,5 +25,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -53,7 +53,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/trace_store.db
|
||||
eval:
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue